From 2b097e36153d050a1cd764fc9a2223e9f544888f Mon Sep 17 00:00:00 2001 From: subangstrom Date: Fri, 15 Aug 2025 14:43:03 -0400 Subject: [PATCH 1/8] copied over, will need to update --- phaser/extras/calc_drift.py | 426 ++++++++++++++++++++++++++++++ phaser/extras/calc_tilt.py | 164 ++++++++++++ phaser/extras/extract_params.py | 173 ++++++++++++ phaser/extras/main.py | 43 +++ phaser/extras/prepare.py | 139 ++++++++++ phaser/extras/process_metadata.py | 240 +++++++++++++++++ 6 files changed, 1185 insertions(+) create mode 100644 phaser/extras/calc_drift.py create mode 100644 phaser/extras/calc_tilt.py create mode 100755 phaser/extras/extract_params.py create mode 100644 phaser/extras/main.py create mode 100644 phaser/extras/prepare.py create mode 100644 phaser/extras/process_metadata.py diff --git a/phaser/extras/calc_drift.py b/phaser/extras/calc_drift.py new file mode 100644 index 0000000..69e1e9b --- /dev/null +++ b/phaser/extras/calc_drift.py @@ -0,0 +1,426 @@ +from pathlib import Path +from queue import Queue, Empty +from threading import Thread, Event +import fnmatch +import typing as t + +import click +import numpy +from numpy.typing import NDArray, ArrayLike +import scipy.ndimage +import json +import h5py +from matplotlib import pyplot +from matplotlib.patches import Circle, PathPatch +from matplotlib.path import Path as MplPath +from matplotlib.backend_bases import MouseEvent, MouseButton, PickEvent, KeyEvent +from rich.console import Console +from rich.prompt import Prompt, FloatPrompt, Confirm +#from rich.pretty import Pretty + +from .raw import load_4d +from .metadata import AnyMetadata + + +def load_adf(path: t.Union[str, Path]) -> t.Tuple[numpy.ndarray, t.Any]: + f = h5py.File(path) + + images = t.cast(h5py.Group, f['Data/Image']) + if len(images) == 0: + raise ValueError("No images found in dataset.") + if len(images) > 1: + raise ValueError("Multi-image files not currently supported.") + image = t.cast(h5py.Group, next(iter(images.values()))) + raw_meta: numpy.ndarray = t.cast(h5py.Dataset, image['Metadata'])[:, 0][()] + meta_bytes = raw_meta.tobytes() + meta_bytes = meta_bytes[:meta_bytes.find(b"\0")] + meta = json.loads(meta_bytes) + data = t.cast(h5py.Dataset, image['Data'])[..., 0][()] + + return (data, meta) + + +def normed_to_uint8(data: numpy.ndarray) -> NDArray[numpy.uint8]: + return numpy.floor(numpy.clip(data, 0, 1) * 255.999).astype(numpy.uint8) + + +def normed_to_color(data: numpy.ndarray, color: ArrayLike) -> NDArray[numpy.uint8]: + return numpy.floor(numpy.clip(data, 0, 1)[..., None] * color).astype(numpy.uint8) + + +def signed_angle(v1: numpy.ndarray, v2: numpy.ndarray) -> float: + return numpy.pi - numpy.mod(numpy.arctan2(v1[0] * v2[1] - v1[1] * v2[0], numpy.dot(v1, v2)), 2.*numpy.pi) + + +def load_files(paths: t.Iterable[t.Union[str, Path]], inner: float, outer: float) -> t.Iterable[t.Tuple[Path, AnyMetadata, NDArray[numpy.float32]]]: + queue: Queue[t.Tuple[Path, AnyMetadata, NDArray[numpy.float32]]] = Queue(1) + finished = Event() + + # producer thread + # should hold one in queue, one waiting for queue, one processing + def producer(): + try: + for path in paths: + # eagerly load and put on queue + path = Path(path) + meta = AnyMetadata.parse_file(path) + + exp_path = meta.path or Path('.') + raw_path = exp_path / (meta.raw_filename or "scan_x128_y128.raw") + if not raw_path.exists(): + raise ValueError(f"Can't find raw data at path '{raw_path}'") + raw = load_4d(raw_path) + + kx = numpy.arange(raw.shape[-2], dtype=numpy.float32) - raw.shape[-2] / 2. + ky = numpy.arange(raw.shape[-1], dtype=numpy.float32) - raw.shape[-1] / 2. + kyy, kxx = numpy.meshgrid(kx, ky, indexing='ij') + k2 = kyy**2 + kxx**2 + virtual_aperture = numpy.zeros(raw.shape[-2:], dtype=bool) + virtual_aperture[(k2 >= inner**2) & (k2 <= outer**2)] = 1 + virtual_img = numpy.sum(raw * virtual_aperture, axis=(-1, -2)) + + queue.put((path, meta, virtual_img)) + # finished all files + finished.set() + except BaseException: + import traceback + traceback.print_exc() + + thread = Thread(target=producer, name='loader', daemon=True) + thread.start() + + # periodically check that thread is still running, + # to prevent deadlock + while thread.is_alive(): + try: + val = queue.get(timeout=1.) + except Empty: + continue + yield val + + thread.join() + + # drain queue once thread is finished + while True: + try: + val = queue.get(False) + except Empty: + break + yield val + + if not finished.is_set(): + raise ValueError("Error in file loading") + + print("Finished processing files!") + + +def calibrated_meta_path(meta_path: Path): + meta_path_name = meta_path.stem + if meta_path_name.endswith('_orig'): + meta_path_name = meta_path_name[:-5] + new_meta_path = meta_path.with_stem(meta_path_name + "_calib") + return new_meta_path + + +@click.command() +@click.argument('path', type=click.Path(exists=True, dir_okay=True, file_okay=True)) +@click.option('--include', type=str, multiple=True, + help="Glob of filenames to include. If not specified, include all '.json' files") +@click.option('--exclude', type=str, multiple=True, + help="Glob of filenames to exclude.") +@click.option('--skip-existing/--no-skip-existing', default=True, + help="Whether to skip datasets which have already been calibrated. Defaults to true.") +def calc_drift(path: t.Union[str, Path], include: t.Sequence[str] = (), exclude: t.Sequence[str] = (), skip_existing: bool = True): + """ + Calculate the linear drift present in a ptychography dataset, or group of datasets. + + PATH should be the path to a JSON metadata file, or to a directory + which will be searched for JSON metadata files. + + Datasets can be included or excluded with the `--include` and `--exclude` options. + These can be repeated multiple times. + + Calculated drift is stored in a new metadata file with the suffix `_calib`. + By default, datasets which already have this file are skipped. This behavior can be changed + using the `--no-skip-existing` option. + """ + + console = Console() + path = Path(path) + + if path.is_file(): + paths = [path] + else: + exclude = (*exclude, "*_calib.json", '._*') + paths = list(path.glob('**/*.json')) + if len(include): + paths = list(filter(lambda path: any(fnmatch.fnmatch(path.name, pat) for pat in include), paths)) + paths = list(filter(lambda path: not any(fnmatch.fnmatch(path.name, pat) for pat in exclude), paths)) + + if skip_existing: + paths = list(filter(lambda path: not calibrated_meta_path(path).exists(), paths)) + + paths.sort() + + console.print(f"{len(paths)} file(s) to process.") + + params = { + 'det': 'bf', + 'inner': 0.0, + 'outer': 2.0, + 'd1': 6., + 'd2': 6. + } + + params['det'] = Prompt.ask("Detector type", choices=['bf', 'adf', 'af'], default=params['det'], console=console) + if params['det'] in ('af', 'adf'): + params['inner'] = float(FloatPrompt.ask(r"Inner radius \[mrad]", default=params['inner'], console=console)) + inner = params['inner'] + else: + inner = 0. + + if params['det'] in ('af', 'bf'): + params['outer'] = float(FloatPrompt.ask(r"Outer radius \[mrad]", default=params['outer'], console=console)) + outer = params['outer'] + else: + outer = numpy.inf + + params['scale1'] = FloatPrompt.ask("Distance 1 scale", default=1., console=console) + params['scale2'] = FloatPrompt.ask("Distance 2 scale", default=params['scale1'], console=console) + #params['angle'] = FloatPrompt.ask("Signed angle from distance 1 -> distance 2 (degree, CWW is +)", default=90., console=console) + #params['angle'] *= numpy.pi/180. + + console.print("Loading files...") + for (meta_path, meta, virtual_img) in load_files(paths, inner, outer): + console.print(f"Loaded file '{meta.path}'") + if not bool(Confirm.ask("Process this file?", default=True)): + console.print("Skipping file...") + continue + + while True: + correction_matrix = calc_drift_one(console, meta, virtual_img, params) + + choice = Prompt.ask("Save calibration?", choices=['y', 'n', 'abort'], default='y', console=console) + if choice == 'abort': + break + if choice.lower() not in ('y', 'yes'): + continue + + new_meta_path = calibrated_meta_path(meta_path) + new_meta = meta.copy() + new_meta.scan_correction = tuple(map(tuple, correction_matrix)) # type: ignore + + with open(new_meta_path, 'w') as f: + f.write(new_meta.json(indent=4, exclude=set('path'))) + + console.print(f"New metadata written to '{new_meta_path}'!") + break + + +def calc_drift_one(console: Console, meta: AnyMetadata, virtual_img: NDArray[numpy.float32], params: t.Dict[str, t.Any]): + print(meta) + scan_step_4d = numpy.array(meta.scan_step) * 1e10 # m to angstrom + scan_size_4d = scan_step_4d * meta.scan_shape + + console.print(f" 4D pixel size: {scan_step_4d[0]:.3f} x {scan_step_4d[1]:.3f} A", style='logging.level.info') + console.print(f" 4D image size: {scan_size_4d[0]:.2f} x {scan_size_4d[1]:.2f} A", style='logging.level.info') + + fig, ax = pyplot.subplots(constrained_layout=True) + canvas = fig.canvas + ax.set_xlabel('x [A]') + ax.set_ylabel('y [A]') + + ax.set_xlim(0., scan_size_4d[0]) + ax.set_ylim(scan_size_4d[1], 0.) + + img = ax.imshow(virtual_img, extent=(-0.5 * scan_step_4d[0], (virtual_img.shape[1] + 0.5) * scan_step_4d[0], (virtual_img.shape[0] + 0.5) * scan_step_4d[1], -0.5 * scan_step_4d[1])) + img.set_picker(True) + img.set_animated(True) + + path: t.Optional[PathPatch] = None + circles: t.List[Circle] = [] + + selected: t.Optional[Circle] = None + bg = canvas.copy_from_bbox(ax.bbox) # type: ignore + + def draw_artists(): + ax.draw_artist(img) + if path is not None: + ax.draw_artist(path) + for circle in circles: + ax.draw_artist(circle) + canvas.blit(ax.bbox) # type: ignore + + def draw(event=None): + nonlocal path, bg + + vertices = numpy.array([circle.center for circle in circles]) + if len(vertices): + if path is None: + path = PathPatch(MplPath(vertices), fill=False, lw=3.) + ax.add_patch(path) + else: + path.set_path(MplPath(vertices)) + + bg = canvas.copy_from_bbox(ax.bbox) # type: ignore + draw_artists() + + """ + def on_click(event: MouseEvent): + if not event.button == MouseButton.LEFT: + return + """ + + def on_release(event: MouseEvent): + nonlocal selected + if not event.button == MouseButton.LEFT: + return + selected = None + + def on_pick(event: PickEvent): + nonlocal selected + nonlocal path + if event.mouseevent.button != MouseButton.LEFT: + return + if event.artist is not img: + return + + for circle in circles: + if circle.contains_point((event.mouseevent.x, event.mouseevent.y)): # type: ignore + selected = circle + return + + if len(circles) < 3: + pos = t.cast( + t.Tuple[float, float], + list(ax.transData.inverted().transform((event.mouseevent.x, event.mouseevent.y))) + ) + + node = Circle(pos, radius=0.5, fc='white', ec='black', transform=ax.transData) + node.set_animated(True) + ax.add_patch(node) + circles.append(node) + selected = node + draw() + + def on_move(event: MouseEvent): + if selected is None or event.inaxes is None or event.button != MouseButton.LEFT: + return + if event.x is None or event.y is None: + return + new_pt = ax.transData.inverted().transform((event.x, event.y)) + selected.center = new_pt + + canvas.restore_region(bg) # type: ignore + draw() + #draw_artists() + + warped_img = None + correction_matrix = None + + def on_press(event: KeyEvent): + if event.key != 'enter': + return + vertices = numpy.array([circle.center for circle in circles]) + + if len(vertices) != 3: + console.print("Need 3 points to compute drift") + return + + vecs = numpy.diff(vertices, axis=0) + #vecs_next = numpy.roll(vecs, 1, axis=0) + + print(f"angle: {signed_angle(vecs[:, 0], vecs[:, 1]) * 180./numpy.pi}") + + while True: + dists = [] + for (i, scale, default) in zip(range(len(vertices)-1), (params['scale1'], params['scale2']), (params['d1'], params['d2'])): + while True: + try: + s = Prompt.ask(f"Distance {i+1} (* {scale:.3f} A)", default=default, console=console) + val = float(eval(s)) + dists.append(val * scale) + except Exception: + continue + break + console.print(f"Distances [A]: {dists[0]:.3f}, {dists[1]:.3f}") + if Confirm.ask("Distances correct?", default=True): + break + + # first, we use some logic to determine which orthogonal basis the measurements are closest to. + horz_vec = numpy.argmax(numpy.abs(vecs[:, 0])) # vec with maximum x component is horizontal + horz_sign = numpy.sign(vecs[horz_vec, 0]) # whether to flip horizontal vector + vert_sign = numpy.sign(vecs[horz_vec-1, 1]) # whether to flip vertical vector + + # make target_vecs using the determinations above + target_vecs = numpy.diag(numpy.array(dists)) + if horz_vec != 0: + # v1 should be vertical + target_vecs = target_vecs[::-1, :] + # flip target_vecs based on desired signs + target_vecs = numpy.diag([horz_sign, vert_sign]) @ target_vecs + console.print(f"finding transformation\n{vecs.T}\nto\n{target_vecs}") + + # we try to find A which transforms `vecs` into `target_vecs`. + a = target_vecs @ numpy.linalg.inv(vecs.T) + + nonlocal warped_img, correction_matrix + + # because we only know distances, `a` + q, r = numpy.linalg.qr(a) + r = numpy.diag(numpy.sign(numpy.diagonal(r))) @ r # flip to ensure diagonals are positive. These are absorbed into `q`. + console.print(f"actual dists: {dists}") + console.print(f"measured dists: {numpy.linalg.norm(vecs, axis=-1)}") + console.print(f"distortion:\n{r[::-1, ::-1]}") # correct for ptychoshelves coordinate system + #print(f"distortion:\n{r}") # correct for ptychoshelves coordinate system + vecs_after = (r @ vecs.T).T + console.print(f"dists after correction: {numpy.linalg.norm(vecs_after, axis=-1)}") + console.print(f"vecs after correction: {vecs_after[0]}, {vecs_after[1]}") + angle_after = signed_angle(vecs_after[0], vecs_after[1]) + console.print(f"angle after correction: {180./numpy.pi * angle_after:.2f}") + #vecs_after_a = (a @ vecs.T).T + #print(f"vecs after a correction: {vecs_after_a[:, 0]}, {vecs_after_a[:, 1]}") + #angle_after_a = signed_angle(vecs_after_a[:, 0], vecs_after_a[:, 1]) + #print(f"angle after a correction: {180./numpy.pi * angle_after_a:.2f}") + correction_matrix = r + + # warp image given corrections + warped_shape = tuple(numpy.ceil(numpy.array(virtual_img.shape) * numpy.max(numpy.abs(numpy.diagonal(r)))).astype(int)) + affine = numpy.block([[numpy.linalg.inv(r)[::-1, ::-1], numpy.zeros((2, 1))], [numpy.zeros((1, 2)), numpy.ones((1, 1))]]) + translation = numpy.eye(3) + translation[:2, -1] += numpy.array(virtual_img.shape[-2:]) / 2. + translation2 = numpy.eye(3) + translation2[:2, -1] -= numpy.array(warped_shape[-2:]) / 2. + affine = translation @ affine @ translation2 + warped_img = scipy.ndimage.affine_transform(virtual_img, affine, output_shape=warped_shape) + pyplot.close(fig) + + #canvas.mpl_connect('button_press_event', on_click) + canvas.mpl_connect('button_release_event', on_release) # type: ignore + canvas.mpl_connect('key_press_event', on_press) # type: ignore + canvas.mpl_connect('motion_notify_event', on_move) # type: ignore + canvas.mpl_connect('draw_event', draw) # type: ignore + canvas.mpl_connect('pick_event', on_pick) # type: ignore + + pyplot.show() + + if warped_img is None: + return + + fig, ax = pyplot.subplots() + ax.imshow(warped_img, vmin=float(numpy.nanmin(virtual_img)), vmax=float(numpy.nanmax(virtual_img)), + extent=(-0.5 * scan_step_4d[0], (warped_img.shape[1] + 0.5) * scan_step_4d[0], (warped_img.shape[0] + 0.5) * scan_step_4d[1], -0.5 * scan_step_4d[1])) + ax.set_xlabel('x [A]') + ax.set_ylabel('y [A]') + ax.set_title("Warped image") + + def on_press_2(event: KeyEvent): + if event.key != 'enter': + return + pyplot.close(fig) + + fig.canvas.mpl_connect('key_press_event', on_press_2) # type: ignore + + pyplot.show() + + return correction_matrix \ No newline at end of file diff --git a/phaser/extras/calc_tilt.py b/phaser/extras/calc_tilt.py new file mode 100644 index 0000000..5739c81 --- /dev/null +++ b/phaser/extras/calc_tilt.py @@ -0,0 +1,164 @@ +from pathlib import Path +import math +import typing as t + +import click +from matplotlib import pyplot +from matplotlib.backend_bases import MouseEvent, MouseButton, PickEvent +from matplotlib.colors import LogNorm, Normalize +from matplotlib.patches import Rectangle +import numpy +from rich.console import Console +from rich.prompt import Prompt, FloatPrompt + +from .raw import load_4d +from .metadata import AnyMetadata + + +@click.command() +@click.argument('path', type=click.Path(exists=True, dir_okay=False, file_okay=True)) +def calc_tilt(path: t.Union[str, Path]): + """ + Calculate the mistilt present in a ptychography datase. + + PATH should be the path to a JSON metadata file. + """ + console = Console() + + path = Path(path) + meta = AnyMetadata.parse_file(path) + + exp_path = meta.path or Path('.') + raw_path = exp_path / (meta.raw_filename or "scan_x128_y128.raw") + if not raw_path.exists(): + raise ValueError(f"Can't find raw data at path '{raw_path}'") + + raw = load_4d(raw_path) + + params = { + 'det': 'bf', + 'inner': 0.0, + 'outer': 2.0, + } + + params['det'] = Prompt.ask("Detector type", choices=['bf', 'adf', 'af'], default=params['det'], console=console) + if params['det'] in ('af', 'adf'): + params['inner'] = float(FloatPrompt.ask(r"Inner radius \[mrad]", default=params['inner'], console=console)) + inner = params['inner'] + else: + inner = 0. + + if params['det'] in ('af', 'bf'): + params['outer'] = float(FloatPrompt.ask(r"Outer radius \[mrad]", default=params['outer'], console=console)) + outer = params['outer'] + else: + outer = numpy.inf + + def get_pattern_norm(pattern) -> Normalize: + max_int = numpy.nanquantile(pattern, 0.99) + min_int = numpy.nanmin(pattern) + return Normalize( + 0.4*(max_int - min_int) + min_int, + max_int + ) + + console.print("Loading dataset...") + kx = numpy.arange(raw.shape[-2], dtype=numpy.float32) - raw.shape[-2] / 2. + kx *= meta.diff_step + ky = numpy.arange(raw.shape[-1], dtype=numpy.float32) - raw.shape[-1] / 2. + ky *= meta.diff_step + kyy, kxx = numpy.meshgrid(kx, ky, indexing='ij') + k2 = kyy**2 + kxx**2 + virtual_aperture = numpy.zeros(raw.shape[-2:], dtype=bool) + virtual_aperture[(k2 >= inner**2) & (k2 <= outer**2)] = 1 + virtual_img = numpy.nansum(raw * virtual_aperture, axis=(-1, -2)) + pattern = numpy.nansum(raw, axis=(0, 1)) + + scan_step = numpy.array(meta.scan_step) * 1e10 # m to A + + real_fig, real_ax = pyplot.subplots() + real_ax.set_xlabel(r"x [$\mathrm{\AA}$]") + real_ax.set_ylabel(r"y [$\mathrm{\AA}$]") + real_img = real_ax.imshow(virtual_img, extent=(-0.5 * scan_step[0], (virtual_img.shape[1] + 0.5) * scan_step[0], (virtual_img.shape[0] + 0.5) * scan_step[1], -0.5 * scan_step[1])) + + recip_fig, recip_ax = pyplot.subplots() + recip_ax.invert_xaxis() # to match sign convention + recip_ax.set_aspect('equal') + recip_ax.set_xlabel(r"$\theta_x$ [mrad]") + recip_ax.set_ylabel(r"$\theta_y$ [mrad]") + recip_img = recip_ax.pcolormesh(kx, ky, pattern, cmap='magma', norm=get_pattern_norm(pattern)) + + real_img.set_picker(True) + real_img.set_animated(True) + + rect = None + #bg = real_fig.canvas.copy_from_bbox(real_ax.bbox) # type: ignore + drag_start: t.Optional[t.Tuple[float, float]] = None + + def draw_artists(): + real_ax.draw_artist(real_img) + if rect is not None: + real_ax.draw_artist(rect) + + real_fig.canvas.blit(real_ax.bbox) + + def draw(event=None): + #nonlocal bg + #bg = real_fig.canvas.copy_from_bbox(real_ax.bbox) # type: ignore + #real_fig.canvas.restore_region(bg) # type: ignore + draw_artists() + + def on_pick(event: PickEvent): + nonlocal rect, drag_start + if event.mouseevent.button != MouseButton.LEFT: + return + if event.artist is not real_img: + return + + drag_start = tuple(real_ax.transData.inverted().transform((event.mouseevent.x, event.mouseevent.y))) + rect = Rectangle(drag_start, 1., 1., fc='none', ec='red', transform=real_ax.transData) + + def on_release(event: MouseEvent): + nonlocal drag_start + if not event.button == MouseButton.LEFT: + return + drag_start = None + + draw() + if rect is not None: + sum_in_rect(rect) + + def on_move(event: MouseEvent): + nonlocal rect + if drag_start is None or event.inaxes is None or event.button != MouseButton.LEFT: + return + if event.x is None or event.y is None or rect is None: + return + pt: t.Tuple[float, float] = tuple(real_ax.transData.inverted().transform((event.x, event.y))) + + rect.set_width(abs(pt[0] - drag_start[0])) + rect.set_height(abs(pt[1] - drag_start[1])) + rect.set_xy((min(pt[0], drag_start[0]), min(pt[1], drag_start[1]))) + draw() + + def sum_in_rect(rect: Rectangle): + nonlocal pattern + + bbox = rect.get_bbox() + min_x, min_y = numpy.array(bbox.min) / scan_step + max_x, max_y = numpy.array(bbox.max) / scan_step + + pattern = numpy.nansum(raw[ + math.floor(min_y):math.ceil(max_y), + math.floor(min_x):math.ceil(max_x), + ], axis=(0, 1)) + recip_img.set_array(pattern) + recip_img.set_norm(get_pattern_norm(pattern)) + recip_fig.canvas.draw() + + real_fig.canvas.mpl_connect('button_release_event', on_release) # type: ignore + real_fig.canvas.mpl_connect('motion_notify_event', on_move) # type: ignore + real_fig.canvas.mpl_connect('draw_event', draw) # type: ignore + real_fig.canvas.mpl_connect('pick_event', on_pick) # type: ignore + + pyplot.show() \ No newline at end of file diff --git a/phaser/extras/extract_params.py b/phaser/extras/extract_params.py new file mode 100755 index 0000000..996eab7 --- /dev/null +++ b/phaser/extras/extract_params.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python3 + +from pathlib import Path +import typing as t + +import click +import numpy +from numpy.typing import NDArray +import h5py +from scipy.io import loadmat + + +def get_scanrot(probe_pos: NDArray[numpy.floating]): + pos = probe_pos + n = int(numpy.sqrt(pos.shape[-1])) + if not n**2 == pos.shape[-1]: + raise ValueError("Only works on square probe grid") + pos = pos.reshape(2, n, n) + #xdiff, ydiff = pos[:, -1, 0] - pos[:, 0, 0] + (ydiffs, xdiffs) = (pos[:, 1:, 0].T - pos[:, 0, 0]).T / numpy.arange(1, n) + angles = -numpy.arctan2(ydiffs, xdiffs) + scan_rot = numpy.mean(angles[len(angles)//2:]) + return scan_rot + + +def plot_convergence(probe, conv): + from matplotlib import pyplot + from matplotlib.patches import Circle + fig, ax = pyplot.subplots() + probe = numpy.fft.fftshift(probe) + ax.imshow(numpy.abs(probe)) + ax.add_artist(Circle((len(probe)//2, len(probe)//2), radius=conv, color='red', fill=False)) + pyplot.show() + + +def get_convergence(probe): + """Return the bright field radius, in px, for the given probe.""" + probe = numpy.fft.fft2(probe) + #from matplotlib import pyplot + #pyplot.imshow(numpy.abs(probe)) + #pyplot.show() + mask = numpy.abs(probe) > 1e-3 + conv = numpy.sqrt(numpy.sum(mask) / numpy.pi) + print(f"pix_convergence_angle: {conv} (convergence angle in px)") + #plot_convergence(probe, conv) + return conv + + +def from_niter(f: h5py.File): + print("Loading as output parameters") + params = t.cast(h5py.Group, f['par/p']) + pix_size = 1e3/t.cast(h5py.Dataset, params['z'])[()].flat[0] + print(f"diffraction_pix_size: {pix_size} (mrad, size of each diffraction pixel)") + + probe = t.cast(h5py.Dataset, params['probe_initial'])[()] + probe = probe['real'] + 1.j * probe['imag'] + conv = get_convergence(probe) + print(f"mrad_convergence_angle: {pix_size * conv} (convergence angle in mrad)") + + #import pdb; pdb.set_trace() + orig_probe_pos = t.cast(h5py.Dataset, params['positions_0'])[()] + final_probe_pos = t.cast(h5py.Dataset, params['positions'])[()] + #probe_pos = params['positions'][()] # final probe positions + print(f"scanrot_det_offset: {get_scanrot(orig_probe_pos) * 180./numpy.pi} (deg, scan rotation)") + print(f" final positions: {get_scanrot(final_probe_pos) * 180./numpy.pi} (deg, scan rotation)") + plot_probe_positions(orig_probe_pos) + plot_probe_positions(final_probe_pos) + + +def from_probe_positions(f: h5py.File): + print("Loading as probe positions") + probe_pos = t.cast(h5py.Dataset, f['probe_positions_0'])[()] + print(f"scanrot_det_offset: {get_scanrot(probe_pos) * 180./numpy.pi} (deg, scan rotation)") + plot_probe_positions(probe_pos) + + +def plot_probe_positions(probe_pos): + from matplotlib import pyplot + pyplot.scatter(probe_pos[0], probe_pos[1], c=numpy.arange(probe_pos.shape[1])) + pyplot.show() + + +def plot_probe_diff(probes: h5py.Dataset, i: int, n: int): + from matplotlib import pyplot + from matplotlib.backend_bases import KeyEvent + + np = int(numpy.sqrt(probes.shape[0])) + if not np**2 == probes.shape[0]: + raise ValueError(f"Only works on square probe grid (n={probes.shape[0]})") + + x = np//2 + y = np//2 + + #fig, axs = pyplot.subplots(n, n, sharex=True, sharey=True) + fig, ax = pyplot.subplots() + probe = probes[x + y*np][()] + img = ax.imshow(probe) + + def update(): + nonlocal probe + print(f"\rpos: ({x}, {y}) ", end='') + probe = probes[x + y*np][()] + img.set_data(probe) + fig.canvas.draw_idle() + + def key_press(event: KeyEvent): + nonlocal x, y + #print(f"Key released: {event.key}") + if event.key == 'left': + if x > 0: + x -= 1 + elif event.key == 'right': + if x < np - 1: + x += 1 + elif event.key == 'up': + if y > 0: + y -= 1 + elif event.key == 'down': + if y < np - 1: + y += 1 + else: + return + update() + + fig.canvas.mpl_connect('key_press_event', key_press) + pyplot.show() + print() + + #i = numpy.arange(n) + #(x, y) = numpy.meshgrid(i, i) + #idxs = (x + y*np).reshape(-1) + #probes = probes[idxs][()] + #for (i, ax) in enumerate(axs.T.flat): + # ax.imshow(probes[i]) + + +@click.command() +@click.argument('path', type=click.Path(exists=True, dir_okay=False)) +def extract_params(path: t.Union[str, Path]): + """Extract/infer reconstruction parameters from PtychoShelves files.""" + path = Path(path) + print(f"Loading params from '{path}'...") + f = None + if path.suffix == '.mat': + try: + f = loadmat(path) + except NotImplementedError: + pass + if f is None: + f = h5py.File(path) + + if 'par' in f: + assert isinstance(f, h5py.Group) + from_niter(f) + elif 'probe_positions_0' in f: + assert isinstance(f, h5py.Group) + from_probe_positions(f) + elif 'probe' in f: + probe = f['probe'] + get_convergence(probe) + elif 'dp' in f: + print("Checking probe movement...") + probes = f['dp'] # shape: (nprobes, nk, nk) + assert isinstance(probes, h5py.Dataset) + plot_probe_diff(probes, 0, 5) + #raise ValueError("No useful info in 'data_dp.mat'") + else: + #import pdb; pdb.set_trace() + raise ValueError(f"Unknown file structure. Keys: {list(f.keys())}") + + +if __name__ == '__main__': + extract_params() diff --git a/phaser/extras/main.py b/phaser/extras/main.py new file mode 100644 index 0000000..29344af --- /dev/null +++ b/phaser/extras/main.py @@ -0,0 +1,43 @@ +import typing as t + +import click + + +class MainCommand(click.MultiCommand): + def __init__(self, commands: t.Union[t.Iterable[str], t.Dict[str, t.Union[str, t.Tuple[str, str]]]], **kwargs): + super().__init__(**kwargs) + self.commands: t.Dict[str, t.Union[str, t.Tuple[str, str]]] + if isinstance(commands, dict): + self.commands = commands + else: + self.commands = dict((v, v) for v in commands) + + def list_commands(self, ctx: click.Context): + return list(self.commands.keys()) + + def get_command(self, ctx: click.Context, cmd_name: str) -> t.Optional[click.Command]: + name = cmd_name.lower() + val = (self.commands.get(name) or + self.commands.get(name.replace('-', '_'))) + if val is None: + return None + if isinstance(val, tuple): + (module, func) = val + else: + module = val + func = val + mod = __import__(f"{__package__}.{module}", None, None, [func]) + return getattr(mod, func) + + +@click.command(cls=MainCommand, commands=dict((v, v) for v in + ('prepare', 'run', 'view_raw', 'view_prepared', 'view_output', + 'process_metadata', 'extract_params', 'to_csv', 'calc_drift', 'calc_tilt') +)) +def main(): + """LeBeau group ptychography utilities.""" + ... + + +if __name__ == '__main__': + main() diff --git a/phaser/extras/prepare.py b/phaser/extras/prepare.py new file mode 100644 index 0000000..01f47d2 --- /dev/null +++ b/phaser/extras/prepare.py @@ -0,0 +1,139 @@ +from pathlib import Path +from itertools import chain +from glob import glob as _glob +import sys +import typing as t + +import click +import json +import yaml +from pydantic import ValidationError + +try: + from yaml import CLoader as Loader +except ImportError: + from yaml import Loader + +from ptycho_lebeau.params import ParamMetaSet, SaveRecord +from ptycho_lebeau.metadata import Metadata, AnyMetadata +from ptycho_lebeau.util import handle_exception + + +def _try_parse_metadata(path: Path) -> t.Union[Metadata, ValidationError, None]: + """ + Try and parse a metadata file. Raises for an invalid file. + Returns None if `path` might be reconstruction parameters instead. + """ + with open(path) as f: + raw_meta = json.load(f) + + try: + return AnyMetadata.parse_obj(raw_meta, path) + except ValidationError as e: + if 'metadata' in raw_meta.get('file_type', '') or 'time_unix' in raw_meta: + # definitely a metadata file. Raise an exception + raise + # we're not sure yet, so return an error we can decide to throw later + return e + + + +@click.command() +@click.argument('files', type=click.Path(allow_dash=True, dir_okay=False), nargs=-1) +@click.option('--glob/--no-glob', default=True, help="Enable globbing of arguments.") +@click.option('--sparse/--dense', default=None, help="Sparse parameter set.") +@click.option('--out-file', type=click.Path(dir_okay=False), help="Output file to write list of prepared files into") +@handle_exception +def prepare(files: t.Union[t.Sequence[t.Union[str, Path]], str, Path], *, glob: bool = True, + sparse: t.Optional[bool] = None, out_file: t.Union[str, Path, None] = None): + """ + Process reconstruction parameters (and acquisition metadata) into PtychoShelves-ready JSON files. + """ + + if isinstance(files, (str, Path)): + files = (files,) + + param_paths: t.List[Path] = [] + param_sets: t.List[ParamMetaSet] = [] + metadatas: t.List[Metadata] = [] + + # process supplied files, sorting them into parameters and metadata + it = chain.from_iterable(_glob(str(f), recursive=True) for f in files) if glob else files + + def _process_filelist(path: t.Union[str, Path]) -> t.Iterable[Path]: + path = Path(path) + if path.suffix.lower() != '.txt': + yield path + return + print(f"Loading filelist '{path}'...") + with open(path, 'r') as f: + lines = list(f) + + yield from (Path(line.strip()) for line in lines) + + # .txt files are treated as a list of files to load + it = chain.from_iterable(map(_process_filelist, it)) + + for path in it: + print(f"Loading '{path}'...") + meta = None + # json, might be metadata file + if path.suffix.lower() == '.json': + meta = _try_parse_metadata(path) + if isinstance(meta, Metadata): + metadatas.append(meta) + print(f"Loaded '{path}' as '{meta.file_type}'.") + continue + + with open(path, 'r') as f: + objs = list(yaml.load_all(f, Loader)) + print(f"Loaded '{path}' as reconstruction(s).") + try: + # ugly, fix later + sets = list(map(ParamMetaSet.parse_obj, objs)) + param_sets.extend(sets) + param_paths.extend((path.parent,) * len(sets)) + except ValidationError as e: + if len(objs) == 1 and meta is not None: + print(f"Could not parse '{path}' as metadata or reconstruction params.", file=sys.stderr) + print(f"As metadata:\n{meta}", file=sys.stderr) + print(f"\nAs reconstruction:\n{e}", file=sys.stderr) + else: + print(f"Could not parse file '{path}' as reconstruction params:", file=sys.stderr) + print(e, file=sys.stderr) + sys.exit(1) + + # keep track of what we've saved + save_record = SaveRecord() + saved_paths = [] + + if len(param_sets) == 0: + print("No parameter sets to process.") + else: + print(f"Processing {len(param_sets)} parameter set(s).") + if len(metadatas) > 1: + print(f"{len(metadatas)} dataset(s) per parameter set") + + for (param_set, path) in zip(param_sets, param_paths): + if len(metadatas) > 0: + param_set = param_set.with_metadata(metadatas) + + i = -1 + for (i, reconstruction) in enumerate(param_set.iter(save_record, path=path, sparse=sparse)): + path = reconstruction.name + '.json' + with open(path, 'w') as f: + print(f"Saving reconstruction #{i+1} to '{path}'...") + f.write(reconstruction.json(indent=4, exclude={'engines': {'__all__': {'slices', 'init_slices'}}})) + saved_paths.append(path) + + print(f"Saved {i+1} reconstruction(s)") + + if out_file is not None: + print(f"Writing list of reconstructions to '{out_file}'") + with open(out_file, 'w') as f: + for path in saved_paths: + print(path, file=f) + + +if __name__ == '__main__': + prepare() diff --git a/phaser/extras/process_metadata.py b/phaser/extras/process_metadata.py new file mode 100644 index 0000000..c28b24f --- /dev/null +++ b/phaser/extras/process_metadata.py @@ -0,0 +1,240 @@ + +import sys +from pathlib import Path +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +import typing as t + +import click +from click.exceptions import Exit +from rich.prompt import Prompt +from rich.theme import Theme +from rich.console import Console + +from ptycho_lebeau.metadata import EmpadMetadata, UnscannedMetadata + + +T = t.TypeVar('T') + + +def default_console() -> Console: + return Console(theme=Theme({ + 'warning': 'bold yellow', + 'error': 'bold red', + 'info': 'blue', + 'prompt.invalid': 'bold red', + })) + + +def prompt_ask(name, unit=None, default=None, default_str: str = "", console=None, validate=float, err="Please enter a number") -> t.Tuple[str, t.Any]: + if not console: + console = default_console() + while True: + str_val = Prompt.ask(f"{name} \\[{unit}]", default=default_str, console=console) + if str_val == default_str: + return (default_str, default) + try: + val = validate(str_val) + return (str_val, val) + except Exception as e: + console.print(f"[prompt.invalid]{err}") + + +@dataclass +class Parameter(t.Generic[T], ABC): + name: str + default_val: t.Optional[T] + + @abstractmethod + def ask(self, metadata: EmpadMetadata, console: t.Optional[Console] = None) -> T: + ... + + +@dataclass +class OptStringParam(Parameter[t.Optional[str]]): + def ask(self, metadata: EmpadMetadata, console: t.Optional[Console] = None) -> t.Optional[str]: + val = Prompt.ask(self.name, default=self.default_val, console=console) + if val is not None: + self.default_val = val + return val + + +@dataclass +class FloatParam(Parameter[float]): + unit: t.Optional[str] + default_str: str + conv_factor: t.Optional[float] = None + + def ask(self, metadata: EmpadMetadata, console: t.Optional[Console] = None) -> float: + (val_str, val) = prompt_ask(self.name, self.unit, self.default_val, self.default_str, console, validate=float) + self.default_str = val_str + self.default_val = val + if self.conv_factor: + return val * self.conv_factor + return val + + +def default_params() -> t.Sequence[t.Tuple[str, Parameter]]: + return [ + ('conv_angle', FloatParam("Convergence angle", 18.9, "mrad", "18.9")), + ('defocus', FloatParam("Defocus (CW is +)", 0.0, "nm", "0.0", 1e-9)), # nm to m + ('beam_current', FloatParam("Approx. beam current", 30.0, "pA", "30.0", 1e-12)), # pA to A + ('diff_step', CameraParam(18.8 / 24., 0.4575)), + ('adu', AduParam(578, 200000.)), # calibrated 2022-12-14 + ('author', OptStringParam("Author", None)), + ] + + +@dataclass(init=False) +class CameraParam(Parameter[float]): + """Maps f"{camera_length:0.3f}" to diff_step values""" + used: t.Dict[str, float] = field(default_factory=dict) + """ Camera length associated with default_val (m)""" + default_camera_length: float = 0.4575 + + def __init__(self, diff_step: float, camera_length: float): + self.name = 'Diffraction pixel spacing' + self.default_val = diff_step + self.used = {} + self.default_camera_length = camera_length + + def ask(self, metadata: EmpadMetadata, console: t.Optional[Console] = None) -> float: + if not console: + console = default_console() + camera_length = metadata.camera_length + camera_length_str = f"{camera_length:0.3f}" + default = self.used.get(camera_length_str, None) + if default is None and self.default_val is not None: + # diff_step scales inversely with camera length + default = self.default_val * self.default_camera_length / camera_length + default_str = f"{default:.03f}" if default is not None else "" + + console.print(f"Camera length: {metadata.camera_length * 1e3 :.0f} mm") + (val_str, val) = prompt_ask(self.name, "mrad/px", default, default_str, console) + self.used[camera_length_str] = val + return val + + +@dataclass(init=False) +class AduParam(Parameter[float]): + """Maps f"{voltage:.0f}" to ADU values""" + used: t.Dict[str, float] + """Voltage associated with default_val (V)""" + default_voltage: float + + def __init__(self, default: float, default_voltage: float): + self.name = 'Single-electron intensity' + self.default_val = default + self.used = {} + self.default_voltage = default_voltage + + def ask(self, metadata: EmpadMetadata, console: t.Optional[Console] = None) -> float: + voltage = metadata.voltage + voltage_str = f"{voltage:.0f}" + default = self.used.get(voltage_str, None) + if default is None and self.default_val is not None: + # ADU should scale linearly with voltage + default = self.default_val * voltage / self.default_voltage + default_str = f"{default:.0f}" if default is not None else "" + + (val_str, val) = prompt_ask(self.name, "ADU", default, default_str, console) + self.used[voltage_str] = val + return val + + +def process_dir(path, params: t.Optional[t.Sequence[t.Tuple[str, Parameter]]] = None, + console: t.Optional[Console] = None, prompt: bool = True): + path = Path(path) + + if console is None: + console = default_console() + if params is None: + params = default_params() + + for raw_file in path.rglob('**/*.xml'): + raw_dir = raw_file.parent + try: + [metadata_path, *extra] = list(filter(lambda f: not f.name.startswith('.'), raw_dir.glob('*.xml'))) + except ValueError: + continue + if len(extra): + console.print(f"[warning]Skipping dir with multiple metadata files: '{raw_dir}'.") + continue + + output_path = metadata_path.with_suffix('.json') + if output_path.exists(): + # skip already created + continue + + try: + metadata = EmpadMetadata.from_xml(metadata_path) + console.print(f"Loaded '{metadata_path}'.") + except UnscannedMetadata: + console.print(f"[info]Skipping metadata '{metadata_path}' (not a raster scan dataset).") + continue + except Exception: + console.print(f"[error]Couldn't parse XML metadata '{metadata_path}'. Skipping.") + console.print_exception() + continue + + if prompt: + for (k, param) in params: + val = param.ask(metadata, console) + setattr(metadata, k, val) + + metadata.notes = Prompt.ask(f"Notes", default=None, console=console) + + with open(output_path, 'w') as f: + f.write(metadata.json(indent=4, exclude=set('path'))) + + console.print(f"Wrote metadata to '{output_path}'.") + + console.print(f"Processed all files.") + + +@click.command() +@click.option('--watch/--no-watch', default=False, help="Watch folder for changes (Linux only)") +@click.option('--prompt/--no-prompt', default=True, help="Prompt for additional metadata.") +@click.argument('folder', type=click.Path(exists=True, file_okay=False), required=False) +def process_metadata(folder: t.Union[str, Path, None], watch: bool = False, prompt: bool = True): + """ + Process metadata for all the raw datasets contained in FOLDER. + """ + console = default_console() + params = default_params() + + if folder is None: + folder = Path('.') + path = Path(folder).resolve() + + console.print(f"Processing dir {path}") + process_dir(path, params, console, prompt) + + if not watch: + return + + if not 'linux' in sys.platform: + console.print(f"[error]--watch is supported on Linux only.") + raise Exit(1) + try: + from inotify.adapters import InotifyTree + from inotify.constants import IN_CLOSE + except ImportError: + console.print(f"[error]Couldn't import inotify.\nInstall it for --watch support.") + raise Exit(1) + + notifier = InotifyTree(str(path), mask=IN_CLOSE) + console.print(f"Watching for experiments...") + + gen = notifier.event_gen(yield_nones=False) + for (_, event_types, _path, filename) in t.cast(t.Iterator[t.Tuple[t.Any, t.List[str], t.Any, str]], gen): + if not (any(t in event_types for t in ('IN_CLOSE_WRITE', 'IN_CLOSE_NOWRITE')) + and Path(filename).match('scan*.raw')): + continue + + print("Found changes, reprocessing.") + process_dir(path, params, console) + + +if __name__ == '__main__': + process_metadata() \ No newline at end of file From 35b49ff21cd8d0531fcd6d6abd5fe925b895d141 Mon Sep 17 00:00:00 2001 From: subangstrom Date: Fri, 15 Aug 2025 16:11:33 -0400 Subject: [PATCH 2/8] cli changes --- phaser/{extras => cli_tools}/calc_drift.py | 0 phaser/{extras => cli_tools}/calc_tilt.py | 0 phaser/{extras => cli_tools}/extract_params.py | 0 phaser/{extras => cli_tools}/main.py | 0 phaser/{extras => cli_tools}/prepare.py | 0 phaser/{extras => cli_tools}/process_metadata.py | 2 +- phaser/main.py | 4 ++++ 7 files changed, 5 insertions(+), 1 deletion(-) rename phaser/{extras => cli_tools}/calc_drift.py (100%) rename phaser/{extras => cli_tools}/calc_tilt.py (100%) rename phaser/{extras => cli_tools}/extract_params.py (100%) rename phaser/{extras => cli_tools}/main.py (100%) rename phaser/{extras => cli_tools}/prepare.py (100%) rename phaser/{extras => cli_tools}/process_metadata.py (99%) diff --git a/phaser/extras/calc_drift.py b/phaser/cli_tools/calc_drift.py similarity index 100% rename from phaser/extras/calc_drift.py rename to phaser/cli_tools/calc_drift.py diff --git a/phaser/extras/calc_tilt.py b/phaser/cli_tools/calc_tilt.py similarity index 100% rename from phaser/extras/calc_tilt.py rename to phaser/cli_tools/calc_tilt.py diff --git a/phaser/extras/extract_params.py b/phaser/cli_tools/extract_params.py similarity index 100% rename from phaser/extras/extract_params.py rename to phaser/cli_tools/extract_params.py diff --git a/phaser/extras/main.py b/phaser/cli_tools/main.py similarity index 100% rename from phaser/extras/main.py rename to phaser/cli_tools/main.py diff --git a/phaser/extras/prepare.py b/phaser/cli_tools/prepare.py similarity index 100% rename from phaser/extras/prepare.py rename to phaser/cli_tools/prepare.py diff --git a/phaser/extras/process_metadata.py b/phaser/cli_tools/process_metadata.py similarity index 99% rename from phaser/extras/process_metadata.py rename to phaser/cli_tools/process_metadata.py index c28b24f..27b52be 100644 --- a/phaser/extras/process_metadata.py +++ b/phaser/cli_tools/process_metadata.py @@ -11,7 +11,7 @@ from rich.theme import Theme from rich.console import Console -from ptycho_lebeau.metadata import EmpadMetadata, UnscannedMetadata +from hooks.io.empad import EmpadMetadata T = t.TypeVar('T') diff --git a/phaser/main.py b/phaser/main.py index 03125f4..0250fd9 100644 --- a/phaser/main.py +++ b/phaser/main.py @@ -87,6 +87,10 @@ def worker(url: str, quiet: bool = False): run_worker(url, quiet=quiet) +@cli.command('process-metadata', help="Process metadata for all the raw datasets contained in FOLDER.") +@click.argument('args', nargs=-1, type=click.UNPROCESSED) + + if __name__ == '__main__': cli() \ No newline at end of file From 844f3926f151682cfd1df6745689cf635725c93f Mon Sep 17 00:00:00 2001 From: "jmlebeau@gmail.com" Date: Thu, 21 Aug 2025 22:16:30 -0400 Subject: [PATCH 3/8] setting up the command structure and process metadata now working --- phaser/cli_tools/__init__.py | 10 + phaser/cli_tools/metadata.py | 490 +++++++++++++++++++++++++++ phaser/cli_tools/process_metadata.py | 2 +- phaser/main.py | 49 ++- 4 files changed, 547 insertions(+), 4 deletions(-) create mode 100644 phaser/cli_tools/__init__.py create mode 100644 phaser/cli_tools/metadata.py diff --git a/phaser/cli_tools/__init__.py b/phaser/cli_tools/__init__.py new file mode 100644 index 0000000..ba9d1bf --- /dev/null +++ b/phaser/cli_tools/__init__.py @@ -0,0 +1,10 @@ +import click +from .process_metadata import process_metadata + +@click.group() +def tools(): + """Toolbox of utilities.""" + pass + +# register subcommands here +tools.add_command(process_metadata) \ No newline at end of file diff --git a/phaser/cli_tools/metadata.py b/phaser/cli_tools/metadata.py new file mode 100644 index 0000000..e67e40b --- /dev/null +++ b/phaser/cli_tools/metadata.py @@ -0,0 +1,490 @@ +from __future__ import annotations + +import abc +from datetime import datetime, timedelta +from pathlib import Path +import csv +import typing as t + +from pydantic import BaseModel, Field, validator, root_validator +import numpy +from lxml import etree + + +def parse_version(s: str) -> t.Tuple[int, ...]: + def to_int(seg: str) -> int: + if not seg.isdigit(): + raise ValueError() + return int(seg) + + try: + return tuple(map(to_int, s.split('.'))) + except ValueError: + raise ValueError(f"Invalid version string '{s}'") from None + + +def _convert_to_si(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: + """ + Convert units from old metadata format to new metadata format (all SI units). + """ + + if values.get('defocus') is not None: + values['defocus'] *= 1e-9 # nm to m + + values['scan_step'] = tuple(v * 1e-10 for v in values['scan_step']) # A to m + + if values.get('beam_current') is not None: + values['beam_current'] *= 1e-12 # pA to A + + if values.get('exposure_time') is not None: + values['exposure_time'] *= 1e-3 # ms to s + if values.get('post_exposure_time') is not None: + values['post_exposure_time'] *= 1e-3 # ms to s + + # update version + values['version'] = "2.0" + + return values + + + +if not t.TYPE_CHECKING: + class Metadata(abc.ABC): + ... +else: + class Metadata(abc.ABC): + @property + @abc.abstractmethod + def file_type(self) -> str: + ... + + @property + @abc.abstractmethod + def name(self) -> str: + ... + + @property + @abc.abstractmethod + def path(self) -> t.Optional[Path]: + ... + + @property + @abc.abstractmethod + def voltage(self) -> float: + ... + + @property + @abc.abstractmethod + def conv_angle(self) -> float: + ... + + @property + @abc.abstractmethod + def defocus(self) -> float: + ... + + @property + @abc.abstractmethod + def diff_step(self) -> float: + ... + + @property + @abc.abstractmethod + def scan_rotation(self) -> float: + ... + + @property + @abc.abstractmethod + def scan_shape(self) -> t.Tuple[int, int]: + ... + + @property + @abc.abstractmethod + def scan_step(self) -> t.Tuple[float, float]: + ... + + @property + @abc.abstractmethod + def scan_correction(self) -> t.Optional[t.Tuple[t.Tuple[float, float], t.Tuple[float, float]]]: + ... + + @property + @abc.abstractmethod + def scan_positions(self) -> t.Optional[t.List[t.Tuple[float, float]]]: + ... + + @property + @abc.abstractmethod + def raw_filename(self) -> t.Optional[str]: + ... + + @property + @abc.abstractmethod + def crop(self) -> t.Optional[t.Tuple[int, int, int, int]]: + ... + + @abc.abstractmethod + def is_simulated(self) -> bool: + ... + + +class PyMultislicerMetadata(BaseModel, Metadata): + file_type: t.Literal['pyMultislicer_metadata'] = 'pyMultislicer_metadata' + + name: str + """Experiment name.""" + version: str = "2.0" + """Metadata version""" + + @root_validator(pre=False) + def _validate_version(cls, values): + version = parse_version(values['version']) + if version > (2, 0): + raise ValueError(f"Unsupported metadata version '{values['version']}'") + if version < (2, 0): + return _convert_to_si(cls, values) + return values + + path: t.Optional[Path] = Field(default=None, exclude=True) + """Current path to experimental folder.""" + + @classmethod + def parse_obj(cls, obj: t.Any, path: t.Union[str, Path, None] = None) -> 'PyMultislicerMetadata': + # ugly hack + meta: PyMultislicerMetadata = BaseModel.parse_obj.__func__(cls, obj) # type: ignore + meta.path = Path(path).parent if path is not None else None + return meta + + @classmethod + def parse_file(cls, path: t.Union[str, Path], *, content_type: t.Optional[str] = None, + encoding: str = 'utf8', proto: t.Optional[str] = None, allow_pickle: bool = False) -> PyMultislicerMetadata: + # ugly hack + meta: PyMultislicerMetadata = BaseModel.parse_file.__func__(cls, path, content_type=content_type, encoding=encoding, # type: ignore + proto=proto, allow_pickle=allow_pickle) + meta.path = Path(path).parent + return meta + + raw_filename: t.Optional[str] + """Raw 4DSTEM data filename.""" + + voltage: float + """Accelerating voltage (V).""" + + conv_angle: t.Optional[float] = None + """Convergence angle (mrad).""" + defocus: t.Optional[float] = None + """Defocus (m). Positive is overfocus.""" + diff_step: t.Optional[float] = None + """Diffraction pixel size (mrad/px).""" + + scan_rotation: float + """Scan rotation (degrees).""" + scan_shape: t.Tuple[int, int] + """Scan shape (x, y).""" + scan_fov: t.Tuple[float, float] + """Scan field of view (m).""" + scan_step: t.Tuple[float, float] + """Scan step (m/px).""" + + scan_correction: t.Optional[t.Tuple[t.Tuple[float, float], t.Tuple[float, float]]] = None + """Scan correction matrix, [x', y'] = scan_correction @ [x, y]""" + + scan_positions: t.Optional[t.List[t.Tuple[float, float]]] = None + """ + Scan position override (m). + Should be specified as a 1d list of (x, y) positions, in scan order. `scan_correction` is applied to these positions (if present). + """ + + def is_simulated(self) -> t.Literal[True]: + return True + + @property + def crop(self) -> None: + """Region scan is valid within, (min_x, max_x, min_y, max_y). Matlab-style slicing (1-indexed, inclusive).""" + return None + + +class UnscannedMetadata(BaseException): + ... + + +class EmpadMetadata(Metadata, BaseModel): + file_type: t.Literal['empad_metadata'] = 'empad_metadata' + + name: str + """Experiment name.""" + version: str = "2.0" + """Metadata version""" + + @root_validator(pre=False) + def _validate_version(cls, values): + version = parse_version(values['version']) + if version > (2, 0): + raise ValueError(f"Unsupported metadata version '{values['version']}'") + if version < (2, 0): + return _convert_to_si(cls, values) + return values + + raw_filename: t.Optional[str] + """Raw 4DSTEM data filename.""" + orig_path: t.Optional[Path] = None + """Original path to experimental folder.""" + path: t.Optional[Path] = Field(default=None, exclude=True) + """Current path to experimental folder.""" + author: t.Optional[str] = None + """Dataset author""" + time: str + """Acquisition time, formatted according to ISO 8061.""" + time_unix: float + """Acquisition time, in Unix time (seconds since epoch).""" + + def get_datetime(self) -> datetime: + return datetime.fromtimestamp(self.time_unix) + + bg_unix: t.Optional[float] = None + """Background acquisition time, in Unix time (seconds since epoch).""" + + has_bg: bool = False + """True if the background acquisition is <6 hours out of date.""" + + voltage: float + """Accelerating voltage (V).""" + conv_angle: t.Optional[float] = None + """Convergence angle (mrad).""" + defocus: t.Optional[float] = None + """Defocus (m). Positive is overfocus.""" + + camera_length: float + """Camera length (m).""" + diff_step: t.Optional[float] = None + """Est. diffraction pixel size (mrad/px).""" + + scan_rotation: float + """Scan rotation (degrees).""" + scan_shape: t.Tuple[int, int] + """Scan shape (x, y).""" + scan_fov: t.Tuple[float, float] + """Scan field of view (m).""" + scan_step: t.Tuple[float, float] + """Scan step (Angstrom/px).""" + + exposure_time: float + """Pixel exposure time (s).""" + post_exposure_time: float + """Pixel post-exposure time (s).""" + beam_current: t.Optional[float] = None + """Approx. beam current (A).""" + adu: t.Optional[float] = None + """Single-electron intensity (data units).""" + + scan_correction: t.Optional[t.Tuple[t.Tuple[float, float], t.Tuple[float, float]]] = None + """Scan correction matrix, [x', y'] = scan_correction @ [x, y]""" + + scan_positions: t.Optional[t.List[t.Tuple[float, float]]] = None + """ + Scan position override (m). + Should be specified as a 1d list of (x, y) positions, in scan order. `scan_correction` is applied to these positions (if present). + """ + + notes: t.Optional[str] = None + + crop: t.Optional[t.Tuple[int, int, int, int]] = None + """Region scan is valid within, (min_x, max_x, min_y, max_y). Matlab-style slicing (1-indexed, inclusive).""" + + def is_simulated(self) -> t.Literal[False]: + return False + + @classmethod + def parse_obj(cls, obj: t.Any, path: t.Union[str, Path, None] = None) -> 'EmpadMetadata': + # ugly hack + meta: EmpadMetadata = BaseModel.parse_obj.__func__(cls, obj) # type: ignore + meta.path = Path(path).parent if path is not None else None + return meta + + @classmethod + def parse_file(cls, path: t.Union[str, Path], *, content_type: t.Optional[str] = None, + encoding: str = 'utf8', proto: t.Optional[str] = None, allow_pickle: bool = False) -> EmpadMetadata: + # ugly hack + meta: EmpadMetadata = BaseModel.parse_file.__func__(cls, path, content_type=content_type, encoding=encoding, # type: ignore + proto=proto, allow_pickle=allow_pickle) + meta.path = Path(path).parent + return meta + + class Config: + allow_population_by_field_name = True + extra = 'forbid' + + json_encoders = { + # encode empty paths as empty string (for passing through to matlab) + Path: lambda p: "" if p == Path("") else str(p) + } + + @staticmethod + def from_xml(xml_path: t.Union[str, Path]) -> EmpadMetadata: + orig_path = Path(xml_path).parent + xml = t.cast(etree._ElementTree, etree.parse(str(xml_path))) # type: ignore + + def get(root, tag: str) -> etree._Element: + elem = root.find(tag) # type: ignore + if elem is None: + raise ValueError(f"Couldn't find tag '{tag}' in XML metadata.") + return elem + + def try_get(root, tag: str) -> t.Optional[etree._Element]: + return root.find(tag) # type: ignore + + root: etree._Element = xml.getroot() + timestamp = get(root, 'timestamp') + raw_filename = get(root, 'raw_file').attrib['filename'] + time_unix = float(timestamp.attrib['timestamp']) + + scan_params = get(root, 'scan_parameters') + if try_get(scan_params, 'series_count') is not None: + # not a scanned dataset + raise UnscannedMetadata() + + background = try_get(root, 'background_image') + exposure_time = float(get(root, 'exposure_time').text) * 1e-3 # ms to s + post_exposure_time = float(get(root, 'post_exposure_time').text) * 1e-3 # ms to s + + bg_unix = None if background is None else background.attrib.get('timestamp', None) + bg_unix = None if bg_unix is None else float(bg_unix) + # True if background is <6 hours out of date + has_bg = False if bg_unix is None else abs(time_unix - bg_unix) < 60*60*12 + + iom = get(root, 'iom_measurements') + scan_rotation = float(get(iom, 'scan_rotation').text) * 180. / numpy.pi + camera_length = float(get(iom, 'nominal_camera_length').text) # m + voltage = float(get(iom, 'high_voltage').text) # V + + scan_params = get(root, "scan_parameters[@mode='acquire']") + scan_shape = (get(scan_params, 'scan_resolution_x'), get(scan_params, 'scan_resolution_y')) + scan_shape = tuple(map(lambda elem: int(elem.text), scan_shape)) + scan_size = float(get(scan_params, 'scan_size').text) + + if abs(scan_size - 1.) > 1e-5: + print("Warning: scan_size != 1 has not been tested.") + + # all in m + fov = get(iom, 'full_scan_field_of_view') + scale_factor = float(get(fov, 'scale_factor').text) + fov = tuple(map(lambda elem: float(elem.text) * scan_size / scale_factor, (get(fov, 'x'), get(fov, 'y')))) + scan_step = tuple(map(lambda v: v / max(scan_shape), fov)) + + return EmpadMetadata( + name=root.attrib['name'], + orig_path=orig_path, + raw_filename=raw_filename, + time=timestamp.attrib['isoformat'], + time_unix=time_unix, + bg_unix= bg_unix, + has_bg=has_bg, + voltage=voltage, + camera_length=camera_length, + scan_rotation=scan_rotation, + scan_shape=t.cast(t.Tuple[int, int], scan_shape), + scan_fov=t.cast(t.Tuple[float, float], fov), + scan_step=t.cast(t.Tuple[float, float], scan_step), + exposure_time=exposure_time, + post_exposure_time=post_exposure_time, + ) + + +class AnyMetadata(BaseModel, Metadata): + __root__: t.Union[EmpadMetadata, PyMultislicerMetadata] = Field(discriminator='file_type') + + @validator('__root__', pre=True) + def _default_keys(cls, value: t.Any) -> t.Any: + # hack to specify default file_type + if isinstance(value, dict) and 'file_type' not in value: + value['file_type'] = 'empad_metadata' + # when parsing from a file, if version not found, assume it's 0.1 + if isinstance(value, dict) and 'version' not in value: + value['version'] = '0.1' + return value + + @classmethod + def parse_file(cls, path: t.Union[str, Path], *, content_type: t.Optional[str] = None, + encoding: str = 'utf8', proto: t.Optional[str] = None, allow_pickle: bool = False) -> AnyMetadata: + # ugly hack + meta: AnyMetadata = BaseModel.parse_file.__func__(cls, path, content_type=content_type, encoding=encoding, # type: ignore + proto=proto, allow_pickle=allow_pickle) + meta.__root__.path = Path(path).parent + return meta + + @classmethod + def parse_obj(cls, obj: t.Any, path: t.Union[str, Path, None] = None) -> 'Metadata': + # ugly hack + meta: AnyMetadata = BaseModel.parse_obj.__func__(cls, obj) # type: ignore + meta.path = Path(path).parent if path is not None else None # type: ignore + return meta + + def __getattr__(self, name): + return getattr(self.__root__, name) + + def __setattr__(self, name, value): + return setattr(self.__root__, name, value) + + +def datetime_to_excel(date: datetime) -> t.Tuple[int, float]: + offset = date - datetime(1900, 1, 1, 0, 0, 0, 0) + + # split into days and time + days, rem = divmod(offset, timedelta(days=1)) + # return number of days and fraction of day + return (days, rem / timedelta(days=1)) + + +_CSV_FMTS: t.Sequence[t.Tuple[str, t.Union[str, t.Callable[[EmpadMetadata], t.Any]]]] = [ + ('Date', lambda meta: datetime_to_excel(meta.get_datetime())[0]), + ('Time', lambda meta: datetime_to_excel(meta.get_datetime())[1]), + ('Name', 'name'), + ('Material', ''), + ('Taken By', 'author'), + ('Has bg', 'has_bg'), + ('Voltage (kV)', lambda meta: meta.voltage / 1e3), #f"{meta.voltage / 1e3:.1f}"), + ('Scan size x (px)', lambda meta: meta.scan_shape[0]), + ('Scan size y (px)', lambda meta: meta.scan_shape[1]), + ('Scan FOV x (m)', lambda meta: meta.scan_fov[0]), + ('Scan FOV y (m)', lambda meta: meta.scan_fov[1]), + ('Scan step x (A/px)', lambda meta: meta.scan_step[0]*1e10), # m to A + ('Scan step y (A/px)', lambda meta: meta.scan_step[1]*1e10), # m to A + ('Scan rot (deg)', 'scan_rotation'), + ('Mag (x)', ''), + ('Overfocus (CW, nm)', lambda meta: meta.defocus*1e9 if meta.defocus is not None else ""), # m to nm + ('Conv. (mrad)', 'conv_angle'), + ('Camera length (mm)', lambda meta: meta.camera_length * 1e3), + ('Diff. pixel size (mrad/px)', 'diff_step'), + ('Path', lambda meta: str(meta.path.absolute()) if meta.path is not None else ""), + ('Notes', 'notes'), +] + +def to_csv(path: t.Union[str, Path, t.TextIO], metadata: t.Iterable[EmpadMetadata]): + if isinstance(path, (str, Path)): + close = True + f = open(path, 'w', encoding='utf-8') + else: + close = False + f = path + + try: + writer = csv.writer(f, dialect='excel', delimiter=',', quoting=csv.QUOTE_NONNUMERIC) + writer.writerow(name for (name, fmt) in _CSV_FMTS) + writer.writerows(( + getattr(meta, fmt, "") if isinstance(fmt, str) else fmt(meta) + for (name, fmt) in _CSV_FMTS + ) for meta in metadata + ) + finally: + if close: + f.close() + + +# out object key, column to search for, parse +_CSV_PARSE_COLS: t.Dict[str, t.Union[str, t.Tuple[str, t.Callable[[str], t.Any]]]] = { + 'name': 'name', +} + + +def from_csv(f: t.TextIO) -> t.Sequence[Metadata]: + raise NotImplementedError() diff --git a/phaser/cli_tools/process_metadata.py b/phaser/cli_tools/process_metadata.py index 27b52be..be51a10 100644 --- a/phaser/cli_tools/process_metadata.py +++ b/phaser/cli_tools/process_metadata.py @@ -11,7 +11,7 @@ from rich.theme import Theme from rich.console import Console -from hooks.io.empad import EmpadMetadata +from .metadata import EmpadMetadata, UnscannedMetadata T = t.TypeVar('T') diff --git a/phaser/main.py b/phaser/main.py index 0250fd9..6aed60f 100644 --- a/phaser/main.py +++ b/phaser/main.py @@ -1,7 +1,7 @@ from pathlib import Path import sys import typing as t - +from .cli_tools import tools import click @click.group() @@ -87,8 +87,51 @@ def worker(url: str, quiet: bool = False): run_worker(url, quiet=quiet) -@cli.command('process-metadata', help="Process metadata for all the raw datasets contained in FOLDER.") -@click.argument('args', nargs=-1, type=click.UNPROCESSED) +# @cli.command(cls=MainCommand, commands=dict((v, v) for v in +# ('prepare', 'run', 'view_raw', 'view_prepared', 'view_output', +# 'process_metadata', 'extract_params', 'to_csv', 'calc_drift', 'calc_tilt') +# )) + +cli.add_command(tools) + +# # --- Action group: tools ----------------------------------------------------- +# @cli.group(invoke_without_command=False) # require a subcommand +# @click.pass_context +# def tools(ctx: click.Context): +# """Toolbox of utility subcommands.""" + # With invoke_without_command=False, Click will show usage if no subcommand. + +# @tools.add_command(process_metadata) + +# @tools.command('prepare') + +# @tools.command('view_raw') + +# # @tools.command('view_prepared') + +# # @tools.command('view_output') + +# @tools.command('process_metadata') + +# @tools.command('extract_params') +# @tools.command('to_csv') +# @tools.command('calc_drift') +# @tools.command('calc_tilt') + + +# @tools.command("align-data") +# @click.argument("input", type=click.Path(exists=True, path_type=Path)) +# @click.argument("output", type=click.Path(path_type=Path)) +# @click.option( +# "--method", +# type=click.Choice(["fft", "xcorr"], case_sensitive=False), +# default="fft", +# show_default=True, +# help="Alignment method.", +# ) +# def align_data_cmd(input: Path, output: Path, method: str): +# """Align image/data files and write the result.""" +# align_data(input, output, method) From 845927f9cd42a319761ad84dd2b32649e70e9a0c Mon Sep 17 00:00:00 2001 From: "jmlebeau@gmail.com" Date: Thu, 21 Aug 2025 22:44:33 -0400 Subject: [PATCH 4/8] calc drift and tilt now working --- phaser/cli_tools/__init__.py | 9 +- phaser/cli_tools/calc_drift.py | 2 +- phaser/cli_tools/calc_tilt.py | 2 +- phaser/cli_tools/models.py | 291 +++++++ phaser/cli_tools/params.py | 1401 ++++++++++++++++++++++++++++++++ phaser/cli_tools/prepare.py | 6 +- phaser/cli_tools/util.py | 37 + phaser/main.py | 45 +- 8 files changed, 1743 insertions(+), 50 deletions(-) create mode 100644 phaser/cli_tools/models.py create mode 100644 phaser/cli_tools/params.py create mode 100644 phaser/cli_tools/util.py diff --git a/phaser/cli_tools/__init__.py b/phaser/cli_tools/__init__.py index ba9d1bf..5320d96 100644 --- a/phaser/cli_tools/__init__.py +++ b/phaser/cli_tools/__init__.py @@ -1,5 +1,9 @@ import click from .process_metadata import process_metadata +from .prepare import prepare + +from .calc_tilt import calc_tilt +from .calc_drift import calc_drift @click.group() def tools(): @@ -7,4 +11,7 @@ def tools(): pass # register subcommands here -tools.add_command(process_metadata) \ No newline at end of file +tools.add_command(process_metadata) +tools.add_command(prepare) +tools.add_command(calc_tilt) +tools.add_command(calc_drift) diff --git a/phaser/cli_tools/calc_drift.py b/phaser/cli_tools/calc_drift.py index 69e1e9b..ed5c350 100644 --- a/phaser/cli_tools/calc_drift.py +++ b/phaser/cli_tools/calc_drift.py @@ -18,7 +18,7 @@ from rich.prompt import Prompt, FloatPrompt, Confirm #from rich.pretty import Pretty -from .raw import load_4d +from phaser.io.empad import load_4d from .metadata import AnyMetadata diff --git a/phaser/cli_tools/calc_tilt.py b/phaser/cli_tools/calc_tilt.py index 5739c81..0ab851a 100644 --- a/phaser/cli_tools/calc_tilt.py +++ b/phaser/cli_tools/calc_tilt.py @@ -11,7 +11,7 @@ from rich.console import Console from rich.prompt import Prompt, FloatPrompt -from .raw import load_4d +from phaser.io.empad import load_4d from .metadata import AnyMetadata diff --git a/phaser/cli_tools/models.py b/phaser/cli_tools/models.py new file mode 100644 index 0000000..74a8fa6 --- /dev/null +++ b/phaser/cli_tools/models.py @@ -0,0 +1,291 @@ +from __future__ import annotations +import math +import typing as t +from pathlib import Path + +from pydantic import BaseModel, PrivateAttr, Field, validator, root_validator +from pydantic.main import ModelMetaclass +from pydantic.types import NonNegativeInt, ConstrainedList +from pydantic.typing import get_sub_types, is_none_type +from pydantic.generics import GenericModel +from pydantic.fields import ModelField + +Num = t.TypeVar('Num', bound=t.Union[int, float]) +T = t.TypeVar('T') +U = t.TypeVar('U') + +def _fix_field_allow_none(field: ModelField): + if field.allow_none: + return + + sub_fields = [field] if field.sub_fields is None else field.sub_fields + for sub_field in sub_fields: + if sub_field.allow_none: + field.allow_none = True + break + + for subtype in get_sub_types(sub_field.type_): + if is_none_type(subtype) or (isinstance(subtype, type) and + issubclass(subtype, WrapperModel) and + subtype.__fields__['__root__'].allow_none): + field.allow_none = sub_field.allow_none = True + break + else: + continue + break + + +class ModelMeta(ModelMetaclass): + def __new__(cls, name, bases, namespace, **kwargs): + ty: BaseModel = super(ModelMeta, cls).__new__(cls, name, bases, namespace, **kwargs) + # workaround for pydantic assuming inner types don't accept None + for field in ty.__fields__.values(): + _fix_field_allow_none(field) + return ty + + +class ModelConfig(BaseModel, metaclass=ModelMeta): + class Config: + allow_population_by_field_name = True + extra = 'forbid' + allow_mutation = False + frozen = True + + json_encoders = { + # encode empty paths as empty string (for passing through to matlab) + Path: lambda p: "" if p == Path("") else str(p) + } + + +class WrapperModel(GenericModel): + __root__: t.Any + _type_params: t.Tuple[type, ...] = () + + @validator('__root__', pre=True) + def validator(cls, value): + if isinstance(value, WrapperModel): # type: ignore + return value.__root__ + return value + + def __class_getitem__(cls, params: t.Union[type, t.Tuple[type, ...]]) -> type: + new_cls = t.cast(t.Type[WrapperModel], super(WrapperModel, cls).__class_getitem__(params)) + new_cls._type_params = (*cls._type_params, *params) if isinstance(params, tuple) else (*cls._type_params, params) + + if not new_cls.__concrete__: + return new_cls + + # workaround for pydantic assuming inner types don't accept None + _fix_field_allow_none(new_cls.__fields__['__root__']) + + return new_cls + + def __init__(self, val: t.Any = None, **kwargs: t.Any) -> None: + if val is not None: + kwargs.update(__root__=val) + return super().__init__(**kwargs) + + def __repr_str__(self, join_str: str) -> str: + return repr(self.__root__) + + def dict(self, **kwargs) -> t.Dict[str, t.Any]: + d = super().dict(**kwargs) + if isinstance(d['__root__'], dict): + return d['__root__'] + return d + + +if t.TYPE_CHECKING: + InitNonNegativeInt: t.TypeAlias = NonNegativeInt + InitNum: t.TypeAlias = Num +else: + InitNonNegativeInt: t.TypeAlias = t.Optional[NonNegativeInt] + InitNum: t.TypeAlias = t.Optional[Num] + + +class Range(ModelConfig, GenericModel, t.Generic[Num]): + start: Num + end: Num + + # n and step are optional but are set by validation, so we need some TYPE_CHECKING hackery + n: InitNonNegativeInt = Field(default=None) + step: InitNum = Field(default=None) + + @root_validator(pre=False, skip_on_failure=True) + def _validate(cls, values): + if values.get('step') is None: + if values.get('n') is not None and values['n'] > 1: + if not isinstance(values['start'], float) and (values['end'] - values['start']) % (values['n'] - 1): + raise ValueError("Range must be evenly divisible by 'n'") + values['step'] = (values['end'] - values['start']) / (values['n'] - 1) + else: + values['step'] = 1. + if isinstance(values['step'], int): + values['step'] = int(values['step']) + else: + if values.get('n') is not None: + raise ValueError("Either 'n' and 'step' may be specified, but not both.") + + if values.get('n') is None: + if values['start'] > values['end']: + values['n'] = 0 + else: + values['n'] = 1 + math.ceil((values['end'] - values['start']) / values['step'] - 1e-6) + + return values + + def __len__(self) -> int: + return self.n + + def __iter__(self) -> t.Iterator[Num]: + if self.n == 0: + return + val: Num = self.start + for _ in range(self.n-1): + yield val + val = t.cast(Num, val + self.step) + yield self.end + + +class OptionalRange(ModelConfig, BaseModel): + start: t.Optional[float] + end: t.Optional[float] + + step: t.Optional[float] + n: t.Optional[int] + + @root_validator(pre=False, skip_on_failure=True) + def _validate(cls, values): + if values['step'] is None and values['n'] is None: + raise ValueError("Either 'n' or 'step' must be specified.") + if values['step'] is not None and values['n'] is not None: + raise ValueError("Either 'n' and 'step' may be specified, but not both.") + + return values + + def to_range(self, start: float, end: float) -> Range[float]: + return Range[float].parse_obj({ + 'start': self.start or start, + 'end': self.end or end, + 'step': self.step, + 'n': self.n, + }) + + +class ListNotEmpty(ConstrainedList, t.Generic[T]): + min_items = 1 + + +class ValueOrRange(WrapperModel, t.Generic[Num]): + __root__: t.Union[Num, Range[Num]] + + def __len__(self) -> int: + if isinstance(self.__root__, (int, float)): + return 1 + return self.__root__.__len__() + + def __iter__(self) -> t.Iterator[Num]: + if isinstance(self.__root__, (int, float)): + yield self.__root__ + else: + yield from self.__root__ + + +class ListOrNone(WrapperModel, t.Generic[T]): + __root__: t.Union[None, t.List[T]] + + def __len__(self) -> int: + if self.__root__ is None: + return 1 + return self.__root__.__len__() + + def __iter__(self) -> t.Iterator[T]: + if self.__root__ is None: + return + yield from self.__root__ + + +class ValueOrList(WrapperModel, t.Generic[T]): + __root__: t.Union[T, ListNotEmpty[T]] + _is_T: bool = PrivateAttr() + + def _init_private_attributes(self): + if isinstance(self._type_params, tuple): + ty = t.Any if len(self._type_params) == 0 else self._type_params[0] + else: + ty = self._type_params + + field = self.__fields__['__root__'] + subfield = field._create_sub_type(type_=ty, name=field.name + '_0') + subfield.allow_none = field.allow_none + # check if __root__ parses as T without error + self._is_T = subfield.validate(self.__root__, {}, loc='')[1] is None + + super()._init_private_attributes() + + def map(self, f: t.Callable[[T], U], ty: t.Type[U]) -> ValueOrList[U]: + if self._is_T: + inner = f(t.cast(T, self.__root__)) + else: + inner = list(map(f, t.cast(t.List[T], self.__root__))) + return ValueOrList[ty].parse_obj(inner) + + def __len__(self) -> int: + if self._is_T: + return 1 + return t.cast(t.List[T], self.__root__).__len__() + + def __iter__(self) -> t.Iterator[T]: + if self._is_T: + yield t.cast(T, self.__root__) + else: + yield from t.cast(t.List[T], self.__root__) + + +class ValueListOrNone(WrapperModel, t.Generic[T]): + __root__: t.Union[None, T, t.List[T]] + _is_T: bool = PrivateAttr() + + def _init_private_attributes(self): + if self.__root__ is None: + self._is_T = False + else: + if isinstance(self._type_params, tuple): + ty = t.Any if len(self._type_params) == 0 else self._type_params[0] + else: + ty = self._type_params + + field = self.__fields__['__root__'] + subfield = ModelField(name=field.name, type_=ty, class_validators=None, model_config=field.model_config) + # check if __root__ parses as T without error + self._is_T = subfield.validate(self.__root__, {}, loc='')[1] is None + super()._init_private_attributes() + + def __len__(self) -> int: + if self.__root__ is None: + return 0 + if self._is_T: + return 1 + return t.cast(t.List[T], self.__root__).__len__() + + def __iter__(self) -> t.Iterator[T]: + if self.__root__ is None: + return + if self._is_T: + yield t.cast(T, self.__root__) + else: + yield from t.cast(t.List[T], self.__root__) + + +class ValueListOrRange(WrapperModel, t.Generic[Num]): + __root__: t.Union[Num, Range[Num], t.List[Num]] + + def __len__(self) -> int: + if isinstance(self.__root__, (int, float)): + return 1 + return self.__root__.__len__() + + def __iter__(self) -> t.Iterator[Num]: + if isinstance(self.__root__, (int, float)): + yield self.__root__ + else: + yield from self.__root__ diff --git a/phaser/cli_tools/params.py b/phaser/cli_tools/params.py new file mode 100644 index 0000000..eba3c1c --- /dev/null +++ b/phaser/cli_tools/params.py @@ -0,0 +1,1401 @@ +from __future__ import annotations +import math +import os +import sys +from pathlib import Path +from dataclasses import dataclass, field +from itertools import product +import typing as t + +import numpy +from pydantic import Field, validator, root_validator, PrivateAttr +from pydantic.types import NonNegativeInt, PositiveInt + +from .models import ModelConfig, ValueOrList, WrapperModel +from .metadata import Metadata, AnyMetadata + + +class FormatError(Exception): + def __init__(self, msg: str): + self.msg = msg + + def __str__(self) -> str: + return self.msg + + +def _iter_dict(params: t.Dict[str, t.Any], keys: t.Iterable[str], sparse: bool = False) -> t.Iterable[t.Dict[str, t.Any]]: + """ + Return copies of 'params', iterating through each key in `keys`. + If `sparse`, sparse (plus-shaped) combinations of iterators are returned. + The first value of each iterator is used as the default + Otherwise, dense (grid-shaped, cartesian product) combinations are returned. + """ + # filter keys which aren't in params + keys = tuple(k for k in keys if k in params) + vals = tuple(params[k] for k in keys) + + # dense (cartesian product) + if not sparse: + # dense (cartesian product) + for update_vals in product(*vals): + d = params.copy() + d.update(zip(keys, update_vals)) + yield d + + return + + # TODO sparse update must be done with deep keys + if math.prod(len(v) for v in vals) == 0: + # empty, return + return + + # construct iterator objects for every non-missing pair + # propagate 'sparse' to child iterators + iters = tuple(v.iter(sparse=True) if hasattr(v, 'iter') else iter(v) for v in vals) + + # use the first member of each iterator to define the center point + base_vals = tuple((k, next(it)) for (k, it) in zip(keys, iters)) + params = params.copy() + params.update(base_vals) + # start by yielding center point + yield params + + # perturb one key at a time + for (update_key, it) in zip(keys, iters): + for update_val in it: + d = params.copy() + d[update_key] = update_val + yield d + + +@dataclass +class SaveRecord: + """Stores which (and how many) files have already been saved, and avoids duplicates.""" + paths: t.Dict[str, int] = field(default_factory=dict) + names: t.Dict[str, int] = field(default_factory=dict) + i: int = 0 + + def deduplicate_path(self, base_path: str) -> str: + if base_path in self.paths: + # handle duplicate paths + self.paths[base_path] += 1 + base_path += f'_{self.paths[base_path]}' + else: + self.paths[base_path] = 0 + return base_path + + def deduplicate_name(self, base_name: str) -> str: + if base_name in self.names: + # handle duplicate names + self.names[base_name] += 1 + base_name += f'_{self.names[base_name]}' + else: + self.names[base_name] = 0 + return base_name + + +class QueueParams(ModelConfig): + lockfile: bool = False + """If true, writes a lock file which prevents multiple reconstructions from running.""" + + +class DetectorParams(ModelConfig): + name: t.Literal['empad', 'empad_lebeau'] = 'empad_lebeau' + """Detector type""" + check_2_detpos: t.Literal[None] = None + data_prefix: str = "" + binning: bool = False + upsampling: t.Union[NonNegativeInt, bool] = False + burst_frames: PositiveInt = 1 + + circ_mask: t.Optional[int] = None + """If specified, applies a circular mask of the radius to each diffraction pattern.""" + + bg_sub: t.Union[bool, float] = False + """Subtract background intensity from diffraction patterns.""" + + crop: t.Optional[t.Tuple[int, int, int, int]] = None + """ + Scan dimensions to crop to (min_x, max_x, min_y, max_y). + Matlab-style slicing, so [1, 128, 1, 128] == an entire 128x128 scan + """ + + step: t.Optional[t.Tuple[int, int]] = None + """ + Only use every n x m (x, y) scan positions. Useful for testing reconstruction at multiples of step size. + """ + + tile: t.Optional[t.Tuple[int, int]] = None + """ + Tile the scan n x m (x, y) times prior to reconstruction. Useful for increasing the size of simulated data. + + Care must be taken to ensure the scan is periodic + """ + + fill_nan: bool = False + """ + Whether to check for and fill NaN values in raw dataset + """ + + sim: bool = False + """Whether data is simulated (single-electron intensity)""" + + beam_dose: t.Optional[float] = None + """Total electron dose (in e/A^2) to scale simulated data by. Overrides `beam_current` if specified.""" + beam_current: float = 30. + """Beam current (in pA) to scale simulated data by.""" + + psf_sigma: t.Optional[float] = None + """ + Apply Gaussian point spread function to the raw data before reconstruction. + Mainly useful for simulating a non-ideal detector. + """ + + poisson: bool = False + """ + Apply Poisson noise to the raw data before reconstruction. + If using this with simulated data, make sure `beam_current` + is set correctly. + """ + + +class PrepareParams(ModelConfig): + data_preparator: t.Literal['matlab_aps'] = 'matlab_aps' + auto_prepare_data: bool = True + """If true, prepare dataset from raw measurements""" + force_preparation_data: bool = False + """Prepare dataset from raw measurements even if prepared data exists""" + store_prepared_data: bool = True + """Store prepared data to HDF5""" + prepare_data_function: t.Literal[""] = "" + #auto_center_data: bool = False + #"""Try to automatically center cbed center-of-mass""" + + +class RasterScanParams(ModelConfig): + type: t.Literal['raster'] = 'raster' + + nx: PositiveInt + """Number of scan positions in x.""" + ny: PositiveInt + """Number of scan positions in y""" + + step: float = Field(None) + """Scan step size (angstroms)""" + step_size_x: float = Field(None) + """Scan X step size (angstroms)""" + step_size_y: float = Field(None) + """Scan Y step size (angstroms)""" + + @root_validator + def validate_scan_step(cls, values): + if values.get('step') is not None: + if values.get('step_size_x') is None: + values['step_size_x'] = values['step'] + if values.get('step_size_y') is None: + values['step_size_y'] = values['step'] + elif values.get('step_size_x') is not None: + values['step'] = values['step_size_x'] + if values.get('step_size_y') is None: + values['step_size_y'] = values['step_size_x'] + else: + raise ValueError("Missing parameter 'step'/'step_size_x'") + return values + + custom_flip: t.Tuple[bool, bool, bool] = (False, False, False) + """Custom data flip (left-right, up-down, tranpose).""" + + step_randn_offset: float = 0. + """Random offset to apply to raster positions (relative to step size?)""" + + roi: t.Optional[t.Tuple[int, int, int, int]] = None + """ + Real-space dimensions to crop to (min_x, max_x, min_y, max_y). + Currently broken + """ + + +class CustomScanParams(ModelConfig): + type: t.Literal['custom_GPU'] = 'custom_GPU' + + custom_positions_source: Path + """Niter.mat file to load positions from. Can be specified relative to base_path.""" + + +class ListScanParams(ModelConfig): + type: t.Literal['list'] = 'list' + + scan_positions: t.List[t.Tuple[float, float]] + """List of scan positions.""" + + def is_default(self) -> bool: + return False + + def apply_metadata(self, metadata: t.Optional[Metadata] = None) -> ListScanParams: + return self + + def iter(self, sparse: bool = False) -> t.Iterator[ListScanParams]: + yield self + + def __iter__(self) -> t.Iterator[ListScanParams]: + return self.iter() + + +class ModelParams(ModelConfig): + object_type: t.Literal['rand', 'amplitude'] = 'rand' + + probe_alpha_max: float + """Model probe convergence angle (mrad)""" + probe_df: float + """Model probe defocus (angstroms, overfocus is negative).""" + probe_C3: float = 0. + """Model probe C3 (angstrom)""" + probe_C5: float = 0. + """Model probe C5 (angstrom)""" + probe_C7: float = 0. + """Model probe C7 (angstrom)""" + probe_f_a2: float = 0. + probe_theta_a2: float = 0. + probe_f_a3: float = 0. + probe_theta_a3: float = 0. + probe_f_c3: float = 0. + probe_theta_c3: float = 0. + + +class ModelParamSet(ModelParams): + probe_alpha_max: ValueOrList[float] + probe_df: ValueOrList[float] + + _iterable_keys = ('probe_df', 'probe_alpha_max') + _len: t.Optional[int] = PrivateAttr(None) + + def __len__(self) -> int: + if self._len is None: + self._len = math.prod(len(getattr(self, k) or (None,)) for k in self._iterable_keys) + return self._len + + def iter(self, sparse: bool = False) -> t.Iterator[ModelParams]: + params = dict(filter(lambda t: t[0] in self.__fields_set__, self.__dict__.items())) + + for d in _iter_dict(params, self._iterable_keys, sparse): + yield ModelParams.parse_obj(d) + + def __iter__(self) -> t.Iterator[ModelParams]: + return self.iter() + + +class IOParams(ModelConfig): + default_mask_file: str = "" + default_mask_type: t.Literal['binary', 'indices'] = 'binary' + file_compression: NonNegativeInt = 0 + data_compression: PositiveInt = 3 + load_prep_pos: bool = False + """Load positions from prepared data, ignoring metadata""" + + +class SlicesInterp(ModelConfig): + old_n: PositiveInt + """# of layers in the last engine to interpolate from.""" + + new_n: PositiveInt + """# of layers in the new engine to interpolate to.""" + + def layers(self) -> t.List[float]: + return list(numpy.linspace(1, self.old_n, self.new_n, endpoint=True)) + + +class Slices(ModelConfig): + n: PositiveInt + """Number of layers. 1 for single-slice ptychography.""" + delta_z: float = Field(None) + """Slice thickness (in angstroms).""" + thickness: float = Field(None) + """Total object thickness (in angstroms).""" + + def delta_zs(self) -> t.List[float]: + return [self.delta_z] * self.n + + def zs(self) -> t.List[float]: + return [self.delta_z * i for i in range(self.n)] + + @root_validator + def validate_slices(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: + if values.get('delta_z') is not None: + if values.get('thickness') is not None: + raise ValueError("'delta_z' and 'thickness' can't both be specified.") + values['thickness'] = values['delta_z'] * values['n'] + #values['delta_z'] = ValueOrList[float].parse_obj([values['delta_z']] * values['n']) + elif values.get('thickness') is not None: + #values['delta_z'] = ValueOrList[float].parse_obj([values['thickness'] / values['n']] * values['n']) + values['delta_z'] = values['thickness'] / values['n'] + else: + raise ValueError("Either 'delta_z' or 'thickness' must be specified.") + + return values + + +class SlicesSet(ModelConfig): + n: ValueOrList[PositiveInt] + """Number of layers. 1 for single-slice ptychography.""" + delta_z: ValueOrList[float] = Field(None) + """Slice thickness (in angstroms).""" + thickness: ValueOrList[float] = Field(None) + """Total thickness (in angstroms).""" + + @root_validator + def validate_slices(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: + if values.get('delta_z') is None and values.get('thickness') is None: + raise ValueError("Either 'delta_z' or 'thickness' must be specified.") + elif values.get('delta_z') is not None and values.get('thickness') is not None: + raise ValueError("'delta_z' and 'thickness' can't both be specified.") + return values + + _iterable_keys = ('n', 'delta_z', 'thickness') + _len: t.Optional[int] = PrivateAttr(None) + + def __len__(self) -> int: + if self._len is None: + self._len = math.prod(len(getattr(self, k) or (None,)) for k in self._iterable_keys) + return self._len + + def iter(self, sparse: bool = False) -> t.Iterator[Slices]: + params = dict(filter(lambda t: t[0] in self.__fields_set__, self.__dict__.items())) + for d in _iter_dict(params, self._iterable_keys, sparse): + yield Slices.parse_obj(d) + + def __iter__(self) -> t.Iterator[Slices]: + return self.iter() + + +class EngineParams(ModelConfig): + name: t.Literal['GPU_MS'] = 'GPU_MS' + """Reconstruction engine. GPU_MS = GPU multislice""" + fout: str = "{method}_{opt_errmetric}_p{probe_modes}_g{grouping}_step{i}" + """Output path. Can be specified relative to base_path.""" + use_gpu: bool = True + """Use GPU for reconstruction.""" + keep_on_gpu: bool = True + """Keep data + projections on GPU""" + compress_data: bool = False + """Use online memory compression""" + gpu_id: t.Optional[int] = None + """GPU id to use""" + check_gpu_load: bool = True + """Check GPU memory before starting engine""" + + number_iterations: PositiveInt = 200 + """Number of iterations for selected method""" + + save_results_every: t.Optional[PositiveInt] = 10 + """Save partial results every n iterations.""" + plot_results_every: t.Optional[PositiveInt] = None + """Plot partial results every n iterations.""" + + auto_center_data: bool = False + """Center diffraction patterns using average center of mass.""" + + slices: t.Optional[Slices] = None + + delta_z: t.List[float] = Field(None) + """List of slice z""" + + @validator('delta_z') + def validate_delta_z(cls, v, values: t.Dict[str, t.Any], **kwargs) -> t.Any: + if v is not None: + return v + if values.get('slices') is not None: + return values['slices'].delta_zs() + raise ValueError("Either 'slices' or 'delta_z' must be specified.") + + # multislice options + regularize_layers: float = Field(1., ge=0., le=1.) + """Apply regularization on the reconstructed object layers. 0 = no regularization, 0.01 = weak regularization.""" + init_layer_preprocess: t.Literal['all', 'avg', 'avg1', 'interp'] = 'all' + """ + How to preprocess layers from previous reconstruction step. + 'all' (default): No pre-processing + 'avg': Average layers together + 'avg1': Average layers and keep one (the rest initialized with 'init_layer_append_mode') + 'interp': Interpolate old object layers to 'init_slices'/'init_layer_interp'. + """ + + init_slices: t.Optional[SlicesInterp] = None + """ + Slices to interpolate with. Required when 'init_layer_preprocess: interp'. + """ + + init_layer_interp: t.List[float] = Field(default_factory=list) + """List of slice z positions. Initialized from 'init_slices' in python.""" + + init_layer_scaling_factor: float = 1. + """Scaling factor applied to phase of previous object layers. Useful when 'delta_z' is changed.""" + + @validator('init_layer_interp') + def validate_init_layer_interp(cls, v, values: t.Dict[str, t.Any], **kwargs) -> t.Any: + if v is None or len(v) == 0: + if values.get('init_slices') is not None: + return values['init_slices'].layers() + elif values.get('init_layer_preprocess') == 'interp': + raise ValueError("'init_slices' or 'init_layer_interp' required when 'init_layer_preprocess: interp'") + return v + + tilt_x: float = 0. + """ + Propagator x tilt (in mrad). + + Positive causes a shift rightwards (+x) when moving into the sample. + Unrelated to `sample_rotation_angles` and `apply_tilted_plane_correction`. + """ + tilt_y: float = 0. + """ + Propagator y tilt (in mrad). + + Positive causes a shift downwards (+y) when moving into the sample. + Unrelated to `sample_rotation_angles` and `apply_tilted_plane_correction`. + """ + + init_layer_append_mode: t.Literal['vac', 'edge', 'avg'] = 'vac' + """ + How to initialize extra object layers. + 'vac' (default): Add vacuum layers + 'edge': Copy edge layers + 'avg': Copy average of layers + """ + + preshift_ML_probe: bool = False + """ + If true, the provided probe is at the center of the object. If false, it is at the top of object. + + Doesn't work in combination with ``sample_rotation_angles``. + """ + + asize_presolve: t.Optional[t.Tuple[PositiveInt, PositiveInt]] = None + """Crop data to get low resolution estimate for next engine""" + align_shared_objects: bool = False + """Whether to align shared objects""" + + method: t.Literal['MLs', 'MLc', 'DM', 'ePIE', 'hPIE'] = 'MLs' + """Optimization method. MLs = maximum-likelihood sparse, MLc = maximum-likelihood compact, DM = difference map""" + + opt_errmetric: t.Literal['L1', 'poisson'] = 'L1' + """Optimization likelihood metric""" + grouping: PositiveInt = 64 + """Size of processed blocks. Memory/efficiency trade-off, but smaller may lead to faster convergence for MLs""" + probe_modes: PositiveInt = 8 + """Number of coherent probe modes""" + object_modes: PositiveInt = 1 + """Number of coherent object modes""" + object_change_start: PositiveInt = 1 + """Start updating object potential at this iteration.""" + probe_change_start: PositiveInt = 20 + """Start updating probe at this iteration.""" + probe_position_search: PositiveInt = 50 + """Iteration number to start probe position update at""" + + reg_mu: float = 0. + """Object smoothness regularization constant. 0 for no regularization""" + delta: float = 0. + """Press values to zero out of the illumination area in the object, usually 1e-2 is enough""" + positivity_constraint_object: float = 0. + """Enforce weak positivity in object. 1e-2 should be enough.""" + amplitude_threshold_object: float = 1.5 + """Clamp object amplitude. Set to 'inf' to disable.""" + + apply_multimodal_update: bool = False + """If true, update all probe modes. If false, only update first probe mode.""" + probe_backpropagate: float = 0. + """Backpropagation distance for the probe mask, 0 == apply in the object plane""" + probe_support_radius: t.Union[None, float] = None + """Normalized radius of circular support.""" + probe_support_fft: bool = False + """Assume that there is not illumination intensity out of the central FZP zone.""" + probe_support_tem: bool = False + """Limit the reconstructed probe support based on initial_probe. Added by ZC.""" + + # Orthogonal Probe Relaxation + variable_probe: bool = False + """Enable OPR (vary probe modes throughout a single scan).""" + variable_probe_modes: PositiveInt = 1 + """Number of OPR probe modes.""" + variable_probe_smooth: NonNegativeInt = 0 + """Order of polynomial fit used to smooth spatial evolution of OPR modes. 0 = no smoothing""" + variable_intensity: bool = False + """Whether to account for changes in probe intensity.""" + + # PIE / ML reconstruction parameters + beta_object: float = Field(1., gt=0., le=1.) + """Object step size, <= 1""" + beta_probe: float = Field(1., gt=0., le=1.) + """Probe step size, <= 1""" + delta_p: float = 0.1 + """LSQ damping constant""" + beta_LSQ: float = Field(0.9, gt=0., le=1.) + """Least-squares step size. Should be ~0.5 for noisy data, ~0.9 for clean data.""" + + # MLc reconstruction + momentum: float = 0. + """Add momentum to the MLc method""" + accelerated_gradients_start: int = 1000000 + """Iteration number to start Nesterov gradient acceleration at""" + + # DM reconstruction parameters + pfft_relaxation: float = 0.05 + """Relaxation in the Fourier domain""" + probe_regularization: float = 0.1 + """Weight factor for the probe update (inertia)""" + + apply_relaxed_position_constraint: bool = True + """ + When true, slowly relax probe positions towards the affine/geometry model. + When false, update random probe error and geometry model independently. + """ + max_pos_update_shift: float = 0.1 + """Maximum position update allowed each iteration (px).""" + probe_position_error_max: float = 2. + """Maximum random position error (px). Set to 0 to disable random position correction.""" + probe_position_search_momentum: float = 0. + """Momentum acceleration for probe position update.""" + + save_images: t.List[t.Literal[ + 'obj_ph', 'obj_ph_sum', 'obj_ph_stack', + 'obj_mag', 'obj_mag_sum', 'obj_mag_stack', + 'probe_mag', 'probe', + ]] = Field(default_factory=lambda: [ + 'obj_ph_sum', 'obj_ph_stack', 'obj_mag_sum', 'obj_mag_stack', 'probe', 'probe_mag' + ]) + """Intermediate results to save as TIFF images.""" + + +class PlotParams(ModelConfig): + prepared_data: bool = False + interval: t.Optional[int] = None + log_scale: t.Tuple[bool, bool] = (False, False) + realaxes: bool = True + remove_phase_ramp: bool = False + fov_box: bool = False + fov_box_color: str = 'r' + positions: bool = True + mask_bool: bool = True + windowautopos: bool = True + obj_apod: bool = False + prop_obj: float = 0. + """meters""" + show_layers: bool = True + show_layers_stack: bool = False + object_spectrum: t.Optional[bool] = False + probe_spectrum: t.Optional[bool] = False + conjugate: bool = False + horz_fact: float = 2.5 + FP_maskdim: float = 180e-6 + calc_FSC: bool = False + show_FSC: bool = False + residua: bool = False + + +class SaveParams(ModelConfig): + external: bool = True + """Save in external matlab session""" + store_images: bool = False + """Save images of final reconstructions (to '$base_path/analysis/online/ptycho').""" + store_images_intermediate: bool = False + """Save images of reconstructions after each engine (to '$base_path/analysis/online/ptycho').""" + store_images_ids: t.List[PositiveInt] = Field(default_factory=lambda: [1, 2, 3, 4]) + """IDs of images to be stored. 1 = obj amplitude, 2 = obj phase, 3 = probes, 4 = errors, 5 = probes spectrum, 6 = object spectrum""" + store_images_format: t.Literal['png', 'jpg'] = 'png' + store_images_dpi: PositiveInt = 300 + exclude: t.List[str] = Field(default_factory=lambda: ['fmag', 'fmask', 'illum_sum']) + """Variables to exclude from saved output""" + save_reconstructions_intermediate: bool = False + save_reconstructions: bool = False + output_file: t.Literal['h5', 'mat'] = 'h5' + + +class Params(ModelConfig): + name: str + """Name of reconstruction. Used by python only.""" + + file_type: t.Literal['recons_params'] = 'recons_params' + + engine: t.Literal['fold_slice'] = 'fold_slice' + """Reconstruction engine to use. For future use""" + + # Display/output params + verbose_level: int = 2 + """0-1 for loops, 2-3 for testing, >=4 for debugging.""" + use_display: bool = False + """Whether to display plots""" + scan_number: ValueOrList[int] = ValueOrList.parse_obj(1) + """Scan number for shared scans?""" + + # Geometry + z: t.Literal[1] = 1 + """Distance from object to detector.""" + asize: t.Tuple[PositiveInt, PositiveInt] = (128, 128) + """Diffraction pattern size in px (y, x).""" + ctr: t.Optional[t.Tuple[PositiveInt, PositiveInt]] = None + """Diffraction pattern center coordinates (y, x).""" + + beam_source: t.Literal['electron'] = 'electron' + """Beam source.""" + d_alpha: float + """Diffraction pixel size (mrad)""" + prop_regime: t.Literal['farfield', 'nearfield'] = 'farfield' + """Wave propagator to use. Nearfield = Fresnel, Farfield (default) = Fraunhofer""" + focus_to_sample_distance: t.Optional[float] = None + """Focus to sample distance, used for near-field propagator.""" + energy: float = 200. + """Beam energy (in keV).""" + + apply_tilted_plane_correction: t.Literal['', 'propagation', 'diffraction'] = '' + """ + If enabled, applies 'sample_rotation_angles' to the propagator ('propagation') or raw data ('diffraction') + """ + sample_rotation_angles: t.Tuple[float, float, float] = (0., 0., 0.) + """ + Sample mistilt in [X, Y, Z (rotation)], in degrees. + + Applied if ``apply_tilted_plane_correction = 'propagation'``. + """ + + + auto_center_data: bool = True + """Auto center cbed patterns. Crops diffraction pixels accordingly""" + + #thickness: float + #"""Object thickness (in angstroms).""" + + #n_layers: PositiveInt + #"""Multislice layers. 1 for single slice""" + + affine_angle: t.Optional[float] = 0. + """Angle (in degrees) to rotate probe positions by.""" + affine_matrix: t.List[t.List[float]] = [] + """Affine matrix to apply to probe positions. Generated from `affine_angle` if not specified.""" + + @root_validator + def validate_transform(cls, values): + if len(values['affine_matrix']) == 0: + if values.get('affine_angle') is None: + raise ValueError("Either affine_matrix or affine_angle must be specified.") + a = values['affine_angle'] * math.pi / 180. + values['affine_matrix'] = [[math.cos(a), math.sin(a)], [-math.sin(a), math.cos(a)]] + elif len(values['affine_matrix']) != 2 or len(values['affine_matrix'][0]) != 2 or len(values['affine_matrix'][1]) != 2: + raise ValueError("Invalid shape for 'affine_matrix'. Expected a 2x2 matrix.") + return values + + @root_validator(pre=False) + def valdiate_paths(cls, params): + # expand paths + for k in ('base_path', 'ptycho_matlab_path', 'cSAXS_matlab_path'): + params[k] = params[k].expanduser() + + # if base_path is relative, expand it + params['base_path'] = params['base_path'].absolute() + + # make certain paths relative to base_path if specified + for k in ('prepare_data_path', 'save_path', 'raw_data_path', 'initial_probe_file', 'initial_iterate_object_file'): + if params[k] is None: + params[k] = Path() + else: + params[k] = params[k].expanduser() + if not params[k].is_absolute(): + params[k] = params['base_path'] / params[k] + + if 'scan' in params and isinstance(params['scan'], CustomScanParams): + path = params['scan'].custom_positions_source.expanduser() + if not path.is_absolute(): + path = params['base_path'] / path + params['scan'] = CustomScanParams(custom_positions_source=path) + + return params + + @root_validator + def update_engine_defaults(cls, values): + # update engines using default values from all_engines + engines = t.cast(t.List[EngineParams], values['engines']) + all_engines = t.cast(EngineParams, values['all_engines']) + default_fields = all_engines.dict(exclude_unset=True) + default_fields['delta_z'] = all_engines.delta_z + default_fields['init_layer_interp'] = all_engines.init_layer_interp + + def update_engine(engine: EngineParams) -> EngineParams: + d = default_fields.copy() + d.update(engine.dict(exclude_unset=True)) + engine = EngineParams(**d) + if engine.delta_z is None: # type: ignore + raise ValueError("'slices' or 'delta_z' must be specified for all engines.") + return engine + + values['engines'] = list(map(update_engine, engines)) + + for (i, engine) in enumerate(values['engines']): + try: + fout = engine.fout.format(i=i+1, **engine.dict()) + fout = Path(fout).expanduser() + if not fout.is_absolute(): + fout = values['base_path'] / fout + fout = str(fout) + os.sep + object.__setattr__(engine, 'fout', fout) + except KeyError as e: + raise FormatError(f"Invalid format string in 'fout' (unknown key {e})") from None + except Exception as e: + raise FormatError("Invalid format string in 'fout'") from e + #for engine in engines: + #if engine.delta_z is None and 'thickness' in values and 'n_layers' in values: + # # hack to get around Engine immutability + # object.__setattr__(engine, 'delta_z', [values['thickness'] / values['n_layers']] * values['n_layers']) + #if engine.probe_modes is None: + # object.__setattr__(engine, 'probe_modes', values['probe_modes']) + #values['engines'] = engines + return values + + src_metadata: t.Literal['none'] = 'none' + """Not currently used""" + + queue: QueueParams = QueueParams.parse_obj({}) + + detector: DetectorParams = DetectorParams.parse_obj({}) + + prepare: PrepareParams = PrepareParams.parse_obj({}) + + all_engines: EngineParams = Field(default_factory=lambda: EngineParams.parse_obj({})) + + engines: t.List[EngineParams] = Field(default_factory=lambda: [EngineParams.parse_obj({})]) + + src_positions: t.Literal['matlab_pos', 'load_from_file'] = 'matlab_pos' + """Where to get scan positions from""" + positions_file: str = "" + """When src_positions=load_from_file, position file to load from. Formatted with the scan number""" + + scan: t.Union[RasterScanParams, CustomScanParams, ListScanParams] = Field(..., discriminator='type') + """Scan settings""" + + prefix: str = "" + """Prefix for outputs. If empty, scan number is used""" + suffix: str = 'ML_recon' + """Suffix for reconstruction outputs""" + scan_string_format: str = '%01d' + """Matlab format string used to format scan number""" + + base_path: Path = Path("") + """Directory to reconstruct into.""" + ptycho_matlab_path: Path = Path("") + """Path to fold_slice/ptycho folder.""" + cSAXS_matlab_path: Path = Path("") + """Path to csolver. Shouldn't be necessary.""" + raw_data_path: Path = Field(None) + """Path to get raw data from. Can be specified relative to base_path.""" + raw_data_filename: str = "" + """Raw data filename. Defaults to 'scan_x%d_y%d.raw'.""" + + prepare_data_filename: str = "" + """Filename to write prepared data (as HDF5) to.""" + + save_path: Path = Field(None) + """ + Filename to save analysis to. Can be specified relative to base_path. + Defaults, in matlab, to '{base_path}/analysis/S00000-00999/S{scan_number:05}/'. + """ + specfile: t.Literal[''] = '' + """Metadata filename. Currently unused""" + + prepare_data_path: Path = Field(None) + """ + Filename to write prepared data into. Can be specified relative to base_path. + Defaults to save_path. + """ + + io: IOParams = IOParams.parse_obj({}) + + model_object: bool = True + """If true, model initial object using 'object_type'. If false, load from file ('initial_iterate_object_file').""" + model_probe: bool = True + """Whether to model probe or load it from a file""" + + model: ModelParams + """Probe & Object model settings""" + + initial_iterate_object_file: t.Optional[Path] = None + """ + File to load initial object from. Used only if model_object=false. + Can be specified relative to base_path. + """ + multiple_layers_obj: bool = True + """Whether 'initial_iterate_object_file' is multi- or single-slice""" + + initial_probe_file: t.Optional[Path] = None + """ + File to load initial probe from. Used only if model_probe=false. + Can be specified relative to base_path. + """ + normalize_init_probe: bool = True + """Whether to normalize initial probe. Should be disabled when loading an existing reconstruction.""" + crop_pad_init_probe: bool = False + """Whether to crop/pad (True) or interpolate (False) real-space probe to match the reconstruction size.""" + probe_file_propagation: float = 0. + """Distance to propagate the probe from the initial position (meters)""" + + share_probe: bool = False + """Whether to share probes between scans.""" + share_object: bool = False + """Whether to share object between scans.""" + + mode_start_pow: t.Union[float, t.List[float]] = 0.02 + """Normalized intensity to start higher probe modes at.""" + mode_start: t.Literal['rand', 'herm', 'hermver', 'hermhor'] = 'herm' + """Higher mode probe initialization.""" + ortho_probes: bool = True + """Orthogonalize probe modes after each engine""" + ortho_probe_modes: bool = False + """Orthogonalize probe modes after each iteration""" + + plot: PlotParams = PlotParams.parse_obj({}) + save: SaveParams = SaveParams.parse_obj({}) + + +class EngineParamSet(EngineParams): + number_iterations: ValueOrList[PositiveInt] = ValueOrList.parse_obj(200) + + fout: str = "{method}_{opt_errmetric}_p{probe_modes}_g{grouping}_step{i}" + """Output path. Format string, can be specified relative to base_path.""" + + method: ValueOrList[t.Literal['MLs', 'MLc', 'DM', 'ePIE', 'hPIE']] = ValueOrList.parse_obj('MLs') + """Optimization method. MLs = maximum-likelihood sparse, MLc = maximum-likelihood compact, DM = difference map""" + + grouping: ValueOrList[PositiveInt] = ValueOrList.parse_obj(64) + """Size of processed blocks. Memory/efficiency trade-off, but smaller may lead to faster convergence for MLs""" + + object_change_start: ValueOrList[PositiveInt] = ValueOrList.parse_obj(1) + """Start updating object potential at this iteration.""" + probe_change_start: ValueOrList[PositiveInt] = ValueOrList.parse_obj(20) + """Start updating probe wavefunctions at this iteration.""" + probe_position_search: ValueOrList[PositiveInt] = ValueOrList.parse_obj(50) + """Start updating probe positions at this iteration.""" + + # Orthogonal Probe Relaxation + variable_probe: ValueOrList[bool] = ValueOrList.parse_obj(False) + """Enable OPR (vary probe modes throughout a single scan).""" + variable_probe_modes: ValueOrList[PositiveInt] = ValueOrList.parse_obj(1) + """Number of OPR probe modes.""" + variable_probe_smooth: ValueOrList[NonNegativeInt] = ValueOrList.parse_obj(0) + """Order of polynomial fit used to smooth spatial evolution of OPR modes. 0 = no smoothing""" + variable_intensity: ValueOrList[bool] = ValueOrList.parse_obj(False) + """Whether to account for changes in probe intensity.""" + + beta_object: ValueOrList[float] = ValueOrList.parse_obj(1.0) + """Object step size, <= 1""" + beta_probe: ValueOrList[float] = ValueOrList.parse_obj(1.0) + """Probe step size, <= 1""" + delta_p: ValueOrList[float] = ValueOrList.parse_obj(0.1) + """LSQ damping constant""" + beta_LSQ: ValueOrList[float] = ValueOrList.parse_obj(0.9) + """Least-squares step size. Should be ~0.5 for noisy data, ~0.9 for clean data.""" + + reg_mu: ValueOrList[float] = ValueOrList.parse_obj(0.0) + """Object smoothness regularization constant. 0 for no regularization""" + delta: ValueOrList[float] = ValueOrList.parse_obj(0.0) + """Press values to zero out of the illumination area in the object, usually 1e-2 is enough""" + positivity_constraint_object: ValueOrList[float] = ValueOrList.parse_obj(0.0) + """Enforce weak positivity in object. 1e-2 should be enough.""" + amplitude_threshold_object: ValueOrList[float] = ValueOrList.parse_obj(1.5) + """Clamp object amplitude. Set to 'inf' to disable.""" + + probe_modes: ValueOrList[PositiveInt] = ValueOrList.parse_obj(8) + """Number of coherent probe modes.""" + object_modes: ValueOrList[PositiveInt] = ValueOrList.parse_obj(8) + """Number of coherent object modes.""" + regularize_layers: ValueOrList[float] = ValueOrList.parse_obj(1.) + """Apply regularization on the reconstructed object layers. 0 = no regularization, 0.01 = weak regularization.""" + + asize_presolve: ValueOrList[t.Optional[t.Tuple[PositiveInt, PositiveInt]]] = ValueOrList[t.Optional[t.Tuple[PositiveInt, PositiveInt]]].parse_obj(None) + """Crop data to get low resolution estimate for next engine""" + + slices: t.Optional[SlicesSet] = None + + delta_z: t.Optional[t.List[float]] = None + + @validator('delta_z') + def validate_delta_z(cls, v, values: t.Dict[str, t.Any], **kwargs) -> t.Any: + return v + + init_layer_preprocess: ValueOrList[t.Literal['all', 'avg', 'avg1', 'interp']] = ValueOrList.parse_obj('all') + """ + How to preprocess layers from previous reconstruction step. + 'all' (default): No pre-processing + 'avg': Average layers together + 'avg1': Average layers and keep one (the rest initialized with 'init_layer_append_mode') + 'interp': Interpolate old object layers to 'init_slices'/'init_layer_interp'. + """ + + init_slices: t.Optional[SlicesInterp] = None + """ + Slices to interpolate with. Required when 'init_layer_preprocess: interp'. + """ + + init_layer_interp: t.List[float] = Field(default_factory=list) + """List of slice z positions. Initialized from 'init_slices' in python.""" + + @validator('init_layer_interp') + def validate_init_layer_interp(cls, v, values: t.Dict[str, t.Any], **kwargs) -> t.Any: + return v + + init_layer_scaling_factor: ValueOrList[float] = ValueOrList.parse_obj(1.) + """Scaling factor applied to phase of previous object layers. Useful when 'delta_z' is changed.""" + + tilt_x: ValueOrList[float] = ValueOrList.parse_obj(0.) + """ + Propagator x tilt (in mrad). + + Positive causes a shift rightwards (+x) when moving into the sample. + Unrelated to `sample_rotation_angles` and `apply_tilted_plane_correction`. + """ + tilt_y: ValueOrList[float] = ValueOrList.parse_obj(0.) + """ + Propagator y tilt (in mrad). + + Positive causes a shift downwards (+y) when moving into the sample. + Unrelated to `sample_rotation_angles` and `apply_tilted_plane_correction`. + """ + + _iterable_keys = ( + 'slices', 'number_iterations', 'method', 'grouping', 'probe_modes', 'object_modes', 'object_change_start', 'probe_change_start', + 'probe_position_search', 'beta_object', 'beta_probe', 'delta_p', 'beta_LSQ', 'regularize_layers', 'asize_presolve', + 'reg_mu', 'delta', 'positivity_constraint_object', 'amplitude_threshold_object', + 'init_layer_preprocess', 'init_layer_scaling_factor', 'tilt_x', 'tilt_y', + 'variable_probe', 'variable_probe_modes', 'variable_probe_smooth', 'variable_intensity', + ) + _len: t.Optional[int] = PrivateAttr(None) + + def __len__(self) -> int: + if self._len is None: + self._len = math.prod(len(getattr(self, k) or (None,)) for k in self._iterable_keys) + return self._len + + def iter(self, sparse: bool = False) -> t.Iterator[EngineParams]: + params = dict(filter(lambda t: t[0] in self.__fields_set__, self.__dict__.items())) + for d in _iter_dict(params, self._iterable_keys, sparse): + yield EngineParams.parse_obj(d) + + def __iter__(self) -> t.Iterator[EngineParams]: + return self.iter() + + +class RasterScanParamSet(RasterScanParams): + step: ValueOrList[float] = Field(None) + """Scan step size (angstroms)""" + step_size_x: float = Field(None) + """Scan X step size (angstroms)""" + step_size_y: float = Field(None) + """Scan Y step size (angstroms)""" + + nx: PositiveInt + """Number of scan positions in x""" + ny: PositiveInt + """Number of scan positions in y""" + + custom_flip: ValueOrList[t.Tuple[bool, bool, bool]] = ValueOrList[t.Tuple[bool, bool, bool]].parse_obj((False, False, False)) + """Custom data flip (left-right, up-down, tranpose).""" + + step_randn_offset: ValueOrList[float] = ValueOrList.parse_obj(0.) + """Random offset to apply to raster positions (relative to step size?)""" + + _iterable_keys = ('step', 'custom_flip', 'step_randn_offset') + _len: t.Optional[int] = PrivateAttr(None) + + def __len__(self) -> int: + if self._len is None: + self._len = math.prod(len(getattr(self, k) or (None,)) for k in self._iterable_keys) + return self._len + + def iter(self, sparse: bool = False) -> t.Iterator[RasterScanParams]: + params = dict(filter(lambda t: t[0] in self.__fields_set__, self.__dict__.items())) + for d in _iter_dict(params, self._iterable_keys, sparse): + yield RasterScanParams.parse_obj(d) + + def __iter__(self) -> t.Iterator[RasterScanParams]: + return self.iter() + + +class CustomScanParamSet(CustomScanParams): + custom_positions_source: ValueOrList[Path] + """Niter.mat file to load positions from. Can be specified relative to base_path.""" + + _iterable_keys = ('custom_positions_source',) + _len: t.Optional[int] = PrivateAttr(None) + + def __len__(self) -> int: + if self._len is None: + self._len = math.prod(len(getattr(self, k) or (None,)) for k in self._iterable_keys) + return self._len + + def iter(self, sparse: bool = False) -> t.Iterator[CustomScanParams]: + params = dict(filter(lambda t: t[0] in self.__fields_set__, self.__dict__.items())) + for d in _iter_dict(params, self._iterable_keys, sparse): + yield CustomScanParams.parse_obj(d) + + def __iter__(self) -> t.Iterator[CustomScanParams]: + return self.iter() + + def apply_metadata(self, metadata: t.Optional[Metadata] = None) -> CustomScanParamSet: + return self + + def is_default(self) -> bool: + return False + + +class DetectorParamSet(DetectorParams): + upsampling: ValueOrList[t.Union[NonNegativeInt, bool]] = ValueOrList.parse_obj(False) + + bg_sub: ValueOrList[t.Union[bool, float]] = ValueOrList.parse_obj(False) + """Subtract background intensity from diffraction patterns.""" + + crop: ValueOrList[t.Optional[t.Tuple[int, int, int, int]]] = ValueOrList[t.Optional[t.Tuple[int, int, int, int]]].parse_obj(None) + """ + Scan dimensions to crop to (min_x, max_x, min_y, max_y). + Matlab-style slicing, so [1, 128, 1, 128] == an entire 128x128 scan + """ + + step: ValueOrList[t.Optional[t.Tuple[int, int]]] = ValueOrList[t.Optional[t.Tuple[int, int]]].parse_obj(None) + """ + Only use every n x m scan positions (x, y). Useful for testing reconstruction at multiples of step size. + """ + + tile: ValueOrList[t.Optional[t.Tuple[int, int]]] = ValueOrList[t.Optional[t.Tuple[int, int]]].parse_obj(None) + """ + Tile the scan n x m times prior to reconstruction (x, y). Useful for increasing the size of simulated data. + + Care must be taken to ensure the scan is periodic + """ + + circ_mask: ValueOrList[t.Optional[int]] = ValueOrList[t.Optional[int]].parse_obj(None) + """If specified, applies a circular mask of the radius to each diffraction pattern.""" + + psf_sigma: ValueOrList[t.Optional[float]] = ValueOrList[t.Optional[float]].parse_obj(None) + """ + Apply Gaussian point spread function to the raw data before reconstruction. + Mainly useful for simulating a non-ideal detector. + """ + + beam_dose: ValueOrList[t.Optional[float]] = ValueOrList[t.Optional[float]].parse_obj(None) + """Total electron dose (in e/A^2) to scale simulated data by. Overrides `beam_current` if specified.""" + + beam_current: ValueOrList[float] = ValueOrList[float].parse_obj(30.) + """Beam current (in pA) to scale simulated data by.""" + + poisson: ValueOrList[bool] = ValueOrList[bool].parse_obj(False) + """ + Apply Poisson noise to the raw data before reconstruction. + If using this with simulated data, make sure `beam_current` + is set correctly. + """ + + _iterable_keys = ('upsampling', 'bg_sub', 'crop', 'step', 'tile', 'circ_mask', 'psf_sigma', 'poisson', 'beam_dose', 'beam_current') + _len: t.Optional[int] = PrivateAttr(None) + + def __len__(self) -> int: + if self._len is None: + self._len = math.prod(len(getattr(self, k, (None,))) for k in self._iterable_keys) + return self._len + + def iter(self, sparse: bool = False) -> t.Iterator[DetectorParams]: + params = dict(filter(lambda t: t[0] in self.__fields_set__, self.__dict__.items())) + for d in _iter_dict(params, self._iterable_keys, sparse): + yield DetectorParams.parse_obj(d) + + def __iter__(self) -> t.Iterator[DetectorParams]: + return self.iter() + + def apply_metadata(self, metadata: t.Optional[Metadata] = None) -> DetectorParamSet: + from_meta: t.Dict[str, t.Any] = {} + if metadata is not None: + from_meta['sim'] = metadata.is_simulated() + if metadata.crop is not None: + from_meta['crop'] = ValueOrList[t.Optional[t.Tuple[int, int, int, int]]].parse_obj(metadata.crop) + + return DetectorParamSet.parse_obj({**from_meta, **{k: self.__dict__[k] for k in self.__fields_set__}}) + + +class ParamSet(Params): + name: str + """Name of reconstruction. Used by python only. May be a format string""" + + file_type: t.Literal['recons_param_set'] = 'recons_param_set' + + base_path: Path = Path("{name}") + """Directory to reconstruct into. May be a Python format string, which is expanded for each reconstruction.""" + + sample_rotation_angles: ValueOrList[t.Tuple[float, float, float]] = ValueOrList[t.Tuple[float, float, float]].parse_obj((0., 0., 0.)) + """ + Sample mistilt in [X, Y, Z (rotation)], in degrees. + + Applied if ``apply_tilted_plane_correction = 'propagation'``. + """ + + energy: ValueOrList[float] = ValueOrList.parse_obj(200.) + """Beam energy (in keV).""" + + asize: ValueOrList[t.Tuple[PositiveInt, PositiveInt]] = ValueOrList[t.Tuple[PositiveInt, PositiveInt]].parse_obj((128, 128)) + """Diffraction pattern size in px (y, x).""" + + detector: DetectorParamSet = DetectorParamSet.parse_obj({}) + """Detector settings""" + + scan: t.Union[RasterScanParamSet, CustomScanParamSet, ListScanParams] + """Scan settings""" + + model: ModelParamSet + """Probe & Object model settings""" + + prepare_data_path: t.Optional[Path] = None + save_path: t.Optional[Path] = None + raw_data_path: t.Optional[Path] = None + + all_engines: EngineParamSet = Field(default_factory = lambda: EngineParamSet.parse_obj({})) + + engines: t.List[EngineParamSet] = Field(default_factory = lambda: [EngineParamSet.parse_obj({})]) + + d_alpha: ValueOrList[float] = Field(None) + """Diffraction pixel size (mrad).""" + + affine_angle: ValueOrList[t.Optional[float]] = ValueOrList[t.Optional[float]].parse_obj(None) + """Angle (in degrees) to rotate probe positions by.""" + + affine_matrix: ValueOrList[t.List[t.List[float]]] = ValueOrList.parse_obj([[]]) + """Affine matrix to apply to probe positions. Generated from `affine_angle` if not specified.""" + + sparse: bool = False + """Whether to return a sparse or dense combination of parameters.""" + + @root_validator + def valdiate_paths(cls, params): + return params + + @root_validator + def update_engine_defaults(cls, values): + return values + + @root_validator + def validate_transform(cls, values): + return values + + _iterable_keys = ( + 'energy', 'scan', 'detector', 'd_alpha', 'affine_angle', 'affine_matrix', + 'all_engines', 'model', 'asize', 'sample_rotation_angles', + ) + _len: t.Optional[int] = PrivateAttr(None) + + def __len__(self) -> int: + if self._len is None: + self._len = (math.prod(len(getattr(self, k) or (None,)) for k in self._iterable_keys) * + math.prod(len(engine) for engine in self.engines)) + return self._len + + def iter(self, + _save_record: t.Optional[SaveRecord] = None, metadata: t.Optional[Metadata] = None, + sparse: t.Optional[bool] = None + ) -> t.Iterator[Params]: + # keep a list of used paths, names, and a counter of iterations. + if _save_record is None: + _save_record = SaveRecord() + sparse = sparse if sparse is not None else self.sparse + + params = dict(filter(lambda t: t[0] in self.__fields_set__, self.__dict__.items())) + params.pop('file_type', None) + params.pop('sparse', None) + + # If keys are missing from `vals`, replace with a sentinel which gets filtered out later + for d in _iter_dict(params, self._iterable_keys, sparse): + # apply values for this run (skipping those which are _missing) + for engines in product(*self.engines): + d.update(engines=list(engines)) + + # format run-specific name and base_path + try: + name = self.name.format(i=_save_record.i, meta=metadata, sparse=sparse, **d) + d['name'] = _save_record.deduplicate_name(name) + path = str(self.base_path).format(i=_save_record.i, meta=metadata, **d) + d['base_path'] = _save_record.deduplicate_path(path) + except KeyError as e: + raise FormatError(f"Invalid format string (unknown key {e})") from None + except Exception as e: + raise FormatError("Invalid format string") from e + + # finally, yield a parsed object + yield Params.parse_obj(d) + _save_record.i += 1 + + def __iter__(self) -> t.Iterator[Params]: + return self.iter() + + +class RasterScanParamMetaSet(RasterScanParamSet): + nx: t.Optional[PositiveInt] = None + """Number of scan positions in x. Required if metadata not given.""" + ny: t.Optional[PositiveInt] = None + """Number of scan positions in y. Required if metadata not given.""" + + @root_validator + def validate_scan_step(cls, values): + # save validation for after we have any metadata + return values + + def apply_metadata(self, metadata: t.Optional[Metadata] = None) -> RasterScanParamSet: + from_meta: t.Dict[str, t.Any] = {} + if metadata is not None: + from_meta.update( + nx=metadata.scan_shape[0], + ny=metadata.scan_shape[1], + step=ValueOrList.parse_obj(metadata.scan_step[0]*1e10), # m to A + step_size_x=metadata.scan_step[0]*1e10, # m to A + step_size_y=metadata.scan_step[1]*1e10, # m to A + ) + else: + if self.nx is None or self.ny is None: + raise ValueError("Scan 'nx' and 'ny' must be specified in parameters or metadata.") + if self.step is None and (self.step_size_x is None or self.step_size_y is None): + raise ValueError("Scan 'step' or 'step_size_x' and 'step_size_y' must be specified in parameters or metadata.") + + d = {**from_meta, **{k: self.__dict__[k] for k in self.__fields_set__}} + return RasterScanParamSet.parse_obj({k: v.__root__ if isinstance(v, WrapperModel) else v for (k, v) in d.items()}) + + def is_default(self) -> bool: + return not len(self.__fields_set__) + + +class ModelParamMetaSet(ModelParamSet): + probe_alpha_max: t.Optional[ValueOrList[float]] = None + """Model probe convergence angle (mrad). Required if metadata not given.""" + probe_df: t.Optional[ValueOrList[float]] = None + """Model probe defocus (angstroms, overfocus is negative). Required if metadata not given.""" + + def apply_metadata(self, metadata: t.Optional[Metadata] = None) -> ModelParamSet: + from_metadata: t.Dict[str, t.Any] = {} + if metadata is not None: + if metadata.conv_angle is not None: + from_metadata['probe_alpha_max'] = ValueOrList[float].parse_obj(metadata.conv_angle) + if metadata.defocus is not None: + # sign convention opposite, m -> angstrom + from_metadata['probe_df'] = ValueOrList[float].parse_obj(-1e10 * metadata.defocus) + + new = {**from_metadata, **{k: self.__dict__[k] for k in self.__fields_set__}} + + if new.get('probe_alpha_max') is None: + raise ValueError("'probe_alpha_max' must be specified in parameters or metadata.") + if new.get('probe_df') is None: + raise ValueError("'probe_df' must be specified in parameters or metadata.") + + return ModelParamSet.parse_obj({k: v.__root__ if isinstance(v, WrapperModel) else v for (k, v) in new.items()}) + + +class ParamMetaSet(ParamSet): + file_type: t.Literal['recons_param_meta'] = 'recons_param_meta' + + meta: ValueOrList[t.Union[AnyMetadata, Path, None]] = ValueOrList[t.Union[AnyMetadata, Path, None]].parse_obj(None) + """ + Metadata object or path(s) to metadata files. If a relative path is specified, + it will be resolved relative to the path of the reconstruction file. + """ + + @validator('meta', pre=False) + def validate_meta_path(cls, meta: ValueOrList[t.Union[AnyMetadata, Path, None]]) -> ValueOrList[t.Union[AnyMetadata, Path, None]]: + if meta is None: + return ValueOrList[t.Union[AnyMetadata, Path, None]].parse_obj(None) + return meta.map(lambda p: p.expanduser() if isinstance(p, Path) else p, t.Union[AnyMetadata, Path, None]) + + energy: t.Optional[ValueOrList[float]] = None + """Beam energy (in keV). Required if metadata not given.""" + raw_data_path: t.Optional[Path] = None + """Path to raw data. Can be specified relative to 'base_path'.""" + raw_data_filename: t.Optional[str] = None + """Raw data filename. Matlab format string. Defaults to 'scan_x%d_y%d.raw' if metadata not given.""" + d_alpha: t.Optional[ValueOrList[float]] = None + """Diffraction pixel size (mrad). Required if metadata not given.""" + + scan: t.Union[RasterScanParamMetaSet, CustomScanParamSet, ListScanParams] = RasterScanParamMetaSet.parse_obj({}) + """Scan settings""" + + model: ModelParamMetaSet = ModelParamMetaSet.parse_obj({}) + """Probe & Object model settings""" + + def with_metadata(self, metadata: t.Union[None, Metadata, t.Sequence[Metadata]]) -> ParamMetaSet: + if 'meta' in self.__fields_set__: + print("Overriding metadata specified in ParamMetaSet.", file=sys.stderr) + return self.copy(update={'meta': ValueOrList[t.Union[AnyMetadata, Path, None]].parse_obj(metadata)}) + + def apply_metadata(self, metadata: t.Optional[Metadata] = None) -> ParamSet: + from_meta = {} + if metadata is not None: + from_meta.update( + energy=ValueOrList.parse_obj(metadata.voltage * 1e-3), # V to kV + raw_data_path=metadata.path.resolve(), # type: ignore + raw_data_filename=metadata.raw_filename, + ) + if metadata.diff_step is not None: + from_meta['d_alpha'] = ValueOrList.parse_obj(metadata.diff_step) + + if hasattr(metadata, 'scan_correction') and metadata.scan_correction is not None: + a = metadata.scan_rotation * math.pi / 180. + rot = numpy.array([[math.cos(a), -math.sin(a)], [math.sin(a), math.cos(a)]]) + # apply rotation after scan correction (is this right?) + # also, flip to [y, x] coordinates for ptycho shelves + from_meta['affine_matrix'] = ValueOrList.parse_obj((rot @ numpy.array(metadata.scan_correction)[::-1, ::-1]).tolist()) + from_meta['affine_angle'] = 0. + else: + from_meta['affine_angle'] = ValueOrList.parse_obj(metadata.scan_rotation) + else: + if self.d_alpha is None: + raise ValueError("'d_alpha' must be specified in parameters or metadata.") + if self.energy is None: + raise ValueError("'energy' must be specified in parameters or metadata.") + + if metadata is not None and hasattr(metadata, 'scan_positions') and metadata.scan_positions is not None \ + and self.scan.is_default(): + # scale m -> A, flip 180 degrees, transpose + positions = [(-y*1e10, -x*1e10) for (x, y) in metadata.scan_positions] + + # center around 0, 0 + xshift = (min([p[0] for p in positions]) + max([p[0] for p in positions]))/2. + yshift = (min([p[1] for p in positions]) + max([p[1] for p in positions]))/2. + positions = [(x - xshift, y - yshift) for (x, y) in positions] + + from_meta['scan'] = ListScanParams(scan_positions=positions) + else: + from_meta['scan'] = self.scan.apply_metadata(metadata) + + from_meta.update( + model = self.model.apply_metadata(metadata), + detector = self.detector.apply_metadata(metadata), + ) + + fields = self.__fields_set__ - {'meta', 'scan', 'model', 'detector'} + d = {**from_meta, **{k: self.__dict__[k] for k in fields}} + return ParamSet.parse_obj({k: v.__root__ if isinstance(v, WrapperModel) else v for (k, v) in d.items()}) + + def iter( + self, _save_record: t.Optional[SaveRecord] = None, + path: t.Optional[Path] = None, sparse: t.Optional[bool] = None + ) -> t.Iterator[Params]: + """ + Iterate through the contained Params objects. Metadata files are loaded relative to `path`. + """ + # keep a list of used paths, names, and a counter of iterations. + if _save_record is None: + _save_record = SaveRecord() + + for meta in self.meta: + if isinstance(meta, (str, Path)): + if not meta.is_absolute() and path is not None: + meta = path / meta + print(f"Loading '{meta}' as metadata...") + meta = AnyMetadata.parse_file(meta) + + yield from self.apply_metadata(meta).iter(_save_record, meta, sparse) + + +if __name__ == '__main__': + import yaml + + with open('test.yaml', 'r') as f: + obj = yaml.safe_load(f) + params = Params.parse_obj(obj) + print(params.json()) diff --git a/phaser/cli_tools/prepare.py b/phaser/cli_tools/prepare.py index 01f47d2..fc2760f 100644 --- a/phaser/cli_tools/prepare.py +++ b/phaser/cli_tools/prepare.py @@ -14,9 +14,9 @@ except ImportError: from yaml import Loader -from ptycho_lebeau.params import ParamMetaSet, SaveRecord -from ptycho_lebeau.metadata import Metadata, AnyMetadata -from ptycho_lebeau.util import handle_exception +from .params import ParamMetaSet, SaveRecord +from .metadata import Metadata, AnyMetadata +from .util import handle_exception def _try_parse_metadata(path: Path) -> t.Union[Metadata, ValidationError, None]: diff --git a/phaser/cli_tools/util.py b/phaser/cli_tools/util.py new file mode 100644 index 0000000..d8620f3 --- /dev/null +++ b/phaser/cli_tools/util.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +import functools +import sys +import math +import traceback +import typing as t + +import click + +if t.TYPE_CHECKING: + P = t.ParamSpec('P') + U = t.TypeVar('U') + +def handle_exception(f: t.Callable[P, U]) -> t.Callable[P, U]: + """Catch any exception, print a stack trace up to the current frame, and exits.""" + @functools.wraps(f) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> U: + try: + return f(*args, **kwargs) + except Exception: + # get parent traceback + (ty, e, tb) = sys.exc_info() + # get parent frame + tb = tb.tb_next # type: ignore + traceback.print_exception(ty, e, tb) + raise click.exceptions.Exit(1) + + return wrapper + + +def electron_wavelength(kv: float = 200.) -> float: + """Return the wavelength (in angstroms) of a electron with the given kinetic energy (in keV).""" + # relativistic total energy (pc) (E^2 = KE^2 + 2*KE*RE) + rest_energy = 510.99895000 # keV + hc = 12.3984197 # keV-angstrom + return hc / math.sqrt(kv**2 + 2*kv*rest_energy) \ No newline at end of file diff --git a/phaser/main.py b/phaser/main.py index 6aed60f..13be794 100644 --- a/phaser/main.py +++ b/phaser/main.py @@ -87,52 +87,9 @@ def worker(url: str, quiet: bool = False): run_worker(url, quiet=quiet) -# @cli.command(cls=MainCommand, commands=dict((v, v) for v in -# ('prepare', 'run', 'view_raw', 'view_prepared', 'view_output', -# 'process_metadata', 'extract_params', 'to_csv', 'calc_drift', 'calc_tilt') -# )) - +# Add CLI tool subcommands cli.add_command(tools) -# # --- Action group: tools ----------------------------------------------------- -# @cli.group(invoke_without_command=False) # require a subcommand -# @click.pass_context -# def tools(ctx: click.Context): -# """Toolbox of utility subcommands.""" - # With invoke_without_command=False, Click will show usage if no subcommand. - -# @tools.add_command(process_metadata) - -# @tools.command('prepare') - -# @tools.command('view_raw') - -# # @tools.command('view_prepared') - -# # @tools.command('view_output') - -# @tools.command('process_metadata') - -# @tools.command('extract_params') -# @tools.command('to_csv') -# @tools.command('calc_drift') -# @tools.command('calc_tilt') - - -# @tools.command("align-data") -# @click.argument("input", type=click.Path(exists=True, path_type=Path)) -# @click.argument("output", type=click.Path(path_type=Path)) -# @click.option( -# "--method", -# type=click.Choice(["fft", "xcorr"], case_sensitive=False), -# default="fft", -# show_default=True, -# help="Alignment method.", -# ) -# def align_data_cmd(input: Path, output: Path, method: str): -# """Align image/data files and write the result.""" -# align_data(input, output, method) - if __name__ == '__main__': From 6a412758f3fdd58ef0fe467ce0620644e91fd92e Mon Sep 17 00:00:00 2001 From: subangstrom Date: Mon, 25 Aug 2025 16:21:01 -0400 Subject: [PATCH 5/8] added other commands --- phaser/cli_tools/__init__.py | 5 + phaser/cli_tools/analysis.py | 208 ++++++++++++++++++++++++++++++ phaser/cli_tools/pprint_xml.py | 17 +++ phaser/cli_tools/view_output.py | 93 +++++++++++++ phaser/cli_tools/view_prepared.py | 112 ++++++++++++++++ phaser/cli_tools/view_raw.py | 168 ++++++++++++++++++++++++ 6 files changed, 603 insertions(+) create mode 100644 phaser/cli_tools/analysis.py create mode 100755 phaser/cli_tools/pprint_xml.py create mode 100755 phaser/cli_tools/view_output.py create mode 100644 phaser/cli_tools/view_prepared.py create mode 100755 phaser/cli_tools/view_raw.py diff --git a/phaser/cli_tools/__init__.py b/phaser/cli_tools/__init__.py index 5320d96..1350699 100644 --- a/phaser/cli_tools/__init__.py +++ b/phaser/cli_tools/__init__.py @@ -4,6 +4,9 @@ from .calc_tilt import calc_tilt from .calc_drift import calc_drift +from .view_raw import view_raw +from .view_prepared import view_prepared +from .view_output import view_output @click.group() def tools(): @@ -15,3 +18,5 @@ def tools(): tools.add_command(prepare) tools.add_command(calc_tilt) tools.add_command(calc_drift) +tools.add_command(view_raw) + diff --git a/phaser/cli_tools/analysis.py b/phaser/cli_tools/analysis.py new file mode 100644 index 0000000..4b554e7 --- /dev/null +++ b/phaser/cli_tools/analysis.py @@ -0,0 +1,208 @@ +""" +Code for analysis of ptychographic reconstructions +""" + +from dataclasses import dataclass +import typing as t + +import numpy +from numpy.typing import NDArray, ArrayLike + +ScalarT = t.TypeVar('ScalarT', bound=numpy.generic) + + +@dataclass +class FRCResult: + """Fourier Ring Correlation result""" + + freq: NDArray[numpy.float64] + """List of frequencies in k-space (units of 1/px)""" + + corr: NDArray[numpy.float64] + """Fourier ring correlations corresponding to `freq`/`r`""" + + r: NDArray[numpy.float64] + """List of radii in k-space (proportional to `freq`, units of reciprocal space px)""" + + r_count: NDArray[numpy.int64] + """Count of pixels contained in each radius bin""" + + single_image: bool = False + """ + Whether the FRC was calculated from a single image checkerboard. + This affects the final resolution calculation. + """ + + def half_bit_threshold(self) -> NDArray[numpy.float64]: + """Return the half-bit FRC threshold""" + snr = 2**0.5 - 1 + return self.snr_threshold(snr) + + def const_threshold(self, corr: float = 0.5) -> NDArray[numpy.float64]: + """Return the FRC threshold for a constant correlation `corr`.""" + return numpy.full_like(self.freq, corr) + + def snr_threshold(self, snr: float) -> NDArray[numpy.float64]: + """Return the FRC threshold for the given signal-to-noise ratio `snr`.""" + n = 1/numpy.sqrt(self.r_count) + + return (snr + (2*numpy.sqrt(snr) + 1)*n) / (1 + snr + 2*numpy.sqrt(snr)*n) + + def intersect(self, threshold: ArrayLike) -> t.Tuple[float, float]: + """ + Return the (freq, corr) point where the FRC crosses `threshold`. + """ + diff = self.corr - numpy.array(threshold, dtype=float) + lastdiff = numpy.roll(diff, shift=1) + try: + # find negative zero crossing + i = numpy.nonzero((diff[1:] <= 0) & (lastdiff[1:] > 0))[0][0] + except IndexError: + raise ValueError("No crossing found when evaluating FRC resolution") + + # calculate linear intersection between `frc.corr` and `threshold` + # slope of line + m = (diff[i+1] - diff[i]) / (self.freq[i+1] - self.freq[i]) + # frequency difference to reach x-axis + x_d = -diff[i]/m + # slope of self line + m_corr = (self.corr[i+1] - self.corr[i]) / (self.freq[i+1] - self.freq[i]) + # (x, y) + return (float(self.freq[i] + x_d), float(self.corr[i] + x_d * m_corr)) + + def resolution(self, threshold: ArrayLike, pixel_size: float = 1.0) -> float: + """ + Return the FRC resolution using threshold `threshold`. + + `pixel_size` is the pixel size in units/px. + Returns a value in the same units as `pixel_size`. + """ + freq, _ = self.intersect(threshold) + if self.single_image: + # upscale frequency (to undo our downsampling) + freq *= float(numpy.sqrt(2)) + return pixel_size / freq + + +def fourier_ring_correlate(img: NDArray[numpy.generic], *, r_spacing: float = 1.0, + inscribed: bool = True, hann_window: bool = True) -> FRCResult: + """ + Fourier-ring correlate a single image `img`. + + This uses `checkerboard_split` to split an image up into two pairs, + performs FRC on each, and averages the results. + """ + subimages = checkerboard_split(img) + # pair up along diagonals + pair1, pair2 = (subimages[0], subimages[3]), (subimages[1], subimages[2]) + result1 = fourier_ring_correlate_pair(*pair1, r_spacing=r_spacing, inscribed=inscribed, hann_window=hann_window) + result2 = fourier_ring_correlate_pair(*pair2, r_spacing=r_spacing, inscribed=inscribed, hann_window=hann_window) + + return FRCResult( + freq=result1.freq, corr=(result1.corr + result2.corr) / 2., r=result1.r, r_count=result1.r_count, + single_image=True + ) + + +def fourier_ring_correlate_pair(img1: NDArray[numpy.generic], img2: NDArray[numpy.generic], *, + r_spacing: float = 1.0, inscribed: bool = True, hann_window: bool = True) -> FRCResult: + """ + Fourier-ring correlate `img1` and `img2`. + + Parameters: + `r_spacing`: The spacing of ring radii to return correlations at. + `inscribed`: If specified, frequencies are returned only to the edge of k-space (1/2 px^-1). + Otherwise, frequencies are returned to the corner of k-space. + `hann_window`: If specified, a Hann window is applied to each image before cross-correlating. + """ + + # convert images to float + img1f = img1.astype(numpy.float64) + img2f = img2.astype(numpy.float64) + + # apply hann window if desired + if hann_window: + from skimage.filters import window + win = numpy.asarray(window('hann', img1f.shape), dtype=numpy.float64) + img1f *= win + img2f *= win + + # take FFTs, and find magnitudes + fft1 = numpy.fft.fftshift(numpy.fft.fft2(img1f), axes=(-2, -1)) + fft2 = numpy.fft.fftshift(numpy.fft.fft2(img2f), axes=(-2, -1)).conj() + fft1_mag = fft1.real**2 + fft1.imag**2 + fft2_mag = fft2.real**2 + fft2.imag**2 + + # calculate radii in reciprocal space + y, x = numpy.indices(img1.shape[-2:]) + c_y, c_x = tuple(int(s//2) for s in img1.shape[-2:]) + r = numpy.sqrt((y - c_y)**2 + (x - c_x)**2) + + # assign radii to bins + r_i = numpy.floor(r / r_spacing).astype(numpy.int_).ravel() + # sum correlation and magnitude along each bin + real_count = numpy.bincount(r_i, (fft1 * fft2).real.ravel()) + mag1_count = numpy.bincount(r_i, fft1_mag.ravel()) + mag2_count = numpy.bincount(r_i, fft2_mag.ravel()) + r_count = numpy.bincount(r_i) + + with numpy.errstate(invalid='ignore', divide='ignore'): + # compute correlations + vals = numpy.abs(real_count) / numpy.sqrt(mag1_count * mag2_count) + # 0/0 => 1.0, 1/0 => 1.0 + vals = numpy.nan_to_num(vals, posinf=1.0, nan=1.0) + + # calculate radii and frequencies sampled at + rs = numpy.linspace(0., len(vals) * r_spacing, len(vals), endpoint=False) + freq = rs / numpy.sqrt(numpy.prod(img1.shape[-2:])) + + if inscribed: + # crop to inner radii + n_r = int(numpy.floor(min(c_y, c_x) / r_spacing)) + return FRCResult(freq=freq[:n_r], corr=numpy.abs(vals[:n_r]), r=rs[:n_r], r_count=r_count[:n_r]) + + return FRCResult(freq=freq, corr=numpy.abs(vals), r=rs, r_count=r_count) + + +def checkerboard_split(img: NDArray[ScalarT]) -> t.Tuple[NDArray[ScalarT], NDArray[ScalarT], NDArray[ScalarT], NDArray[ScalarT]]: + """ + Split image into 4 interlaced subimages (upper left, upper right, lower left, lower right). + + This is useful for single-image Fourier Ring Correlation + """ + shape = img.shape[-2:] + # crop image divisible by two + shape = tuple(s - s%2 for s in shape) + img = img[..., *(slice(0, s) for s in shape)] + + return tuple( + img[..., slice(row_start, shape[0], 2), slice(col_start, shape[1], 2)] + for row_start in (0, 1) + for col_start in (0, 1) + ) # type: ignore + + +def fourier_correlate(img1: NDArray[ScalarT], img2: NDArray[ScalarT], *, hann_window: bool = True) -> NDArray[numpy.complex128]: + """ + Cross-correlate `img1` and `img2` in Fourier space. + + If `hann_window`, a Hann window is applied to each image before cross-correlating. + """ + + img1f = img1.astype(numpy.float64) + img2f = img2.astype(numpy.float64) + + if hann_window: + from skimage.filters import window + win = numpy.asarray(window('hann', img1f.shape), dtype=numpy.float64) + img1f *= win + img2f *= win + + fft1 = numpy.fft.fftshift(numpy.fft.fft2(img1f), axes=(-2, -1)) + fft2 = numpy.fft.fftshift(numpy.fft.fft2(img2f), axes=(-2, -1)).conj() + fft1_mag = fft1.real**2 + fft1.imag**2 + fft2_mag = fft2.real**2 + fft2.imag**2 + + with numpy.errstate(divide='ignore', invalid='ignore'): + corr = fft1 * fft2 / numpy.sqrt(fft1_mag * fft2_mag) + return numpy.nan_to_num(corr, nan=1., posinf=0.) \ No newline at end of file diff --git a/phaser/cli_tools/pprint_xml.py b/phaser/cli_tools/pprint_xml.py new file mode 100755 index 0000000..2214aeb --- /dev/null +++ b/phaser/cli_tools/pprint_xml.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python3 + +import sys +import lxml.etree as etree + +if __name__ == '__main__': + if len(sys.argv) > 1: + fname = sys.argv[1] + f = open(fname, 'r') + else: + f = sys.stdin + + try: + x = etree.parse(f, None) + sys.stdout.write(etree.tostring(x, pretty_print=True, encoding=str)) # type: ignore + except KeyboardInterrupt: + pass diff --git a/phaser/cli_tools/view_output.py b/phaser/cli_tools/view_output.py new file mode 100755 index 0000000..678d108 --- /dev/null +++ b/phaser/cli_tools/view_output.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 + +from pathlib import Path +import typing as t + +import click +import h5py +import numpy +from scipy.io import loadmat +from scipy.linalg import lstsq +from matplotlib import pyplot +from matplotlib.widgets import Slider + + +def remove_phase_ramp(data: numpy.ndarray) -> numpy.ndarray: + output = numpy.empty_like(data) + + (yy, xx) = (arr.flatten() for arr in numpy.indices(data.shape[1:], dtype=float)) + pts = numpy.stack((numpy.ones_like(xx), xx, yy), axis=-1) + + for i in range(data.shape[0]): + layer = data[i] + p, residues, rank, singular = lstsq(pts, layer.flatten()) + output[i] = layer - (p @ pts.T).reshape(layer.shape) + + return output + + +@click.command() +@click.argument('path', type=click.Path(exists=True, dir_okay=False)) +@click.option('--phase/--no-phase', 'show_phase', default=True, help="Plot error progression.") +@click.option('--error/--no-error', 'show_error', default=False, help="Plot error progression.") +def view_output(path: t.Union[str, Path], show_phase: bool = True, show_error: bool = False): + """ + View PtychoShelves output from a Niter.mat file. + """ + print(f"Loading '{path}'...") + try: + f = loadmat(path) + params: numpy.ndarray = t.cast(numpy.ndarray, f['p'])[0, 0] + roi_y, roi_x = params['object_ROI'].flat + roi_xx, roi_yy = numpy.meshgrid(roi_x.flatten(), roi_y.flatten(), indexing='ij') + + error: numpy.ndarray = t.cast(numpy.ndarray, f['outputs']['fourier_error_out'])[0, 0] + + obj: numpy.ndarray = numpy.atleast_3d(f['object'])[roi_yy, roi_xx] + obj = numpy.moveaxis(obj, -1, 0) + print(f"Loaded from loadmat.") + except NotImplementedError: + f = t.cast(h5py.Group, h5py.File(path)) + error = numpy.asarray(f['outputs/fourier_error_out']) + + obj_refs = t.cast(h5py.Dataset, f['outputs/object_roi'])[:, 0] + obj: numpy.ndarray = numpy.stack([t.cast(numpy.ndarray, t.cast(h5py.Dataset, f[h5py.h5r.get_name(ref, f.id)])[()]) for ref in obj_refs], axis=0) + obj = obj['real'] + 1.j * obj['imag'] + print(f"Loaded from HDF5.") + + #obj_mag = t.cast(numpy.ndarray, numpy.abs(obj)) + + if show_phase: + obj_phase = t.cast(numpy.ndarray, numpy.angle(obj)) + obj_phase = numpy.unwrap(numpy.unwrap(obj_phase, axis=-1), axis=-2) + obj_phase = remove_phase_ramp(obj_phase) + + fig, (img_ax, slider_ax) = pyplot.subplots(nrows=2, gridspec_kw={'height_ratios': [12, 1]}) + #vmin, vmax = numpy.min(phase[:, 200:-200, 200:-200]), numpy.max(phase[:, 200:-200, 200:-200]) + vmin, vmax = numpy.nanquantile(obj_phase, 0.02), numpy.nanquantile(obj_phase, 0.98) + img = img_ax.imshow(obj_phase[0], vmin=float(vmin), vmax=float(vmax), cmap='inferno') + slider = Slider(ax=slider_ax, label='Frame', valmin=0, valmax=len(obj)-1, valinit=0, valstep=1) + + fig.colorbar(img, ax=img_ax) + + def update(val): + i = int(numpy.floor(val)) + img.set(data=obj_phase[i]) + + slider.on_changed(update) + + if show_error: + error = error.flatten() + iterations = numpy.arange(error.size)[~numpy.isnan(error)] + error = error[~numpy.isnan(error)] + + fig, ax = pyplot.subplots() + ax.plot(iterations, error, '.-', color='black') + ax.set_xlabel('Iteration') + ax.set_ylabel('Fourier Error') + + pyplot.show() + + +if __name__ == '__main__': + view_output() diff --git a/phaser/cli_tools/view_prepared.py b/phaser/cli_tools/view_prepared.py new file mode 100644 index 0000000..4375b21 --- /dev/null +++ b/phaser/cli_tools/view_prepared.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 + +import typing as t + +import click +import h5py +from scipy.io import loadmat +import numpy +from matplotlib import pyplot +from matplotlib.backend_bases import KeyEvent + + +def load_prepared_data(path: str, positions_path: t.Optional[str] = None) -> t.Tuple[numpy.ndarray, numpy.ndarray]: + """ + Shape: (Nprobes, Ny, Nx), (Nprobes, 2) + """ + + print(f"Loading '{path}'...") + f = t.cast(h5py.Group, h5py.File(path)) + + if 'measurement' in f: + data = t.cast(h5py.Dataset, f['measurement/n0/data'])[()] + data = numpy.fft.fftshift(data, axes=(-1, -2)) + positions = t.cast(h5py.Dataset, f['measurement/n0/positions'])[()] + return (data, positions) + elif 'dp' in f: + data = t.cast(h5py.Dataset, f['dp'])[()] + if positions_path is None: + raise ValueError("Old-style prepared data needs `positions_path` as well.") + positions = load_probe_positions(positions_path) + return (data, positions) + + raise ValueError(f"Couldn't find prepared patterns in '{path}'") + + +def load_probe_positions(path: str) -> numpy.ndarray: + print(f"Loading '{path}'...") + try: + f = loadmat(path) + if 'outputs' in f: + return f['outputs'][0, 0]['probe_positions_0'] + raise Exception("Can't find probe positions in .mat file.") + except (NotImplementedError, ValueError): + pass + + raise NotImplementedError("Probe positions from HDF5 isn't implemented.") + + + +@click.command() +@click.argument('path', type=click.Path(exists=True, dir_okay=False)) +@click.argument('positions_path', type=click.Path(exists=True, dir_okay=False), required=False) +@click.option('--shape', type=(int, int), required=False) +def view_prepared(path: str, positions_path: t.Optional[str] = None, shape: t.Optional[t.Tuple[int, int]] = None): + """ + View prepared data from PtychoShelves .mat or .h5 files. + + PATH should be a prepared HDF5 or .mat file. + For a .mat file, POSITIONS_PATH should be a Niter.mat file. + """ + (data, positions) = load_prepared_data(path, None) + + if data.shape[0] != positions.shape[0]: + raise ValueError("Mismatch in # of probes: {data.shape[0]} in data vs {positions.shape[0]} probe positions.") + + n_probes = data.shape[0] + + stride = int(numpy.sqrt(n_probes)) if shape is None else shape[0] + + pattern_fig, pattern_ax = pyplot.subplots(constrained_layout=True) + pos_fig, pos_ax = pyplot.subplots(constrained_layout=True) + + i: int = -1 + + vmax = float(numpy.nanquantile(data, 0.9999)) + + data = numpy.swapaxes(data[..., ::-1], -1, -2) + img = pattern_ax.imshow(data[i], cmap='inferno', vmin=0.1*vmax, vmax=vmax) + + pos_ax.scatter(positions[:, 0], positions[:, 1], s=2, c='blue') + pos_ax.set_aspect('equal') + pos = pos_ax.scatter([positions[i, 0]], [positions[i, 1]], s=4, c='red') + + def update(): + img.set_data(data[i]) + pos.set_offsets(positions[i, None]) + + pos_fig.canvas.draw_idle() + pattern_fig.canvas.draw_idle() + + def key_press(event: KeyEvent): + nonlocal i + if event.key == 'right': + i -= stride + elif event.key == 'left': + i += stride + elif event.key == 'up': + i -= 1 + elif event.key == 'down': + i += 1 + else: + return + i = i % n_probes + update() + + pattern_fig.canvas.mpl_connect('key_press_event', key_press) + pos_fig.canvas.mpl_connect('key_press_event', key_press) + pyplot.show() + + +if __name__ == '__main__': + view_prepared() diff --git a/phaser/cli_tools/view_raw.py b/phaser/cli_tools/view_raw.py new file mode 100755 index 0000000..2ccde65 --- /dev/null +++ b/phaser/cli_tools/view_raw.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python3 + +from pathlib import Path +import typing as t + +import numpy +import click +from matplotlib import pyplot +from matplotlib.backend_bases import KeyEvent, MouseEvent, MouseButton +import matplotlib.path as mpath +from matplotlib.transforms import Affine2D +from matplotlib.patches import PathPatch + +from .raw import load_4d + + +def save_hist(path: t.Union[str, Path]): + a = load_4d(Path(path)) + + fig = pyplot.figure(figsize=(16, 8), constrained_layout=True) + + (ax1, ax2) = fig.subplots(nrows=2, sharex=True) + #fig.suptitle('loghist') + flat = a.flatten() + flat = flat[flat < 1e5] + + vmin, vmax = numpy.nanmin(flat), numpy.nanmax(flat) + + nbins = min(1024, numpy.floor(vmax - vmin).astype(int)) + + vals, bins = numpy.histogram(flat, bins=nbins, density=True) + widths = numpy.diff(bins) + #ax.bar(bins, numpy.log(vals)) + ax1.set_yscale('log') + ax1.bar(bins[:-1], vals, width=widths, color='black', align='edge') + ax2.bar(bins[:-1], vals, width=widths, color='black', align='edge') + ax1.margins(0.) + ax2.margins(0.) + ax2.set_xlabel("ADU") + ax1.set_ylabel("freq") + ax2.set_ylabel("freq") + + pyplot.show() + + fig.savefig(str(Path(path).with_name('hist.png')), dpi=500) + + +@click.command() +@click.argument('path', type=click.Path(exists=True, dir_okay=False)) +@click.option('--hist/--no-hist', default=False, help="Show histogram") +@click.option('--pacbed/--no-pacbed', default=False, help="Show pacbed") +@click.option('--com/--no-com', default=False, help="Show COM X & Y") +def view_raw(path: t.Union[str, Path], hist: bool = False, pacbed: bool = False, com: bool = False): + """Visualize raw datasets (off-brand 4D STEM Explorer).""" + + a = load_4d(Path(path)) + (yy, xx) = numpy.indices(a.shape[2:]) + + filtered = a.copy() + filtered[(filtered > 1e6) | (filtered < 0.) | numpy.isnan(filtered)] = 0. + + if pacbed: + pacbed = numpy.log(filtered.sum(axis=(0, 1), dtype=numpy.float64)) + pacbed_fig, ax = pyplot.subplots() + ax.imshow(pacbed) + + if com: + com_x = numpy.sum(xx * filtered, axis=(-1, -2), dtype=numpy.float64) / numpy.sum(filtered, axis=(-1, -2), dtype=numpy.float64) + fig, ax = pyplot.subplots() + fig.suptitle('com_x') + ax.imshow(com_x) + com_y = numpy.sum(yy * filtered, axis=(-1, -2), dtype=numpy.float64) / numpy.sum(filtered, axis=(-1, -2), dtype=numpy.float64) + fig, ax = pyplot.subplots() + fig.suptitle('com_y') + ax.imshow(com_y) + + if hist: + fig, (ax1, ax2) = pyplot.subplots(nrows=2, sharex=True, constrained_layout=True) + #fig.suptitle('loghist') + flat = a.flatten() + flat = flat[flat < 1e5] + vals, bins = numpy.histogram(flat, bins=1024, density=True) + widths = numpy.diff(bins) + #ax.bar(bins, numpy.log(vals)) + ax1.set_yscale('log') + ax1.bar(bins[:-1], vals, width=widths, color='black', align='edge') + ax2.bar(bins[:-1], vals, width=widths, color='black', align='edge') + ax1.margins(0.) + ax2.margins(0.) + ax2.set_xlabel("ADU") + ax1.set_ylabel("freq") + ax2.set_ylabel("freq") + #ax2.hist(a.flatten(), bins=1024, density=True, color='black', range=(0, 5000)) #range=(0, 16383) + + #fig, axs = pyplot.subplots(n, n, sharex=True, sharey=True) + recip_fig, recip_ax = pyplot.subplots() + y = a.shape[0] // 2 + x = a.shape[1] // 2 + probe = a[y, x] + img = recip_ax.imshow(probe) + + y_c = 64. + x_c = 64. + r = 28. + mask = (xx - x_c)**2 + (yy - y_c)**2 > r**2 + real_space = numpy.sum(a[:, :, mask], axis=-1, dtype=numpy.float64) + + real_fig, real_ax = pyplot.subplots() + real_ax.imshow(real_space, vmin=float(numpy.quantile(real_space, 0.01)), vmax=float(numpy.quantile(real_space, 0.99))) + + crosshair = mpath.Path(numpy.array([ + [-1.5, -1.5], [1.5, -1.5], [1.5, 1.5], [-1.5, 1.5], + [-0.5, -0.5], [-0.5, 0.5], [0.5, 0.5], [0.5, -0.5] + ]), list(map(int, [mpath.Path.MOVETO, 2, 2, 2, mpath.Path.MOVETO, 2, 2, 2]))) + marker = PathPatch(crosshair, fc='red', fill=True, linestyle='None', transform=Affine2D().translate(x, y) + real_ax.transData) + real_ax.add_patch(marker) + #marker = Rectangle((x-1., y-1.), 2., 2., ec='red', fill=False, lw=1.5) + + def update(): + nonlocal probe + print(f"\rpos: ({x}, {y}) ", end='') + probe = a[y, x] + img.set_data(probe) + marker.set_transform(Affine2D().translate(x, y) + real_ax.transData) + recip_fig.canvas.draw_idle() + real_fig.canvas.draw_idle() + + def key_press(event: KeyEvent): + nonlocal x, y + #print(f"Key released: {event.key}") + if event.key == 'left': + if x > 0: + x -= 1 + elif event.key == 'right': + if x < a.shape[1] - 1: + x += 1 + elif event.key == 'up': + if y > 0: + y -= 1 + elif event.key == 'down': + if y < a.shape[0] - 1: + y += 1 + else: + print("ignored") + update() + + def mouse_click(event: MouseEvent): + nonlocal x, y + if event.button is MouseButton.LEFT \ + and event.x is not None and event.y is not None: + + (click_x, click_y) = real_ax.transData.inverted().transform(tuple(map(int, (event.x, event.y)))) + (click_x, click_y) = map(int, map(round, (click_x, click_y))) + if not 0 <= click_x < a.shape[1] or not 0 <= click_y < a.shape[0]: + return + x, y = click_x, click_y + update() + + recip_fig.canvas.mpl_connect('key_press_event', key_press) + real_fig.canvas.mpl_connect('key_press_event', key_press) + real_fig.canvas.mpl_connect('button_press_event', mouse_click) + + pyplot.show() + print("") + + +if __name__ == '__main__': + view_raw() \ No newline at end of file From 6dd9f5fe7ea6152770544049773307fc6a840949 Mon Sep 17 00:00:00 2001 From: subangstrom Date: Mon, 25 Aug 2025 16:21:22 -0400 Subject: [PATCH 6/8] cli add --- phaser/cli_tools/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/phaser/cli_tools/__init__.py b/phaser/cli_tools/__init__.py index 1350699..f91477f 100644 --- a/phaser/cli_tools/__init__.py +++ b/phaser/cli_tools/__init__.py @@ -19,4 +19,6 @@ def tools(): tools.add_command(calc_tilt) tools.add_command(calc_drift) tools.add_command(view_raw) +tools.add_command(view_prepared) +tools.add_command(view_output) From e71ecda13cda583077498f60169d982a552b59ab Mon Sep 17 00:00:00 2001 From: subangstrom Date: Mon, 25 Aug 2025 16:25:10 -0400 Subject: [PATCH 7/8] should now be all available --- phaser/cli_tools/__init__.py | 10 ++++++++ phaser/cli_tools/main.py | 43 --------------------------------- phaser/cli_tools/to_csv.py | 47 ++++++++++++++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 43 deletions(-) delete mode 100644 phaser/cli_tools/main.py create mode 100644 phaser/cli_tools/to_csv.py diff --git a/phaser/cli_tools/__init__.py b/phaser/cli_tools/__init__.py index f91477f..0252aeb 100644 --- a/phaser/cli_tools/__init__.py +++ b/phaser/cli_tools/__init__.py @@ -7,6 +7,8 @@ from .view_raw import view_raw from .view_prepared import view_prepared from .view_output import view_output +from .extract_params import extract_params +from .to_csv import to_csv @click.group() def tools(): @@ -21,4 +23,12 @@ def tools(): tools.add_command(view_raw) tools.add_command(view_prepared) tools.add_command(view_output) +tools.add_command(extract_params) +tools.add_command(to_csv) + + + +# @click.command(cls=MainCommand, commands=dict((v, v) for v in +# ('prepare', 'run', 'view_raw', 'view_prepared', 'view_output', +# 'process_metadata', 'extract_params', 'to_csv', 'calc_drift', 'calc_tilt') diff --git a/phaser/cli_tools/main.py b/phaser/cli_tools/main.py deleted file mode 100644 index 29344af..0000000 --- a/phaser/cli_tools/main.py +++ /dev/null @@ -1,43 +0,0 @@ -import typing as t - -import click - - -class MainCommand(click.MultiCommand): - def __init__(self, commands: t.Union[t.Iterable[str], t.Dict[str, t.Union[str, t.Tuple[str, str]]]], **kwargs): - super().__init__(**kwargs) - self.commands: t.Dict[str, t.Union[str, t.Tuple[str, str]]] - if isinstance(commands, dict): - self.commands = commands - else: - self.commands = dict((v, v) for v in commands) - - def list_commands(self, ctx: click.Context): - return list(self.commands.keys()) - - def get_command(self, ctx: click.Context, cmd_name: str) -> t.Optional[click.Command]: - name = cmd_name.lower() - val = (self.commands.get(name) or - self.commands.get(name.replace('-', '_'))) - if val is None: - return None - if isinstance(val, tuple): - (module, func) = val - else: - module = val - func = val - mod = __import__(f"{__package__}.{module}", None, None, [func]) - return getattr(mod, func) - - -@click.command(cls=MainCommand, commands=dict((v, v) for v in - ('prepare', 'run', 'view_raw', 'view_prepared', 'view_output', - 'process_metadata', 'extract_params', 'to_csv', 'calc_drift', 'calc_tilt') -)) -def main(): - """LeBeau group ptychography utilities.""" - ... - - -if __name__ == '__main__': - main() diff --git a/phaser/cli_tools/to_csv.py b/phaser/cli_tools/to_csv.py new file mode 100644 index 0000000..10e8148 --- /dev/null +++ b/phaser/cli_tools/to_csv.py @@ -0,0 +1,47 @@ + +import sys +from pathlib import Path +import typing as t + +import click + +from .metadata import EmpadMetadata +from .metadata import to_csv as _to_csv + + +@click.command() +@click.argument('path', type=click.Path(), nargs=-1, required=False) +@click.option('--out', '-o', type=click.Path(allow_dash=True)) +def to_csv(path: t.Union[str, Path, t.Sequence[t.Union[str, Path]]], out: t.Union[str, Path, None] = None): + """ + Export ptychography metadata to a CSV file. + + `path` may either be a list of directories to search within, or a list of metadata JSON files. + If unspecified, the current directory will be searched. + + The resulting CSV file is written to the `out` file, or to stdout. + """ + if isinstance(path, (str, Path)): + paths = [path] + elif path is None or len(path) == 0: + paths = [''] + else: + paths = path + + def resolve_paths(paths: t.Iterable[t.Union[str, Path]]) -> t.Iterator[Path]: + for path in map(Path, paths): + if path.is_dir(): + yield from path.rglob('*.json') + else: + yield path + + def parse_metadata(paths: t.Iterable[Path]) -> t.Iterator[EmpadMetadata]: + for path in paths: + print(f"Parsing '{path}'", file=sys.stderr) + try: + yield EmpadMetadata.parse_file(path) + except Exception as e: + print(f"Couldn't parse '{path}', may not be a metadata file. Skipping.", file=sys.stderr) + + f = sys.stdout if out in (None, '-') else out + _to_csv(f, parse_metadata(resolve_paths(paths))) From 731f138cb9a544a8b2889d25a9f7faaa48c0cc91 Mon Sep 17 00:00:00 2001 From: subangstrom Date: Wed, 27 Aug 2025 14:37:40 -0400 Subject: [PATCH 8/8] brought over raw reader, needs to be updated to call from phaser io --- phaser/cli_tools/raw.py | 53 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 phaser/cli_tools/raw.py diff --git a/phaser/cli_tools/raw.py b/phaser/cli_tools/raw.py new file mode 100644 index 0000000..e7590c1 --- /dev/null +++ b/phaser/cli_tools/raw.py @@ -0,0 +1,53 @@ +import re +from pathlib import Path +import typing as t + +import numpy +from numpy.typing import NDArray + + +def load_4d(path: t.Union[str, Path]) -> NDArray[numpy.float32]: + """ + Load a raw EMPAD dataset into memory. Scan dimensions are inferred from the filename. + The file is loaded so the dimensions are: (scan_y, scan_x, k_y, k_x), with y decreasing downwards. + """ + path = Path(path) + match = re.search(r"x(\d+)_y(\d+)", path.name) + if match: + n_x, n_y = map(int, (match[1], match[2])) + else: + raise ValueError(f"Unable to infer probe dimensions from name {path.name}") + + a = numpy.memmap(path, dtype=numpy.float32, mode='r') + if not a.size % (130*128) == 0: + raise ValueError(f"File not divisible by 130x128 (size={a.size}).") + a.shape = (-1, 130, 128) + #a = a[:, :128, :] + + if a.shape[0] != n_x * n_y: + raise ValueError(f"Got {a.shape[0]} probes, expected {n_x}x{n_y} = {n_x * n_y}.") + a.shape = (n_y, n_x, *a.shape[1:]) + #print(a.shape) + #with open(path.parent / "scan_x128_y64.raw", 'wb') as f: + # a[:64, :, :, :].ravel().tofile(f) + #a = numpy.swapaxes(a, 0, 1) + a = a[..., 127::-1, :] # flip reciprocal y space, crop junk rows + + return a + + +def save_4d(a: NDArray[numpy.float32], folder: t.Union[str, Path, None] = None, name: str = "scan_x{x}_y{y}.raw"): + """Save a raw EMPAD dataset to `folder`.""" + assert len(a.shape) == 4 + assert a.shape[2:] == (128, 128) + n_y, n_x = a.shape[:2] + + out_shape = list(a.shape) + out_shape[2] = 130 # dead rows + + out = numpy.zeros(out_shape, dtype=numpy.float32) + out[..., 127::-1, :] = a + + path = Path(folder or Path()) / name.format(x=n_x, y=n_y) + with open(path, 'wb') as f: + out.tofile(f)