From 378468653d38d88914838084c59f807caded4fdb Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Mon, 10 Mar 2025 13:28:53 -0400 Subject: [PATCH 1/5] Add VectorMoment operator and spherical implementation --- dedalus/core/operators.py | 218 +++++++++++++++++++++++++++----------- 1 file changed, 154 insertions(+), 64 deletions(-) diff --git a/dedalus/core/operators.py b/dedalus/core/operators.py index 9a9a993d..f7b941d6 100644 --- a/dedalus/core/operators.py +++ b/dedalus/core/operators.py @@ -3832,70 +3832,8 @@ def operate(self, out): np.copyto(out.data, arg0.data) -class SphericalCurl(Curl, SphericalEllOperator): - - cs_type = coords.SphericalCoordinates - - def __init__(self, operand, index=0, out=None): - Curl.__init__(self, operand, out=out) - if index != 0: - raise ValueError("Curl only implemented along index 0.") - self.index = index - coordsys = operand.tensorsig[index] - SphericalEllOperator.__init__(self, operand, coordsys) - # FutureField requirements - self.domain = operand.domain.substitute_basis(self.input_basis, self.output_basis) - self.tensorsig = (coordsys,) + operand.tensorsig[:index] + operand.tensorsig[index+1:] - self.dtype = operand.dtype - - @staticmethod - def _output_basis(input_basis): - return input_basis.derivative_basis(1) - - def check_conditions(self): - """Check that operands are in a proper layout.""" - # Require radius to be in coefficient space - layout = self.args[0].layout - return (not layout.grid_space[self.radius_axis]) and (layout.local[self.radius_axis]) - - def enforce_conditions(self): - """Require operands to be in a proper layout.""" - # Require radius to be in coefficient space - self.args[0].require_coeff_space(self.radius_axis) - self.args[0].require_local(self.radius_axis) - - def regindex_out(self, regindex_in): - # Regorder: -, +, 0 - # - and + map to 0 - if regindex_in[0] in (0, 1): - return ((2,) + regindex_in[1:],) - # 0 maps to - and + - else: - return ((0,) + regindex_in[1:], (1,) + regindex_in[1:]) - - @CachedMethod - def radial_matrix(self, regindex_in, regindex_out, ell): - radial_basis = self.radial_basis - regtotal_in = radial_basis.regtotal(regindex_in) - regtotal_out = radial_basis.regtotal(regindex_out) - if regindex_in[1:] == regindex_out[1:]: - return self._radial_matrix(radial_basis, regindex_in[0], regindex_out[0], regtotal_in, regtotal_out, ell) - else: - raise ValueError("This should never happen") - - @staticmethod - @CachedMethod - def _radial_matrix(radial_basis, regindex_in0, regindex_out0, regtotal_in, regtotal_out, ell): - if regindex_in0 == 0 and regindex_out0 == 2: - return -1j * radial_basis.xi(+1, ell+regtotal_in+1) * radial_basis.operator_matrix('D+', ell, regtotal_in) - elif regindex_in0 == 1 and regindex_out0 == 2: - return 1j * radial_basis.xi(-1, ell+regtotal_in-1) * radial_basis.operator_matrix('D-', ell, regtotal_in) - elif regindex_in0 == 2 and regindex_out0 == 0: - return -1j * radial_basis.xi(+1, ell+regtotal_in) * radial_basis.operator_matrix('D-', ell, regtotal_in) - elif regindex_in0 == 2 and regindex_out0 == 1: - return 1j * radial_basis.xi(-1, ell+regtotal_in) * radial_basis.operator_matrix('D+', ell, regtotal_in) - else: - raise ValueError("This should never happen") +class ImaginarySphericalEllOperator(SphericalEllOperator): + """SphericalEllOperator with imaginary symbols.""" def subproblem_matrix(self, subproblem): if self.dtype == np.complex128: @@ -3975,6 +3913,158 @@ def operate(self, out): comp_out[tuple(slices)][msin_slice] += vec_out_complex.imag +class SphericalCurl(Curl, ImaginarySphericalEllOperator): + + cs_type = coords.SphericalCoordinates + + def __init__(self, operand, index=0, out=None): + Curl.__init__(self, operand, out=out) + if index != 0: + raise ValueError("Curl only implemented along index 0.") + self.index = index + coordsys = operand.tensorsig[index] + SphericalEllOperator.__init__(self, operand, coordsys) + # FutureField requirements + self.domain = operand.domain.substitute_basis(self.input_basis, self.output_basis) + self.tensorsig = (coordsys,) + operand.tensorsig[:index] + operand.tensorsig[index+1:] + self.dtype = operand.dtype + + @staticmethod + def _output_basis(input_basis): + return input_basis.derivative_basis(1) + + def check_conditions(self): + """Check that operands are in a proper layout.""" + # Require radius to be in coefficient space + layout = self.args[0].layout + return (not layout.grid_space[self.radius_axis]) and (layout.local[self.radius_axis]) + + def enforce_conditions(self): + """Require operands to be in a proper layout.""" + # Require radius to be in coefficient space + self.args[0].require_coeff_space(self.radius_axis) + self.args[0].require_local(self.radius_axis) + + def regindex_out(self, regindex_in): + # Regorder: -, +, 0 + # - and + map to 0 + if regindex_in[0] in (0, 1): + return ((2,) + regindex_in[1:],) + # 0 maps to - and + + else: + return ((0,) + regindex_in[1:], (1,) + regindex_in[1:]) + + @CachedMethod + def radial_matrix(self, regindex_in, regindex_out, ell): + radial_basis = self.radial_basis + regtotal_in = radial_basis.regtotal(regindex_in) + regtotal_out = radial_basis.regtotal(regindex_out) + if regindex_in[1:] == regindex_out[1:]: + return self._radial_matrix(radial_basis, regindex_in[0], regindex_out[0], regtotal_in, regtotal_out, ell) + else: + raise ValueError("This should never happen") + + @staticmethod + @CachedMethod + def _radial_matrix(radial_basis, regindex_in0, regindex_out0, regtotal_in, regtotal_out, ell): + if regindex_in0 == 0 and regindex_out0 == 2: + return -1j * radial_basis.xi(+1, ell+regtotal_in+1) * radial_basis.operator_matrix('D+', ell, regtotal_in) + elif regindex_in0 == 1 and regindex_out0 == 2: + return 1j * radial_basis.xi(-1, ell+regtotal_in-1) * radial_basis.operator_matrix('D-', ell, regtotal_in) + elif regindex_in0 == 2 and regindex_out0 == 0: + return -1j * radial_basis.xi(+1, ell+regtotal_in) * radial_basis.operator_matrix('D-', ell, regtotal_in) + elif regindex_in0 == 2 and regindex_out0 == 1: + return 1j * radial_basis.xi(-1, ell+regtotal_in) * radial_basis.operator_matrix('D+', ell, regtotal_in) + else: + raise ValueError("This should never happen") + + +@alias("moment") +class VectorMoment(LinearOperator, metaclass=MultiClass): + + name = 'Moment' + + @classmethod + def _check_args(cls, operand, index=0, out=None): + # Dispatch by coordinate system + if isinstance(operand, Operand): + if isinstance(operand.tensorsig[index], cls.cs_type): + return True + return False + + def new_operand(self, operand, **kw): + return VectorMoment(operand, index=self.index, **kw) + + +class SphericalVectorMoment(VectorMoment, ImaginarySphericalEllOperator): + + cs_type = coords.SphericalCoordinates + + def __init__(self, operand, index=0, out=None): + VectorMoment.__init__(self, operand, out=out) + if index != 0: + raise ValueError("Moment only implemented along index 0.") + self.index = index + coordsys = operand.tensorsig[index] + SphericalEllOperator.__init__(self, operand, coordsys) + # FutureField requirements + self.domain = operand.domain.substitute_basis(self.input_basis, self.output_basis) + self.tensorsig = operand.tensorsig + self.dtype = operand.dtype + + @staticmethod + def _output_basis(input_basis): + return input_basis + + def check_conditions(self): + """Check that operands are in a proper layout.""" + # Require radius to be in coefficient space + layout = self.args[0].layout + return (not layout.grid_space[self.radius_axis]) and (layout.local[self.radius_axis]) + + def enforce_conditions(self): + """Require operands to be in a proper layout.""" + # Require radius to be in coefficient space + self.args[0].require_coeff_space(self.radius_axis) + self.args[0].require_local(self.radius_axis) + + def regindex_out(self, regindex_in): + # Regorder: -, +, 0 + Rm, Rp, R0 = 0, 1, 2 + # - and + map to 0 + if regindex_in[0] in (Rm, Rp): + return ((R0,) + regindex_in[1:],) + # 0 maps to - and + + else: + return ((Rm,) + regindex_in[1:], (Rp,) + regindex_in[1:]) + + @CachedMethod + def radial_matrix(self, regindex_in, regindex_out, ell): + radial_basis = self.radial_basis + regtotal_in = radial_basis.regtotal(regindex_in) + regtotal_out = radial_basis.regtotal(regindex_out) + if regindex_in[1:] == regindex_out[1:]: + return self._radial_matrix(radial_basis, regindex_in[0], regindex_out[0], regtotal_in, regtotal_out, ell) + else: + raise ValueError("This should never happen") + + @staticmethod + @CachedMethod + def _radial_matrix(radial_basis, regindex_in0, regindex_out0, regtotal_in, regtotal_out, ell): + # Regorder: -, +, 0 + Rm, Rp, R0 = 0, 1, 2 + if regindex_in0 == Rm and regindex_out0 == R0: + return -1j * radial_basis.xi(+1, ell+regtotal_in+1) * radial_basis.operator_matrix('R+', ell, regtotal_in) + elif regindex_in0 == Rp and regindex_out0 == R0: + return 1j * radial_basis.xi(-1, ell+regtotal_in-1) * radial_basis.operator_matrix('R-', ell, regtotal_in) + elif regindex_in0 == R0 and regindex_out0 == Rm: + return -1j * radial_basis.xi(+1, ell+regtotal_in) * radial_basis.operator_matrix('R-', ell, regtotal_in) + elif regindex_in0 == R0 and regindex_out0 == Rp: + return 1j * radial_basis.xi(-1, ell+regtotal_in) * radial_basis.operator_matrix('R+', ell, regtotal_in) + else: + raise ValueError("This should never happen") + + @alias("lap") class Laplacian(LinearOperator, metaclass=MultiClass): From 6130dbdbd57bb30e979df017d69f7b16c397b3f2 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 11 Mar 2025 09:21:12 -0400 Subject: [PATCH 2/5] Skip 0 exception for VectorMoment --- dedalus/core/operators.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/dedalus/core/operators.py b/dedalus/core/operators.py index f7b941d6..d61aa68c 100644 --- a/dedalus/core/operators.py +++ b/dedalus/core/operators.py @@ -3984,6 +3984,12 @@ class VectorMoment(LinearOperator, metaclass=MultiClass): name = 'Moment' + @classmethod + def _preprocess_args(cls, operand, index=0, out=None): + if operand == 0: + raise SkipDispatchException(output=0) + return [operand], {'index': index, 'out': out} + @classmethod def _check_args(cls, operand, index=0, out=None): # Dispatch by coordinate system From 77c24cdea0e4c960065c8d402404807e2a8ed224 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Wed, 12 Mar 2025 15:26:12 -0400 Subject: [PATCH 3/5] Add HarmonicTrace operator for the ball --- dedalus/core/basis.py | 66 +++++++++++++++++++++++++++++++++++++++ dedalus/core/operators.py | 45 ++++++++++++++++++++++++++ 2 files changed, 111 insertions(+) diff --git a/dedalus/core/basis.py b/dedalus/core/basis.py index 3f6167c9..987fecf3 100644 --- a/dedalus/core/basis.py +++ b/dedalus/core/basis.py @@ -5575,6 +5575,72 @@ def _radial_matrix(basis, ell): return matrix +class HarmonicTraceBall(HarmonicTrace, SphericalEllOperator): + """Project against solid harmonics.""" + + input_coord_type = SphericalCoordinates + input_basis_type = BallBasis + + @CachedAttribute + def radial_basis(self): + return self.input_basis.radial_basis + + def regindex_out(self, regindex_in): + return (regindex_in,) + + def _output_basis(self, input_basis): + return input_basis.S2_basis() + + def operate(self, out): + """Perform operation.""" + operand = self.args[0] + input_basis = self.input_basis + output_basis = self.output_basis + radial_basis = self.radial_basis + axis = self.dist.last_axis(radial_basis) + # Set output layout + out.preset_layout(operand.layout) + # Apply operator + R = radial_basis.regularity_classes(operand.tensorsig) + slices_in = [slice(None) for i in range(self.dist.dim)] + slices_out = [slice(None) for i in range(self.dist.dim)] + for regindex, regtotal in np.ndenumerate(R): + comp_in = operand.data[regindex] + comp_out = out.data[regindex] + for ell, m_ind, ell_ind in input_basis.ell_maps(self.dist): + allowed = radial_basis.regularity_allowed(ell, regindex) + if allowed: + slices_in[axis-2] = slices_out[axis-2] = m_ind + slices_in[axis-1] = slices_out[axis-1] = ell_ind + slices_in[axis] = radial_basis.n_slice(ell) + vec_in = comp_in[tuple(slices_in)] + vec_out = comp_out[tuple(slices_out)] + A = self.radial_matrix(regindex, regindex, ell) + apply_matrix(A, vec_in, axis=axis, out=vec_out) + + def radial_matrix(self, regindex_in, regindex_out, ell): + return self._radial_matrix(self.radial_basis, ell) + + @staticmethod + @CachedMethod + def _radial_matrix(basis, ell): + n_size = basis.n_size(ell) + if basis.alpha + basis.k == 0: + # Just first mode when α+k=0 + matrix = np.zeros((1, n_size), dtype=basis.dtype) + matrix[0, 0] = 1 + else: + # Otherwise calculate with quadrature + N = basis.shape[2] * 2 # maybe necessary for dealiasing? + z0, w0 = dedalus_sphere.zernike.quadrature(3, N, k=0) + Qk = dedalus_sphere.zernike.polynomials(3, n_size, basis.alpha+basis.k, ell, z0) + Q0 = Qk[0, :] # ~r**ell + matrix = ((Q0*w0)[None, :] @ Qk.T).astype(basis.dtype) + #matrix *= basis.radius**3 + #matrix *= 4 * np.pi / np.sqrt(2) # SWSH contribution + return matrix + + class InterpolateAzimuth(FutureLockedField, operators.Interpolate): input_basis_type = (SphereBasis, BallBasis, ShellBasis, DiskBasis, AnnulusBasis) diff --git a/dedalus/core/operators.py b/dedalus/core/operators.py index d61aa68c..47c8e570 100644 --- a/dedalus/core/operators.py +++ b/dedalus/core/operators.py @@ -4071,6 +4071,51 @@ def _radial_matrix(radial_basis, regindex_in0, regindex_out0, regtotal_in, regto raise ValueError("This should never happen") +@alias("htrace") +class HarmonicTrace(LinearOperator, metaclass=MultiClass): + """Trace of harmonic part of a scalar a field on the domain boundary.""" + + name = "htrace" + + @classmethod + def _preprocess_args(cls, operand, coord): + # Handle zeros + if operand == 0: + raise SkipDispatchException(output=0) + return (operand, coord), {} + + @classmethod + def _check_args(cls, operand, coord): + # Dispatch by operand basis + if isinstance(operand, Operand): + if isinstance(coord, cls.input_coord_type): + basis = operand.domain.get_basis(coord) + if isinstance(basis, cls.input_basis_type): + return True + return False + + def __init__(self, operand, coord): + SpectralOperator.__init__(self, operand) + # Require integrand is a scalar + if coord in operand.tensorsig: + raise ValueError("Can only take harmonic trace of scalars.") + # SpectralOperator requirements + self.coord = coord + self.input_basis = operand.domain.get_basis(coord) + self.output_basis = self._output_basis(self.input_basis) + self.first_axis = self.dist.get_basis_axis(self.input_basis) + self.last_axis = self.first_axis + self.input_basis.dim - 1 + # LinearOperator requirements + self.operand = operand + # FutureField requirements + self.domain = operand.domain.substitute_basis(self.input_basis, self.output_basis) + self.tensorsig = operand.tensorsig + self.dtype = operand.dtype + + def new_operand(self, operand, **kw): + return HarmonicTrace(operand, self.coord, **kw) + + @alias("lap") class Laplacian(LinearOperator, metaclass=MultiClass): From 07599e52861bfb39db2ab8fbe3574c3adab3c8d2 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Thu, 13 Mar 2025 17:16:29 -0400 Subject: [PATCH 4/5] Add MomentDivergence operator that maintains k --- dedalus/core/operators.py | 91 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) diff --git a/dedalus/core/operators.py b/dedalus/core/operators.py index 47c8e570..c174a2c7 100644 --- a/dedalus/core/operators.py +++ b/dedalus/core/operators.py @@ -4071,6 +4071,97 @@ def _radial_matrix(radial_basis, regindex_in0, regindex_out0, regtotal_in, regto raise ValueError("This should never happen") +@alias("momentdiv") +class MomentDivergence(LinearOperator, metaclass=MultiClass): + """Divergence of vector moment: MomentDivergence(u) = div(cross(r, u))""" + + name = 'MomentDivergence' + + @classmethod + def _preprocess_args(cls, operand, index=0, out=None): + if operand == 0: + raise SkipDispatchException(output=0) + return [operand], {'index': index, 'out': out} + + @classmethod + def _check_args(cls, operand, index=0, out=None): + # Dispatch by coordinate system + if isinstance(operand, Operand): + if isinstance(operand.tensorsig[index], cls.cs_type): + return True + return False + + def new_operand(self, operand, **kw): + return MomentDivergence(operand, index=self.index, **kw) + + +class SphericalMomentDivergence(MomentDivergence, ImaginarySphericalEllOperator): + + cs_type = coords.SphericalCoordinates + + def __init__(self, operand, index=0, out=None): + MomentDivergence.__init__(self, operand, out=out) + if index != 0: + raise ValueError("Moment only implemented along index 0.") + self.index = index + coordsys = operand.tensorsig[index] + SphericalEllOperator.__init__(self, operand, coordsys) + # FutureField requirements + self.domain = operand.domain.substitute_basis(self.input_basis, self.output_basis) + self.tensorsig = operand.tensorsig[:index] + operand.tensorsig[index+1:] + self.dtype = operand.dtype + + @staticmethod + def _output_basis(input_basis): + return input_basis + + def check_conditions(self): + """Check that operands are in a proper layout.""" + # Require radius to be in coefficient space + layout = self.args[0].layout + return (not layout.grid_space[self.radius_axis]) and (layout.local[self.radius_axis]) + + def enforce_conditions(self): + """Require operands to be in a proper layout.""" + # Require radius to be in coefficient space + self.args[0].require_coeff_space(self.radius_axis) + self.args[0].require_local(self.radius_axis) + + def regindex_out(self, regindex_in): + # Regorder: -, +, 0 + Rm, Rp, R0 = 0, 1, 2 + # 0 -> null + if regindex_in[0] == R0: + return (regindex_in[1:],) + else: + return tuple() + + @CachedMethod + def radial_matrix(self, regindex_in, regindex_out, ell): + radial_basis = self.radial_basis + regtotal_in = radial_basis.regtotal(regindex_in) + if regindex_in[0] == 2 and regindex_in[1:] == regindex_out: + return self._radial_matrix(radial_basis, regindex_in[0], regtotal_in, ell) + else: + raise ValueError("This should never happen") + + @staticmethod + @CachedMethod + def _radial_matrix(radial_basis, regindex_in0, regtotal_in, ell): + Rm = radial_basis.operator_matrix('R-', ell, regtotal_in) + Rp = radial_basis.operator_matrix('R+', ell, regtotal_in) + Dm = radial_basis.operator_matrix('D-', ell, regtotal_in+1) + Dp = radial_basis.operator_matrix('D+', ell, regtotal_in-1) + xim = radial_basis.xi(-1, ell+regtotal_in) + xip = radial_basis.xi(+1, ell+regtotal_in) + n = radial_basis.n_size(ell) + if regindex_in0 == 2: + # return 1j * xim * xip * (Dm * Rp - Dp * Rm) + return 1j * np.sqrt(ell*(ell+1)) * sparse.identity(n, radial_basis.dtype, format='csr') + else: + raise ValueError("This should never happen") + + @alias("htrace") class HarmonicTrace(LinearOperator, metaclass=MultiClass): """Trace of harmonic part of a scalar a field on the domain boundary.""" From 8aebd6a4d35c3c06f40b6582569113c70a6c4720 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Thu, 13 Mar 2025 17:19:51 -0400 Subject: [PATCH 5/5] Typo --- dedalus/core/basis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dedalus/core/basis.py b/dedalus/core/basis.py index 987fecf3..0ccc4f41 100644 --- a/dedalus/core/basis.py +++ b/dedalus/core/basis.py @@ -5575,7 +5575,7 @@ def _radial_matrix(basis, ell): return matrix -class HarmonicTraceBall(HarmonicTrace, SphericalEllOperator): +class HarmonicTraceBall(operators.HarmonicTrace, operators.SphericalEllOperator): """Project against solid harmonics.""" input_coord_type = SphericalCoordinates