Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 79 additions & 3 deletions phaser/engines/common/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,51 @@
from phaser.state import ReconsState
from phaser.plan import SaveOptions

## TODO: move these utility functions to something like utils.plotting
def interpolate_values(a1, a2, c, d, cutoff=0.5):
c = max(c, d*cutoff)
c = min(c, d)
interpolated_value1 = a1/(1-cutoff)/d*c+a1-a1/(1-cutoff)
interpolated_value2 = (a2-1)/(1-cutoff)/d * c + a2 -(a2-1)/(1-cutoff)
interpolated_value1 = max(0, min(1, interpolated_value1))
interpolated_value2 = max(0, min(1, interpolated_value2))
return (interpolated_value1, interpolated_value2)

def grad_to_rgb(angle, absolute, max_abs, offset, SL = False):
import matplotlib.colors
"""Get the rgb value for the given `angle` and the `absolute` value
Intended usage: grad = numpy.array(list(map(lambda p, q: grad_to_rgb(p, q, max_abs), deg, mag)))

Parameters
----------
angle : float
The angle in radians
absolute : float
The absolute value of the gradient
max_abs : float
The maximum value against which all others are normalized (color-wise).

Returns
-------
array_like
The rgb value as a tuple with values [0..1]
"""

# normalize angle
angle = (angle + offset) % (2 * numpy.pi)
if angle < 0:
angle += 2 * numpy.pi

if SL:
# print('specify global max_abs!!!')
return matplotlib.colors.hsv_to_rgb((angle / 2 / numpy.pi,
numpy.where(absolute / max_abs>1, 1, absolute / max_abs),
numpy.where(absolute / max_abs>1, 1, absolute / max_abs)))
else:
return matplotlib.colors.hsv_to_rgb((angle / 2 / numpy.pi,
1,
1))


def output_images(state: ReconsState, out_dir: Path, options: SaveOptions):
for ty in options.images:
Expand Down Expand Up @@ -176,6 +221,36 @@ def _save_object_mag(state: ReconsState, out_path: Path, options: SaveOptions, s
write_opts['metadata']['axes'] = 'ZYX' if stack else 'YX'
w.write(obj_mag, **write_opts)

## TODO: refactor using colorize_complex to match tilt implementation
def _plot_scan_update(state: ReconsState, out_path: Path, options: SaveOptions):
from matplotlib import pyplot
fig, ax = pyplot.subplots(figsize=(4, 4), dpi=options.plot_dpi, constrained_layout=True)
ax.set_aspect(1.)
[left, right, bottom, top] = state.object.sampling.mpl_extent()
ax.set_xlim(left, right)
ax.set_ylim(bottom, top)

scan = to_numpy(state.scan.data)

scan = to_numpy(state.scan.data)
disp = to_numpy(state.scan.data) - to_numpy(state.scan.initial_scan)
pos_y, pos_x = scan[..., 1].ravel(), scan[..., 0].ravel()
dY, dX = disp[..., 1].ravel(), disp[..., 0].ravel()

deg = -numpy.arctan2(dY, dX)
mag = numpy.sqrt(dY ** 2 + dX ** 2)
max_abs = numpy.nanmax(mag)

grad = numpy.array(list(map(lambda p, q: grad_to_rgb(p, q, max_abs, offset=0), deg, mag)))

width = 0.1
# hal = hl = 1. / width * length
headwidth = width
ax.quiver(pos_y, pos_x, dY, dX, color=grad, angles='xy', units='xy', width=width, edgecolor=None, linewidth=1, pivot='mid', scale_units='xy', scale=.1)
# ax.set_axis_off()

fig.savefig(out_path)
pyplot.close(fig)

def _plot_scan(state: ReconsState, out_path: Path, options: SaveOptions):
from matplotlib import pyplot
Expand All @@ -186,7 +261,7 @@ def _plot_scan(state: ReconsState, out_path: Path, options: SaveOptions):
ax.set_xlim(left, right)
ax.set_ylim(bottom, top)

scan = to_numpy(state.scan)
scan = to_numpy(state.scan.data)
i = numpy.arange(scan[..., 0].size)
ax.scatter(scan[..., 1].ravel(), scan[..., 0].ravel(), c=i, cmap='plasma', s=0.5, edgecolors='none')

Expand All @@ -209,7 +284,7 @@ def _plot_tilt(state: ReconsState, out_path: Path, options: SaveOptions):
ax.set_xlim(left, right)
ax.set_ylim(bottom, top)

scan = to_numpy(state.scan)
scan = to_numpy(state.scan.data)
tilt = to_numpy(state.tilt)
tilt = tilt[..., 1] + tilt[..., 0]*1.j
max_tilt = max(numpy.max(numpy.abs(tilt)), 1.0) # at least 1 mrad
Expand Down Expand Up @@ -247,7 +322,8 @@ def _plot_tilt(state: ReconsState, out_path: Path, options: SaveOptions):
'object_mag_stack': partial(_save_object_mag, stack=True),
'object_mag_sum': partial(_save_object_mag, stack=False),
'scan': _plot_scan,
'scan_update': _plot_scan_update,
'tilt': _plot_tilt,
}
# save functions with special handling of file extensions
_PLOT_FUNCS: t.Set[str] = {'scan', 'tilt'}
_PLOT_FUNCS: t.Set[str] = {'scan', 'tilt', 'scan_update'}
4 changes: 2 additions & 2 deletions phaser/engines/common/position_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def __init__(self, args: None, props: MomentumPositionSolverProps):
self.momentum = props.momentum

def init_state(self, sim: ReconsState) -> NDArray[numpy.floating]:
xp = get_array_module(sim.scan)
return xp.zeros_like(sim.scan)
xp = get_array_module(sim.scan.data)
return xp.zeros_like(sim.scan.data)

def perform_update(
self,
Expand Down
169 changes: 157 additions & 12 deletions phaser/engines/common/regularizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,157 @@
)
from phaser.utils.image import convolve1d
from phaser.state import ReconsState
from phaser.hooks.regularization import (
from phaser.hooks.regularization import (ScanConstraintProps,
ClampObjectAmplitudeProps, LimitProbeSupportProps, NonNegObjectPhaseProps,
RegularizeLayersProps, ObjLowPassProps, GaussianProps,
CostRegularizerProps, TVRegularizerProps, UnstructuredGaussianProps
)

logger = logging.getLogger(__name__)

class ScanUpdate(t.NamedTuple):
"""
Scan update object for holding the scan constraint row index and previous position arrays
"""
previous: numpy.typing.NDArray[numpy.floating]
row_bins: t.Optional[numpy.typing.NDArray[numpy.integer]] = None
class ScanConstraint:
"""
Constraints for the scan position updates.
This per iteration regularizer takes the unconstrained position updates and applies a
weighted average of its affine, line (row) averaged, high pass filtered, or low pass filtered components.
Currently, only the affine and line averaged constraints are implemented.

See ref for details:
S. Ning, W. Xu, L. Loh, Z. Lu, M. Bosman, F. Zhang, Q. He, An integrated constrained gradient descent (iCGD) protocol to correct scan-positional errors for electron ptychography with high accuracy and precision. Ultramicroscopy 248, 113716 (2023).
"""
def __init__(self, args: None, props: ScanConstraintProps):
self.valid_kinds: t.Set[str] = {'affine', 'line', 'hpf', 'lpf'}
self.constraints: t.Dict[str, float] = {} #= {'default': 1.0}

total_constraint_weight = 0
for kind in self.valid_kinds:
if getattr(props, kind) > 0:
val = getattr(props, kind)
self.constraints[kind] = val

total_constraint_weight += val

if total_constraint_weight > 1.0:
raise ValueError("Sum of scan constraint weights cannot exceed 1.0")

self.constraints['default'] = 1-total_constraint_weight

logger.info(f"Initialized scan constraint with kinds {list(self.constraints.keys())} and weights {list(self.constraints.values())}")

def init_state(self, sim: ReconsState) -> ScanUpdate:
if 'line' in self.constraints:
if (sim.scan.metadata.get('type') != 'raster') | (sim.scan.metadata.get('rows') is None):
raise ValueError("Line scan constraint cannot be applied to scans without row metadata")
row_vals = sim.scan.metadata.get('rows')
if isinstance(row_vals, list):
row_vals = numpy.array(row_vals, dtype=numpy.integer)
state = ScanUpdate(previous=sim.scan.data.copy(), row_bins=row_vals.ravel())
else:
state = ScanUpdate(previous=sim.scan.data.copy(), row_bins=None)
return state

def apply_group(self, group: NDArray[numpy.integer], sim: ReconsState, state: ScanUpdate) -> t.Tuple[ReconsState, ScanUpdate]:
return self.apply_iter(sim, state)

def apply_iter(self, sim: ReconsState, state: ScanUpdate) -> t.Tuple[ReconsState, ScanUpdate]:
xp = get_array_module(sim.scan.data)
update = xp.zeros_like(sim.scan.data, dtype=sim.scan.data.dtype)
for kind, weight in self.constraints.items():
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this deterministic? It seems like it could apply the updates in arbitrary order, we may want to add sorted() if it matters

Copy link
Contributor Author

@michaelxu01 michaelxu01 Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be deterministic, since the updates are summed and applied after each kind * weight is calculated from the common unconstrained update and previous scan state. Example:
update += scan_affine(sim.scan.data, state.previous) * weight

match kind:
case 'affine':
update += _scan_affine(sim.scan.data, state.previous) * weight
case 'line':
if state.row_bins is not None:
update += _scan_line(sim.scan.data, state.previous, state.row_bins) * weight
case 'hpf':
pass
case 'lpf':
pass
case 'default':
update += _scan_default(sim.scan.data, state.previous) * weight

sim.scan.data = state.previous + update
state = ScanUpdate(previous=sim.scan.data.copy(), row_bins=state.row_bins)
return (sim, state)

# @partial(jit, donate_argnames=('pos',), cupy_fuse=True)
def _scan_default(
pos: NDArray[numpy.floating],
prev: NDArray[numpy.floating],
) -> NDArray[numpy.floating]:
"""
Pass through function for calculating the scan update from final and initial scan positions.

:param pos: N x 2 array of unconstrained updated scan positions
:type pos: NDArray[numpy.floating]
:param prev: N x 2 array of scan positions, before update (previous iteration)
:type prev: NDArray[numpy.floating]
:return: N x 2 array of updates to the scan positions
:rtype: NDArray[floating[Any]]
"""
return pos - prev

def _scan_affine(
pos: NDArray[numpy.floating],
prev: NDArray[numpy.floating],
) -> NDArray[numpy.floating]:
"""
Calculates and returns the affine component of the update between final and initial scan positions.

:param pos: N x 2 array of unconstrained updated scan positions
:type pos: NDArray[numpy.floating]
:param prev: N x 2 array of scan positions, before update (previous iteration)
:type prev: NDArray[numpy.floating]
:return: N x 2 array of updates to the scan positions (affine only)
:rtype: NDArray[floating[Any]]
"""
xp = get_array_module(pos)

disp_update = pos - prev
ones = xp.ones((pos.shape[0], 1), pos.dtype)
pos_prev = xp.concatenate([pos, ones], axis=1)
left = xp.matmul(pos_prev.T, disp_update)
right = xp.matmul(pos_prev.T, pos_prev)
A = xp.matmul(xp.linalg.inv(right), left)
constraint = xp.matmul(pos_prev, A)
center_ones = xp.ones((1, 1), pos.dtype)
center = xp.concatenate([xp.average(pos, axis = 0, keepdims=True), center_ones], axis=1, dtype=pos.dtype)
center_shift = center @ A
constraint -= center_shift
return constraint

# @partial(jit, donate_argnames=('pos',), cupy_fuse=True)
def _scan_line(
pos: NDArray[numpy.floating],
prev: NDArray[numpy.floating],
rows: NDArray[numpy.integer],
) -> NDArray[numpy.floating]:
"""
Calculates and returns a line (row) averaged update from the unconstrained final and initial scan positions.

:param pos: N x 2 array of unconstrained updated scan positions
:type pos: NDArray[numpy.floating]
:param prev: N x 2 array of scan positions, before update (previous iteration)
:type prev: NDArray[numpy.floating]
:param rows: an array of row indices corresponding to the N positions given by pos and prev
:type rows: NDArray[numpy.integer]
:return: N x 2 array of updates to the scan positions (row averaged)
:rtype: NDArray[floating[Any]]
"""
xp = get_array_module(pos)

disp_val = pos - prev

y_shifts = xp.bincount(rows, disp_val[:,0]) / xp.bincount(rows)
x_shifts = xp.bincount(rows, disp_val[:,1]) / xp.bincount(rows)
constraint = xp.stack([y_shifts[rows], x_shifts[rows]], axis=1, dtype=pos.dtype)
return constraint

class ClampObjectAmplitude:
def __init__(self, args: None, props: ClampObjectAmplitudeProps):
Expand Down Expand Up @@ -250,7 +395,7 @@ def calc_loss_group(
xp = get_array_module(sim.object.data)

cost = xp.sum(xp.abs(sim.object.data - 1.0))
cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype)
cost_scale = xp.array(group.shape[-1] / prod(sim.scan.data.shape[:-1]), dtype=cost.dtype)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe unnecessary for this PR, but we could probably make n_pos() a method of sim.scan to avoid this repetition

return (cost * cost_scale * self.cost, state)


Expand All @@ -272,7 +417,7 @@ def calc_loss_group(

cost = xp.sum(abs2(sim.object.data - 1.0))

cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype)
cost_scale = xp.array(group.shape[-1] / prod(sim.scan.data.shape[:-1]), dtype=cost.dtype)
return (cost * cost_scale * self.cost, state) # type: ignore


Expand All @@ -293,7 +438,7 @@ def calc_loss_group(
xp = get_array_module(sim.object.data)

cost = xp.sum(xp.abs(xp.angle(sim.object.data)))
cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype)
cost_scale = xp.array(group.shape[-1] / prod(sim.scan.data.shape[:-1]), dtype=cost.dtype)
return (cost * cost_scale * self.cost, state)


Expand All @@ -319,7 +464,7 @@ def calc_loss_group(
xp.abs(fft2(xp.prod(sim.object.data, axis=0)))
)
# scale cost by fraction of the total reconstruction in the group
cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype)
cost_scale = xp.array(group.shape[-1] / prod(sim.scan.data.shape[:-1]), dtype=cost.dtype)

return (cost * cost_scale * self.cost, state)

Expand Down Expand Up @@ -351,7 +496,7 @@ def calc_loss_group(
#)
# scale cost by fraction of the total reconstruction in the group
# TODO also scale by # of pixels or similar?
cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype)
cost_scale = xp.array(group.shape[-1] / prod(sim.scan.data.shape[:-1]), dtype=cost.dtype)

return (cost * cost_scale * self.cost, state)

Expand All @@ -377,7 +522,7 @@ def calc_loss_group(
xp.sum(abs2(xp.diff(sim.object.data, axis=-2)))
)
# scale cost by fraction of the total reconstruction in the group
cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype)
cost_scale = xp.array(group.shape[-1] / prod(sim.scan.data.shape[:-1]), dtype=cost.dtype)

return (cost * cost_scale * self.cost, state) # type: ignore

Expand All @@ -403,7 +548,7 @@ def calc_loss_group(

cost = xp.sum(xp.abs(xp.diff(sim.object.data, axis=0)))
# scale cost by fraction of the total reconstruction in the group
cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype)
cost_scale = xp.array(group.shape[-1] / prod(sim.scan.data.shape[:-1]), dtype=cost.dtype)

return (cost * cost_scale * self.cost, state)

Expand All @@ -429,7 +574,7 @@ def calc_loss_group(

cost = xp.sum(abs2(xp.diff(sim.object.data, axis=0)))
# scale cost by fraction of the total reconstruction in the group
cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype)
cost_scale = xp.array(group.shape[-1] / prod(sim.scan.data.shape[:-1]), dtype=cost.dtype)

return (cost * cost_scale * self.cost, state) # type: ignore

Expand Down Expand Up @@ -519,7 +664,7 @@ def __init__(self, args: None, props: UnstructuredGaussianProps):
self.attr_path = props.attr_path

def init_state(self, sim: ReconsState) -> NDArray[numpy.floating]:
xp = get_array_module(sim.scan)
xp = get_array_module(sim.scan.data)
try:
self.getattr_nested(sim, self.attr_path)
except AttributeError as e:
Expand Down Expand Up @@ -547,8 +692,8 @@ def setattr_nested(self, obj: t.Any, attr_path: str, value: t.Any):
def apply_iter(self, sim: ReconsState, state: NDArray[numpy.floating]) -> t.Tuple[ReconsState, NDArray[numpy.floating]]:
from scipy.spatial import KDTree
obj_samp = sim.object.sampling
scan_flat = sim.scan.reshape(-1, 2)
scan_ndim = sim.scan.ndim - 1
scan_flat = sim.scan.data.reshape(-1, 2)
scan_ndim = sim.scan.data.ndim - 1

attr = self.getattr_nested(sim, self.attr_path)
vals = t.cast(NDArray[numpy.inexact], getattr(attr, 'data', attr)) # Extract raw array
Expand Down
2 changes: 1 addition & 1 deletion phaser/engines/common/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def cutout_group(
"""Returns (probe, obj) in the cutout region"""
probes = state.probe.data

group_scan = state.scan[tuple(group)]
group_scan = state.scan.data[tuple(group)]
group_obj = state.object.sampling.get_view_at_pos(state.object.data, group_scan, probes.shape[-2:])
# group probes in real space
# shape (len(group), 1, Ny, Nx)
Expand Down
Loading