From c8133e4ff75a5ce5769cbb1f0340c17b477b54f2 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 27 May 2025 14:47:38 +0100 Subject: [PATCH 01/29] Add array namespace option for field buffers --- dedalus/core/distributor.py | 6 +++++- dedalus/core/field.py | 23 ++++++++++++++--------- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/dedalus/core/distributor.py b/dedalus/core/distributor.py index c4cc766f..6b2aedbd 100644 --- a/dedalus/core/distributor.py +++ b/dedalus/core/distributor.py @@ -10,6 +10,7 @@ from math import prod import numbers from weakref import WeakSet +import array_api_compat from .coords import CoordinateSystem, DirectProduct from ..tools.array import reshape_vector @@ -74,7 +75,7 @@ class Distributor: states) and the paths between them (D transforms and R transposes). """ - def __init__(self, coordsystems, comm=None, mesh=None, dtype=None): + def __init__(self, coordsystems, comm=None, mesh=None, dtype=None, array_namespace=np): # Accept single coordsys in place of tuple/list if not isinstance(coordsystems, (tuple, list)): coordsystems = (coordsystems,) @@ -115,6 +116,9 @@ def __init__(self, coordsystems, comm=None, mesh=None, dtype=None): self._build_layouts() # Keep set of weak field references self.fields = WeakSet() + # Array module + x = array_namespace.zeros(0) + self.array_namespace = array_api_compat.array_namespace(x) @CachedAttribute def cs_by_axis(self): diff --git a/dedalus/core/field.py b/dedalus/core/field.py index 415edcf6..0e7c0365 100644 --- a/dedalus/core/field.py +++ b/dedalus/core/field.py @@ -473,16 +473,19 @@ def evaluate(self): def reinitialize(self, **kw): return self - @staticmethod - def _create_buffer(buffer_size): + def _create_buffer(self, buffer_size): """Create buffer for Field data.""" - if buffer_size == 0: - # FFTW doesn't like allocating size-0 arrays - return np.zeros((0,), dtype=np.float64) + xp = self.array_namespace + if xp == np: + if buffer_size == 0: + # FFTW doesn't like allocating size-0 arrays + return np.zeros((0,), dtype=np.float64) + else: + # Use FFTW SIMD aligned allocation + alloc_doubles = buffer_size // 8 + return fftw.create_buffer(alloc_doubles) else: - # Use FFTW SIMD aligned allocation - alloc_doubles = buffer_size // 8 - return fftw.create_buffer(alloc_doubles) + return xp.zeros(buffer_size) @CachedAttribute def _dealias_buffer_size(self): @@ -516,12 +519,13 @@ def preset_scales(self, scales): def preset_layout(self, layout): """Interpret buffer as data in specified layout.""" + xp = self.array_namespace layout = self.dist.get_layout_object(layout) self.layout = layout tens_shape = [vs.dim for vs in self.tensorsig] local_shape = layout.local_shape(self.domain, self.scales) total_shape = tuple(tens_shape) + tuple(local_shape) - self.data = np.ndarray(shape=total_shape, + self.data = xp.ndarray(shape=total_shape, dtype=self.dtype, buffer=self.buffer) #self.global_start = layout.start(self.domain, self.scales) @@ -561,6 +565,7 @@ def __init__(self, dist, bases=None, name=None, tensorsig=None, dtype=None): dtype = dist.dtype from .domain import Domain self.dist = dist + self.array_namespace = dist.array_namespace self.name = name self.tensorsig = tensorsig self.dtype = dtype From 9d456697fa782d1bfd6534c2f39ac3d3644d9b7f Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 27 May 2025 14:49:15 +0100 Subject: [PATCH 02/29] Add array-api-compat to setup.py --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 1cf1d9dd..583009a6 100644 --- a/setup.py +++ b/setup.py @@ -181,6 +181,7 @@ def read(rel_path): # Runtime requirements install_requires = [ + "array-api-compat", "docopt", "h5py >= 3.0.0", "matplotlib", From 6db959316f96da24d9dff911bd44053fc031c031 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 27 May 2025 15:15:45 +0100 Subject: [PATCH 03/29] Allow specifying array namespace by string --- dedalus/core/distributor.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dedalus/core/distributor.py b/dedalus/core/distributor.py index 6b2aedbd..9b2a3ff7 100644 --- a/dedalus/core/distributor.py +++ b/dedalus/core/distributor.py @@ -117,8 +117,10 @@ def __init__(self, coordsystems, comm=None, mesh=None, dtype=None, array_namespa # Keep set of weak field references self.fields = WeakSet() # Array module - x = array_namespace.zeros(0) - self.array_namespace = array_api_compat.array_namespace(x) + if isinstance(array_namespace, str): + self.array_namespace = getattr(array_api_compat, array_namespace) + else: + self.array_namespace = array_api_compat.array_namespace(array_namespace.zeros(0)) @CachedAttribute def cs_by_axis(self): From a39b3458eee0d1ed67b07edeb084fe782272fbd8 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 27 May 2025 15:16:18 +0100 Subject: [PATCH 04/29] Try fixing cupy allocation from buffer --- dedalus/core/field.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/dedalus/core/field.py b/dedalus/core/field.py index 0e7c0365..3d15f864 100644 --- a/dedalus/core/field.py +++ b/dedalus/core/field.py @@ -525,9 +525,11 @@ def preset_layout(self, layout): tens_shape = [vs.dim for vs in self.tensorsig] local_shape = layout.local_shape(self.domain, self.scales) total_shape = tuple(tens_shape) + tuple(local_shape) - self.data = xp.ndarray(shape=total_shape, - dtype=self.dtype, - buffer=self.buffer) + # Handle cupy allocation + if xp.__name__ == "cupy": + self.data = xp.ndarray(shape=total_shape, dtype=self.dtype, memptr=self.buffer.data) + else: + self.data = xp.ndarray(shape=total_shape, dtype=self.dtype, buffer=self.buffer) #self.global_start = layout.start(self.domain, self.scales) From dd1f0f78e8531561d92b71acae95d42a95815516 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 27 May 2025 15:31:50 +0100 Subject: [PATCH 05/29] Fix cupy check --- dedalus/core/field.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dedalus/core/field.py b/dedalus/core/field.py index 3d15f864..5f7cabfb 100644 --- a/dedalus/core/field.py +++ b/dedalus/core/field.py @@ -7,6 +7,7 @@ from functools import partial, reduce from collections import defaultdict import numpy as np +import array_api_compat from mpi4py import MPI from scipy import sparse from scipy.sparse import linalg as splinalg @@ -525,8 +526,8 @@ def preset_layout(self, layout): tens_shape = [vs.dim for vs in self.tensorsig] local_shape = layout.local_shape(self.domain, self.scales) total_shape = tuple(tens_shape) + tuple(local_shape) - # Handle cupy allocation - if xp.__name__ == "cupy": + # Create view into buffer + if array_api_compat.is_cupy_namespace(xp): self.data = xp.ndarray(shape=total_shape, dtype=self.dtype, memptr=self.buffer.data) else: self.data = xp.ndarray(shape=total_shape, dtype=self.dtype, buffer=self.buffer) From 62ee03bd5a80c7b2d7655e45608699bd4a248812 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 27 May 2025 15:48:34 +0100 Subject: [PATCH 06/29] Add cupy-based complex fourier MMT --- dedalus/core/basis.py | 6 ++++-- dedalus/core/transforms.py | 36 +++++++++++++++++++++--------------- 2 files changed, 25 insertions(+), 17 deletions(-) diff --git a/dedalus/core/basis.py b/dedalus/core/basis.py index 3f6167c9..7689eee5 100644 --- a/dedalus/core/basis.py +++ b/dedalus/core/basis.py @@ -915,11 +915,13 @@ def _native_grid(self, scale): @CachedMethod def transform_plan(self, dist, grid_size): """Build transform plan.""" + xp = dist.array_namespace + xp_name = xp.__name__.split('.')[-1] # Shortcut trivial transforms if grid_size == 1 or self.size == 1: - return self.transforms['matrix'](grid_size, self.size) + return self.transforms[f"matrix-{xp_name}"](grid_size, self.size) else: - return self.transforms[self.library](grid_size, self.size) + return self.transforms[f"{self.library}-{xp_name}"](grid_size, self.size) def forward_transform(self, field, axis, gdata, cdata): # Transform diff --git a/dedalus/core/transforms.py b/dedalus/core/transforms.py index 00758fb2..db6f2595 100644 --- a/dedalus/core/transforms.py +++ b/dedalus/core/transforms.py @@ -191,50 +191,56 @@ class ComplexFourierTransform(SeparableTransform): If M is even, the ordering is [0, 1, 2, ..., KM, -KM, -KM+1, ..., -1]. """ - def __init__(self, grid_size, coeff_size): + def __init__(self, grid_size, coeff_size, array_namespace, dtype): self.N = grid_size self.M = coeff_size self.KN = (self.N - 1) // 2 self.KM = (self.M - 1) // 2 self.Kmax = min(self.KN, self.KM) + self.array_namespace = array_namespace + self.dtype = dtype @property def wavenumbers(self): """One-dimensional global wavenumber array.""" + xp = self.array_namespace M = self.M KM = self.KM - k = np.arange(M) + k = xp.arange(M) # Wrap around Nyquist mode return (k + KM) % M - KM -@register_transform(basis.ComplexFourier, 'matrix') +@register_transform(basis.ComplexFourier, 'matrix-numpy') +@register_transform(basis.ComplexFourier, 'matrix-cupy') class ComplexFourierMMT(ComplexFourierTransform, SeparableMatrixTransform): """Complex-to-complex Fourier MMT.""" @CachedAttribute def forward_matrix(self): """Build forward transform matrix.""" + xp = self.array_namespace K = self.wavenumbers[:, None] - X = np.arange(self.N)[None, :] + X = xp.arange(self.N)[None, :] dX = self.N / 2 / np.pi - quadrature = np.exp(-1j*K*X/dX) / self.N + quadrature = xp.exp(-1j*K*X/dX) / self.N # Zero Nyquist and higher modes for transforms with grid_size <= coeff_size - quadrature *= np.abs(K) <= self.Kmax - # Ensure C ordering for fast dot products - return np.asarray(quadrature, order='C') + quadrature *= xp.abs(K) <= self.Kmax + # Ensure C ordering for fast dot products, cast to specified dtype + return xp.asarray(quadrature, order='C', dtype=self.dtype) @CachedAttribute def backward_matrix(self): """Build backward transform matrix.""" + xp = self.array_namespace K = self.wavenumbers[None, :] - X = np.arange(self.N)[:, None] + X = xp.arange(self.N)[:, None] dX = self.N / 2 / np.pi - functions = np.exp(1j*K*X/dX) + functions = xp.exp(1j*K*X/dX) # Zero Nyquist and higher modes for transforms with grid_size <= coeff_size - functions *= np.abs(K) <= self.Kmax - # Ensure C ordering for fast dot products - return np.asarray(functions, order='C') + functions *= xp.abs(K) <= self.Kmax + # Ensure C ordering for fast dot products, cast to specified dtype + return xp.asarray(functions, order='C', dtype=self.dtype) class ComplexFFT(ComplexFourierTransform): @@ -267,7 +273,7 @@ def resize_coeffs(self, data_in, data_out, axis, rescale): np.multiply(data_in[negfreq], rescale, data_out[negfreq]) -@register_transform(basis.ComplexFourier, 'scipy') +@register_transform(basis.ComplexFourier, 'scipy-numpy') class ScipyComplexFFT(ComplexFFT): """Complex-to-complex FFT using scipy.fft.""" @@ -299,7 +305,7 @@ def __init__(self, *args, rigor=None, **kw): super().__init__(*args, **kw) -@register_transform(basis.ComplexFourier, 'fftw') +@register_transform(basis.ComplexFourier, 'fftw-numpy') class FFTWComplexFFT(FFTWBase, ComplexFFT): """Complex-to-complex FFT using FFTW.""" From 68bbd219bdd3b2fbbab371a20f4d6a4868449894 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 27 May 2025 15:52:28 +0100 Subject: [PATCH 07/29] Fix transform lookup --- dedalus/core/basis.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dedalus/core/basis.py b/dedalus/core/basis.py index 7689eee5..9ee586ec 100644 --- a/dedalus/core/basis.py +++ b/dedalus/core/basis.py @@ -919,9 +919,9 @@ def transform_plan(self, dist, grid_size): xp_name = xp.__name__.split('.')[-1] # Shortcut trivial transforms if grid_size == 1 or self.size == 1: - return self.transforms[f"matrix-{xp_name}"](grid_size, self.size) + return self.transforms[f"matrix-{xp_name}"](grid_size, self.size, dist.array_namespace, dist.dtype) else: - return self.transforms[f"{self.library}-{xp_name}"](grid_size, self.size) + return self.transforms[f"{self.library}-{xp_name}"](grid_size, self.size, dist.array_namespace, dist.dtype) def forward_transform(self, field, axis, gdata, cdata): # Transform From dce5d99edcfc292960efe17cb1a1e3ff17b7909e Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 27 May 2025 16:16:32 +0100 Subject: [PATCH 08/29] Make fill_random array and dtype compatible --- dedalus/core/field.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dedalus/core/field.py b/dedalus/core/field.py index 5f7cabfb..46f9e70f 100644 --- a/dedalus/core/field.py +++ b/dedalus/core/field.py @@ -915,6 +915,7 @@ def fill_random(self, layout=None, scales=None, seed=None, chunk_size=2**20, dis **kw : dict Other keywords passed to the distribution method. """ + xp = self.dist.array_namespace init_layout = self.layout # Set scales if requested if scales is not None: @@ -934,11 +935,10 @@ def fill_random(self, layout=None, scales=None, seed=None, chunk_size=2**20, dis spatial_slices = self.layout.slices(self.domain, self.scales) local_slices = component_slices + spatial_slices local_data = global_data[local_slices] - if self.is_real: - self.data[:] = local_data - else: - self.data.real[:] = local_data[..., 0] - self.data.imag[:] = local_data[..., 1] + if self.is_complex: + local_data = local_data[..., 0] + 1j * local_data[..., 1] + # Copy to field data + self.data[:] = xp.asarray(local_data, dtype=self.dtype) def low_pass_filter(self, shape=None, scales=None): """ From e189f4137a8940aadc25394f255a87b5a1b3ca36 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 27 May 2025 16:16:55 +0100 Subject: [PATCH 09/29] Work on cupy real fourier MMTs --- dedalus/core/transforms.py | 40 ++++++++++++++++++++++---------------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/dedalus/core/transforms.py b/dedalus/core/transforms.py index db6f2595..066b4bb4 100644 --- a/dedalus/core/transforms.py +++ b/dedalus/core/transforms.py @@ -374,7 +374,7 @@ class RealFourierTransform(SeparableTransform): where the k = 0 minus-sine mode is zeroed in both directions. """ - def __init__(self, grid_size, coeff_size): + def __init__(self, grid_size, coeff_size, array_namespace, dtype): if coeff_size % 2 != 0: pass#raise ValueError("coeff_size must be even.") self.N = grid_size @@ -382,55 +382,61 @@ def __init__(self, grid_size, coeff_size): self.KN = (self.N - 1) // 2 self.KM = (self.M - 1) // 2 self.Kmax = min(self.KN, self.KM) + self.array_namespace = array_namespace + self.dtype = dtype @property def wavenumbers(self): """One-dimensional global wavenumber array.""" + xp = self.array_namespace # Repeat k's for cos and msin parts - return np.repeat(np.arange(self.KM+1), 2) + return xp.repeat(xp.arange(self.KM+1), 2) -@register_transform(basis.RealFourier, 'matrix') +@register_transform(basis.RealFourier, 'matrix-numpy') +@register_transform(basis.RealFourier, 'matrix-cupy') class RealFourierMMT(RealFourierTransform, SeparableMatrixTransform): """Real-to-real Fourier MMT.""" @CachedAttribute def forward_matrix(self): """Build forward transform matrix.""" + xp = self.array_namespace N = self.N M = max(2, self.M) # Account for sin and cos parts of m=0 Kmax = self.Kmax K = self.wavenumbers[::2, None] - X = np.arange(N)[None, :] + X = xp.arange(N)[None, :] dX = N / 2 / np.pi - quadrature = np.zeros((M, N)) - quadrature[0::2] = (2 / N) * np.cos(K*X/dX) - quadrature[1::2] = -(2 / N) * np.sin(K*X/dX) + quadrature = xp.zeros((M, N)) + quadrature[0::2] = (2 / N) * xp.cos(K*X/dX) + quadrature[1::2] = -(2 / N) * xp.sin(K*X/dX) quadrature[0] = 1 / N # Zero Nyquist and higher modes for transforms with grid_size <= coeff_size quadrature *= self.wavenumbers[:,None] <= self.Kmax # Ensure C ordering for fast dot products - return np.asarray(quadrature, order='C') + return xp.asarray(quadrature, order='C', dtype=self.dtype) @CachedAttribute def backward_matrix(self): """Build backward transform matrix.""" + xp = self.array_namespace N = self.N M = max(2, self.M) # Account for sin and cos parts of m=0 Kmax = self.Kmax K = self.wavenumbers[None, ::2] - X = np.arange(N)[:, None] + X = xp.arange(N)[:, None] dX = N / 2 / np.pi - functions = np.zeros((N, M)) - functions[:, 0::2] = np.cos(K*X/dX) - functions[:, 1::2] = -np.sin(K*X/dX) + functions = xp.zeros((N, M)) + functions[:, 0::2] = xp.cos(K*X/dX) + functions[:, 1::2] = -xp.sin(K*X/dX) # Zero Nyquist and higher modes for transforms with grid_size <= coeff_size functions *= self.wavenumbers[None, :] <= self.Kmax # Ensure C ordering for fast dot products - return np.asarray(functions, order='C') + return xp.asarray(functions, order='C', dtype=self.dtype) -@register_transform(basis.RealFourier, 'fftpack') +@register_transform(basis.RealFourier, 'fftpack-numpy') class FFTPACKRealFFT(RealFourierTransform): """Real-to-real FFT using scipy.fftpack.""" @@ -515,7 +521,7 @@ def repack_rescale(self, cdata, temp, axis, rescale): temp[axslice(axis, Kmax+1, None)] = 0 -@register_transform(basis.RealFourier, 'scipy') +@register_transform(basis.RealFourier, 'scipy-numpy') class ScipyRealFFT(RealFFT): """Real-to-real FFT using scipy.fft.""" @@ -540,7 +546,7 @@ def backward(self, cdata, gdata, axis): np.copyto(gdata, temp) -@register_transform(basis.RealFourier, 'fftw') +@register_transform(basis.RealFourier, 'fftw-numpy') class FFTWRealFFT(FFTWBase, RealFFT): """Real-to-real FFT using FFTW.""" @@ -571,7 +577,7 @@ def backward(self, cdata, gdata, axis): plan.backward(temp, gdata) -@register_transform(basis.RealFourier, 'fftw_hc') +@register_transform(basis.RealFourier, 'fftw_hc-numpy') class FFTWHalfComplexFFT(FFTWBase, RealFourierTransform): """Real-to-real FFT using FFTW half-complex DFT.""" From cf8644d61859550cf04a9b889f3322edae43f892 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 27 May 2025 16:26:02 +0100 Subject: [PATCH 10/29] Generalize Fourier basis for more dtypes --- dedalus/core/basis.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dedalus/core/basis.py b/dedalus/core/basis.py index 9ee586ec..9fd2c264 100644 --- a/dedalus/core/basis.py +++ b/dedalus/core/basis.py @@ -14,7 +14,7 @@ from ..tools import clenshaw from ..tools.array import reshape_vector, axindex, axslice, interleave_matrices from ..tools.dispatch import MultiClass, SkipDispatchException -from ..tools.general import unify, DeferredTuple +from ..tools.general import unify, DeferredTuple, is_real_dtype, is_complex_dtype from .coords import Coordinate, CartesianCoordinates, S2Coordinates, SphericalCoordinates, PolarCoordinates, AzimuthalCoordinate, DirectProduct from .domain import Domain from .field import Operand, LockedField @@ -942,9 +942,9 @@ def Fourier(*args, dtype=None, **kw): """Factory function dispatching to RealFourier and ComplexFourier based on provided dtype.""" if dtype is None: raise ValueError("dtype must be specified") - elif dtype == np.float64: + elif is_real_dtype(dtype): return RealFourier(*args, **kw) - elif dtype == np.complex128: + elif is_complex_dtype(dtype): return ComplexFourier(*args, **kw) else: raise ValueError(f"Unrecognized dtype: {dtype}") From 2fb0d32907990f50477f786e5bc17b862f2b86a0 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Wed, 28 May 2025 11:13:42 +0100 Subject: [PATCH 11/29] Add cupy complex FFT --- dedalus/core/transforms.py | 40 ++++++++++++++++++++++++++++++++------ 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/dedalus/core/transforms.py b/dedalus/core/transforms.py index 066b4bb4..4e2aa839 100644 --- a/dedalus/core/transforms.py +++ b/dedalus/core/transforms.py @@ -248,29 +248,30 @@ class ComplexFFT(ComplexFourierTransform): def resize_coeffs(self, data_in, data_out, axis, rescale): """Resize and rescale coefficients in standard FFT format by intermediate padding/truncation.""" + xp = self.array_namespace M = self.M Kmax = self.Kmax if Kmax == 0: posfreq = axslice(axis, 0, 1) badfreq = axslice(axis, 1, None) if rescale is None: - np.copyto(data_out[posfreq], data_in[posfreq]) + xp.copyto(data_out[posfreq], data_in[posfreq]) data_out[badfreq] = 0 else: - np.multiply(data_in[posfreq], rescale, data_out[posfreq]) + xp.multiply(data_in[posfreq], rescale, data_out[posfreq]) data_out[badfreq] = 0 else: posfreq = axslice(axis, 0, Kmax+1) badfreq = axslice(axis, Kmax+1, -Kmax) negfreq = axslice(axis, -Kmax, None) if rescale is None: - np.copyto(data_out[posfreq], data_in[posfreq]) + xp.copyto(data_out[posfreq], data_in[posfreq]) data_out[badfreq] = 0 - np.copyto(data_out[negfreq], data_in[negfreq]) + xp.copyto(data_out[negfreq], data_in[negfreq]) else: - np.multiply(data_in[posfreq], rescale, data_out[posfreq]) + xp.multiply(data_in[posfreq], rescale, data_out[posfreq]) data_out[badfreq] = 0 - np.multiply(data_in[negfreq], rescale, data_out[negfreq]) + xp.multiply(data_in[negfreq], rescale, data_out[negfreq]) @register_transform(basis.ComplexFourier, 'scipy-numpy') @@ -295,6 +296,33 @@ def backward(self, cdata, gdata, axis): np.copyto(gdata, temp) +@register_transform(basis.ComplexFourier, 'scipy-cupy') +class CupyComplexFFT(ComplexFFT): + """Complex-to-complex FFT using scipy.fft.""" + + def __init__(self, *args, **kw): + import cupyx.scipy.fft as cufft + self.cufft = cufft + super().__init__(*args, **kw) + + def forward(self, gdata, cdata, axis): + """Apply forward transform along specified axis.""" + # Call FFT + temp = self.cufft.fft(gdata, axis=axis) # Creates temporary + # Resize and rescale for unit-amplitude normalization + self.resize_coeffs(temp, cdata, axis, rescale=1/self.N) + + def backward(self, cdata, gdata, axis): + """Apply backward transform along specified axis.""" + # Resize and rescale for unit-amplitude normalization + # Need temporary to avoid overwriting problems + temp = np.empty_like(gdata) # Creates temporary + self.resize_coeffs(cdata, temp, axis, rescale=self.N) + # Call FFT + temp = self.cufft.ifft(temp, axis=axis, overwrite_x=True) # Creates temporary + np.copyto(gdata, temp) + + class FFTWBase: """Abstract base class for FFTW transforms.""" From 4fbe35b7028aabdb88cf97deef9e36c7f54389a9 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Wed, 28 May 2025 14:35:10 +0100 Subject: [PATCH 12/29] Add cupy real fft --- dedalus/core/transforms.py | 64 +++++++++++++++++++++++++++++++------- dedalus/tools/general.py | 12 +++++++ 2 files changed, 64 insertions(+), 12 deletions(-) diff --git a/dedalus/core/transforms.py b/dedalus/core/transforms.py index 4e2aa839..5616c1c7 100644 --- a/dedalus/core/transforms.py +++ b/dedalus/core/transforms.py @@ -15,6 +15,7 @@ from ..tools.array import apply_matrix, apply_dense, axslice, solve_upper_sparse, apply_sparse from ..tools.cache import CachedAttribute from ..tools.cache import CachedMethod +from ..tools.general import float_to_complex import logging logger = logging.getLogger(__name__.split('.')[-1]) @@ -314,13 +315,14 @@ def forward(self, gdata, cdata, axis): def backward(self, cdata, gdata, axis): """Apply backward transform along specified axis.""" + xp = self.array_namespace # Resize and rescale for unit-amplitude normalization # Need temporary to avoid overwriting problems - temp = np.empty_like(gdata) # Creates temporary + temp = xp.empty_like(gdata) # Creates temporary self.resize_coeffs(cdata, temp, axis, rescale=self.N) # Call FFT temp = self.cufft.ifft(temp, axis=axis, overwrite_x=True) # Creates temporary - np.copyto(gdata, temp) + xp.copyto(gdata, temp) class FFTWBase: @@ -511,40 +513,42 @@ class RealFFT(RealFourierTransform): def unpack_rescale(self, temp, cdata, axis, rescale): """Unpack complex coefficients and rescale for unit-amplitude normalization.""" + xp = self.array_namespace Kmax = self.Kmax # Scale k = 0 cos data meancos = axslice(axis, 0, 1) - np.multiply(temp[meancos].real, rescale, cdata[meancos]) + xp.multiply(temp[meancos].real, rescale, cdata[meancos]) # Zero k = 0 msin data cdata[axslice(axis, 1, 2)] = 0 # Unpack and scale 1 < k <= Kmax data temp_posfreq = temp[axslice(axis, 1, Kmax+1)] cdata_posfreq_cos = cdata[axslice(axis, 2, 2*(Kmax+1), 2)] cdata_posfreq_msin = cdata[axslice(axis, 3, 2*(Kmax+1), 2)] - np.multiply(temp_posfreq.real, 2*rescale, cdata_posfreq_cos) - np.multiply(temp_posfreq.imag, 2*rescale, cdata_posfreq_msin) + xp.multiply(temp_posfreq.real, 2*rescale, cdata_posfreq_cos) + xp.multiply(temp_posfreq.imag, 2*rescale, cdata_posfreq_msin) # Zero k > Kmax data cdata[axslice(axis, 2*(Kmax+1), None)] = 0 def repack_rescale(self, cdata, temp, axis, rescale): """Repack into complex coefficients and rescale for unit-amplitude normalization.""" + xp = self.array_namespace Kmax = self.Kmax # Scale k = 0 data meancos = axslice(axis, 0, 1) if rescale is None: - np.copyto(temp[meancos], cdata[meancos]) + xp.copyto(temp[meancos], cdata[meancos]) else: - np.multiply(cdata[meancos], rescale, temp[meancos]) + xp.multiply(cdata[meancos], rescale, temp[meancos]) # Repack and scale 1 < k <= Kmax data temp_posfreq = temp[axslice(axis, 1, Kmax+1)] cdata_posfreq_cos = cdata[axslice(axis, 2, 2*(Kmax+1), 2)] cdata_posfreq_msin = cdata[axslice(axis, 3, 2*(Kmax+1), 2)] if rescale is None: - np.multiply(cdata_posfreq_cos, (1 / 2), temp_posfreq.real) - np.multiply(cdata_posfreq_msin, (1 / 2), temp_posfreq.imag) + xp.multiply(cdata_posfreq_cos, (1 / 2), temp_posfreq.real) + xp.multiply(cdata_posfreq_msin, (1 / 2), temp_posfreq.imag) else: - np.multiply(cdata_posfreq_cos, (rescale / 2), temp_posfreq.real) - np.multiply(cdata_posfreq_msin, (rescale / 2), temp_posfreq.imag) + xp.multiply(cdata_posfreq_cos, (rescale / 2), temp_posfreq.real) + xp.multiply(cdata_posfreq_msin, (rescale / 2), temp_posfreq.imag) # Zero k > Kmax data temp[axslice(axis, Kmax+1, None)] = 0 @@ -553,6 +557,10 @@ def repack_rescale(self, cdata, temp, axis, rescale): class ScipyRealFFT(RealFFT): """Real-to-real FFT using scipy.fft.""" + def __init__(self, *args, **kw): + super().__init__(*args, **kw) + self.complex_dtype = float_to_complex(self.dtype) + def forward(self, gdata, cdata, axis): """Apply forward transform along specified axis.""" # Call RFFT @@ -566,7 +574,7 @@ def backward(self, cdata, gdata, axis): # Rescale all modes and combine into complex form shape = list(gdata.shape) shape[axis] = N // 2 + 1 - temp = np.empty(shape=shape, dtype=np.complex128) # Creates temporary + temp = np.empty(shape=shape, dtype=self.complex_dtype) # Creates temporary # Repack into complex form and rescale self.repack_rescale(cdata, temp, axis, rescale=N) # Call IRFFT @@ -574,6 +582,38 @@ def backward(self, cdata, gdata, axis): np.copyto(gdata, temp) +@register_transform(basis.RealFourier, 'scipy-cupy') +class CupyRealFFT(RealFFT): + """Real-to-real FFT using scipy.fft.""" + + def __init__(self, *args, **kw): + import cupyx.scipy.fft as cufft + self.cufft = cufft + super().__init__(*args, **kw) + self.complex_dtype = float_to_complex(self.dtype) + + def forward(self, gdata, cdata, axis): + """Apply forward transform along specified axis.""" + # Call RFFT + temp = self.cufft.rfft(gdata, axis=axis) # Creates temporary + # Unpack from complex form and rescale + self.unpack_rescale(temp, cdata, axis, rescale=1/self.N) + + def backward(self, cdata, gdata, axis): + """Apply backward transform along specified axis.""" + xp = self.array_namespace + N = self.N + # Rescale all modes and combine into complex form + shape = list(gdata.shape) + shape[axis] = N // 2 + 1 + temp = xp.empty(shape=shape, dtype=self.complex_dtype) # Creates temporary + # Repack into complex form and rescale + self.repack_rescale(cdata, temp, axis, rescale=N) + # Call IRFFT + temp = self.cufft.irfft(temp, axis=axis, n=N, overwrite_x=True) # Creates temporary + xp.copyto(gdata, temp) + + @register_transform(basis.RealFourier, 'fftw-numpy') class FFTWRealFFT(FFTWBase, RealFFT): """Real-to-real FFT using FFTW.""" diff --git a/dedalus/tools/general.py b/dedalus/tools/general.py index 18eb5ee4..5e22f9b2 100644 --- a/dedalus/tools/general.py +++ b/dedalus/tools/general.py @@ -124,3 +124,15 @@ def is_complex_dtype(dtype): dtype = dtype.type return np.iscomplexobj(dtype()) + +def float_to_complex(dtype): + itemsize = np.dtype(dtype).itemsize + complex_dtype = np.dtype(f'complex{itemsize*2}') + return complex_dtype.type + + +def complex_to_float(dtype): + itemsize = np.dtype(dtype).itemsize + float_dtype = np.dtype(f'float{itemsize//2}') + return float_dtype.type + From 8c7985d2ce2589e5e376a9be1f8b9cc77b0f4be2 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Wed, 28 May 2025 15:07:20 +0100 Subject: [PATCH 13/29] Fix dtype conversion --- dedalus/tools/general.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dedalus/tools/general.py b/dedalus/tools/general.py index 5e22f9b2..9b8b5746 100644 --- a/dedalus/tools/general.py +++ b/dedalus/tools/general.py @@ -127,12 +127,12 @@ def is_complex_dtype(dtype): def float_to_complex(dtype): itemsize = np.dtype(dtype).itemsize - complex_dtype = np.dtype(f'complex{itemsize*2}') + complex_dtype = np.dtype(f'complex{16*itemsize}') return complex_dtype.type def complex_to_float(dtype): itemsize = np.dtype(dtype).itemsize - float_dtype = np.dtype(f'float{itemsize//2}') + float_dtype = np.dtype(f'float{4*itemsize}') return float_dtype.type From d6a45252a01fab7c160dd678030e77c97d2f1bc0 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Wed, 28 May 2025 15:20:25 +0100 Subject: [PATCH 14/29] Add array compat for basic arithmetic --- dedalus/core/arithmetic.py | 9 ++++++--- dedalus/core/future.py | 1 + 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/dedalus/core/arithmetic.py b/dedalus/core/arithmetic.py index 55c70e53..d417b906 100644 --- a/dedalus/core/arithmetic.py +++ b/dedalus/core/arithmetic.py @@ -245,10 +245,11 @@ def choose_layout(self): def operate(self, out): """Perform operation.""" + xp = self.array_namespace arg0, arg1 = self.args # Set output layout out.preset_layout(arg0.layout) - np.add(arg0.data, arg1.data, out=out.data) + xp.add(arg0.data, arg1.data, out=out.data) # used for einsum string manipulation @@ -854,6 +855,7 @@ def __init__(self, arg0, arg1, out=None, **kw): def operate(self, out): """Perform operation.""" + xp = self.array_namespace arg0, arg1 = self.args # Set output layout out.preset_layout(arg0.layout) @@ -863,7 +865,7 @@ def operate(self, out): # Reshape arg data to broadcast properly for output tensorsig arg0_exp_data = arg0_data.reshape(self.arg0_exp_tshape + arg0_data.shape[len(arg0.tensorsig):]) arg1_exp_data = arg1_data.reshape(self.arg1_exp_tshape + arg1_data.shape[len(arg1.tensorsig):]) - np.multiply(arg0_exp_data, arg1_exp_data, out=out.data) + xp.multiply(arg0_exp_data, arg1_exp_data, out=out.data) class GhostBroadcaster: @@ -939,11 +941,12 @@ def enforce_conditions(self): def operate(self, out): """Perform operation.""" + xp = self.array_namespace arg0, arg1 = self.args # Set output layout out.preset_layout(arg1.layout) # Multiply argument data - np.multiply(arg0, arg1.data, out=out.data) + xp.multiply(arg0, arg1.data, out=out.data) def matrix_dependence(self, *vars): return self.args[1].matrix_dependence(*vars) diff --git a/dedalus/core/future.py b/dedalus/core/future.py index 58f9cd9d..ab07e8ff 100644 --- a/dedalus/core/future.py +++ b/dedalus/core/future.py @@ -51,6 +51,7 @@ def __init__(self, *args, out=None): self.original_args = tuple(args) self.out = out self.dist = unify_attributes(args, 'dist', require=False) + self.array_namespace = self.dist.array_namespace #self.domain = Domain(self.dist, self.bases) self._grid_layout = self.dist.grid_layout self._coeff_layout = self.dist.coeff_layout From 6d05ff0241027422e804fdc6f4c82dec1508bae0 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Wed, 28 May 2025 15:30:15 +0100 Subject: [PATCH 15/29] Beginning adding array_compat to operators --- dedalus/core/operators.py | 37 +++++++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/dedalus/core/operators.py b/dedalus/core/operators.py index 9a9a993d..cc6f0dad 100644 --- a/dedalus/core/operators.py +++ b/dedalus/core/operators.py @@ -378,11 +378,12 @@ def enforce_conditions(self): arg0.require_grid_space() def operate(self, out): + xp = self.array_namespace arg0, arg1 = self.args # Multiply in grid layout out.preset_layout(arg0.layout) if out.data.size: - np.power(arg0.data, arg1, out.data) + xp.power(arg0.data, arg1, out.data) def new_operands(self, arg0, arg1, **kw): return Power(arg0, arg1) @@ -498,8 +499,9 @@ def enforce_conditions(self): self.args[i].change_layout(self.layout) def operate(self, out): + xp = self.array_namespace out.preset_layout(self.layout) - np.copyto(out.data, self.func(*self.args, **self.kw)) + xp.copyto(out.data, self.func(*self.args, **self.kw)) class UnaryGridFunction(NonlinearOperator, FutureField): @@ -812,10 +814,11 @@ def enforce_conditions(self): def operate(self, out): """Perform operation.""" + xp = self.array_namespace arg0 = self.args[0] out.preset_layout(arg0.layout) out.lock_to_layouts(self.layouts) - np.copyto(out.data, arg0.data) + xp.copyto(out.data, arg0.data) def new_operand(self, operand, **kw): return Lock(operand, *self.layouts, **kw) @@ -1522,9 +1525,10 @@ def subproblem_matrix(self, subproblem): def operate(self, out): """Perform operation.""" + xp = self.array_namespace arg = self.args[0] out.preset_layout(arg.layout) - np.copyto(out.data, arg.data) + xp.copyto(out.data, arg.data) class Convert(SpectralOperator, metaclass=MultiClass): @@ -1624,12 +1628,13 @@ def subspace_matrix(self, layout): def operate(self, out): """Perform operation.""" + xp = self.array_namespace arg = self.args[0] layout = arg.layout # Copy for grid space if layout.grid_space[self.last_axis]: out.preset_layout(layout) - np.copyto(out.data, arg.data) + xp.copyto(out.data, arg.data) # Revert to matrix application for coeff space else: super().operate(out) @@ -1772,9 +1777,10 @@ def base(self): def operate(self, out): """Perform operation.""" + xp = self.array_namespace arg = self.args[0] out.preset_layout(arg.layout) - np.einsum('ii...', arg.data, out=out.data) + xp.einsum('ii...', arg.data, out=out.data) class SphericalTrace(Trace): @@ -1971,6 +1977,7 @@ def subproblem_matrix(self, subproblem): def operate(self, out): """Perform operation.""" + xp = self.array_namespace operand = self.args[0] # Set output layout out.preset_layout(operand.layout) @@ -3485,10 +3492,11 @@ def subproblem_matrix(self, subproblem): def operate(self, out): """Perform operation.""" # OPTIMIZE: this has an extra copy + xp = self.array_namespace arg0 = self.args[0] # Set output layout out.preset_layout(arg0.layout) - np.copyto(out.data, arg0.data) + xp.copyto(out.data, arg0.data) class DirectProductDivergence(Divergence): @@ -3534,10 +3542,11 @@ def subproblem_matrix(self, subproblem): def operate(self, out): """Perform operation.""" # OPTIMIZE: this has an extra copy + xp = self.array_namespace arg0 = self.args[0] # Set output layout out.preset_layout(arg0.layout) - np.copyto(out.data, arg0.data) + xp.copyto(out.data, arg0.data) class SphericalDivergence(Divergence, SphericalEllOperator): @@ -3739,10 +3748,11 @@ def enforce_conditions(self): def operate(self, out): """Perform operation.""" # OPTIMIZE: this has an extra copy + xp = self.array_namespace arg0 = self.args[0] # Set output layout out.preset_layout(arg0.layout) - np.copyto(out.data, arg0.data) + xp.copyto(out.data, arg0.data) class DirectProductCurl(Curl): @@ -3826,10 +3836,11 @@ def enforce_conditions(self): def operate(self, out): """Perform operation.""" # OPTIMIZE: this has an extra copy + xp = self.array_namespace arg0 = self.args[0] # Set output layout out.preset_layout(arg0.layout) - np.copyto(out.data, arg0.data) + xp.copyto(out.data, arg0.data) class SphericalCurl(Curl, SphericalEllOperator): @@ -4052,10 +4063,11 @@ def enforce_conditions(self): def operate(self, out): """Perform operation.""" # OPTIMIZE: this has an extra copy + xp = self.array_namespace arg0 = self.args[0] # Set output layout out.preset_layout(arg0.layout) - np.copyto(out.data, arg0.data) + xp.copyto(out.data, arg0.data) class DirectProductLaplacian(Laplacian): @@ -4097,10 +4109,11 @@ def enforce_conditions(self): def operate(self, out): """Perform operation.""" # OPTIMIZE: this has an extra copy + xp = self.array_namespace arg0 = self.args[0] # Set output layout out.preset_layout(arg0.layout) - np.copyto(out.data, arg0.data) + xp.copyto(out.data, arg0.data) class SphericalLaplacian(Laplacian, SphericalEllOperator): From 79d789c5766a367474d7f247fa45751fb32aca09 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 22 Jul 2025 11:58:36 -0400 Subject: [PATCH 16/29] Quick implementation of apply_sparse for cupy --- dedalus/tools/array.py | 40 +++++++++++------- dedalus/tools/linalg_gpu.py | 84 +++++++++++++++++++++++++++++++++++++ 2 files changed, 109 insertions(+), 15 deletions(-) create mode 100644 dedalus/tools/linalg_gpu.py diff --git a/dedalus/tools/array.py b/dedalus/tools/array.py index ab9caf88..9749f66c 100644 --- a/dedalus/tools/array.py +++ b/dedalus/tools/array.py @@ -6,6 +6,8 @@ from scipy.sparse import _sparsetools from scipy.sparse import linalg as spla from math import prod +from ..tools import linalg_gpu +import array_api_compat from .config import config from . import linalg as cython_linalg @@ -173,14 +175,12 @@ def apply_sparse(matrix, array, axis, out=None, check_shapes=False, num_threads= Apply sparse matrix along any axis of an array. Must be out of place if ouptut is specified. """ - # Check matrix - if not isinstance(matrix, sparse.csr_matrix): - raise ValueError("Matrix must be in CSR format.") + xp = array_api_compat.array_namespace(array) # Check output if out is None: out_shape = list(array.shape) out_shape[axis] = matrix.shape[0] - out = np.empty(out_shape, dtype=array.dtype) + out = xp.empty(out_shape, dtype=array.dtype) elif out is array: raise ValueError("Cannot apply in place") # Check shapes @@ -189,17 +189,27 @@ def apply_sparse(matrix, array, axis, out=None, check_shapes=False, num_threads= raise ValueError("Axis out of bounds.") if matrix.shape[1] != array.shape[axis] or matrix.shape[0] != out.shape[axis]: raise ValueError("Matrix shape mismatch.") - # Old way if requested - if OLD_CSR_MATVECS and array.ndim == 2 and axis == 0: - out.fill(0) - return csr_matvecs(matrix, array, out) - # Promote datatypes - # TODO: find way to optimize this with fused types - matrix_data = matrix.data - if matrix_data.dtype != out.dtype: - matrix_data = matrix_data.astype(out.dtype) - # Call cython routine - cython_linalg.apply_csr(matrix.indptr, matrix.indices, matrix_data, array, out, axis, num_threads) + # Dispatch on array type + if array_api_compat.is_numpy_namespace(xp): + # Check matrix + if not isinstance(matrix, sparse.csr_matrix): + raise ValueError("Matrix must be in CSR format.") + # Old way if requested + if OLD_CSR_MATVECS and array.ndim == 2 and axis == 0: + out.fill(0) + return csr_matvecs(matrix, array, out) + # Promote datatypes + # TODO: find way to optimize this with fused types + matrix_data = matrix.data + if matrix_data.dtype != out.dtype: + matrix_data = matrix_data.astype(out.dtype) + # Call cython routine + cython_linalg.apply_csr(matrix.indptr, matrix.indices, matrix_data, array, out, axis, num_threads) + elif array_api_compat.is_cupy_namespace(xp): + # TODO: check matrix format here without import cupy + linalg_gpu.cupy_apply_csr(matrix, array, axis, out) + else: + raise ValueError("Unsupported array type") return out diff --git a/dedalus/tools/linalg_gpu.py b/dedalus/tools/linalg_gpu.py new file mode 100644 index 00000000..095eabaf --- /dev/null +++ b/dedalus/tools/linalg_gpu.py @@ -0,0 +1,84 @@ + +import numpy as np +try: + import cupy as cp + import cupyx.scipy.sparse as csp + HAVE_CUPY = True +except ImportError: + HAVE_CUPY = False + + +def cupy_apply_csr(matrix, array, axis, out): + """Apply CSR matrix to arbitrary axis of array.""" + if not HAVE_CUPY: + raise ImportError("cupy must be installed to use GPU linear algebra") + # Check matrix format + if not isinstance(matrix, csp.csr_matrix): + raise ValueError("Matrix must be in CSR format.") + # Switch by dimension + ndim = array.ndim + if ndim == 1: + if axis == 0: + out[:] = cupy_apply_csr_vec(matrix, array) + else: + raise ValueError("axis must be 0 for 1D arrays") + elif ndim == 2: + if axis == 0: + if array.shape[1] == 1: + out[:,0] = cupy_apply_csr_vec(matrix, array[:,0]) + else: + out[:] = cupy_apply_csr_first(matrix, array) + elif axis == 1: + if array.shape[0] == 1: + out[0,:] = cupy_apply_csr_vec(matrix, array[0,:]) + else: + out[:] = cupy_apply_csr_last(matrix, array) + else: + raise ValueError("axis must be 0 or 1 for 2D arrays") + else: + # Treat as 3D array with specified axis in the middle + # Compute equivalent shape (N1, N2, N3) + if ndim == 3 and axis == 1: + N1 = array.shape[0] + N2 = array.shape[1] + N3 = array.shape[2] + else: + N1 = int(np.prod(array.shape[:axis])) + N2 = array.shape[axis] + N3 = int(np.prod(array.shape[axis+1:])) + # Dispatch to cupy routines + if N1 == 1: + if N3 == 1: + # (1, N2, 1) -> (N2,) + x1 = array.reshape((N2,)) + temp = cupy_apply_csr_vec(matrix, x1) + out[:] = temp.reshape(out.shape) + else: + # (1, N2, N3) -> (N2, N3) + x2 = array.reshape((N2, N3)) + temp = cupy_apply_csr_first(matrix, x2) + out[:] = temp.reshape(out.shape) + else: + if N3 == 1: + # (N1, N2, 1) -> (N1, N2) + x2 = array.reshape((N1, N2)) + temp = cupy_apply_csr_last(matrix, x2) + out[:] = temp.reshape(out.shape) + else: + # (N1, N2, N3) + x3 = array.reshape((N1, N2, N3)) + y3 = out.reshape(((N1, matrix.shape[0], N3))) + for n1 in range(N1): + y3[n1] = cupy_apply_csr_first(matrix, x3[n1]) + + +def cupy_apply_csr_vec(matrix, vec): + return matrix.dot(vec) + +def cupy_apply_csr_first(matrix, array): + return matrix.dot(array) + +def cupy_apply_csr_last(matrix, array): + return matrix.dot(array.T).T + + From fb9b3d60c236301bbc3a6523ec4ac6bd69c3262d Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 22 Jul 2025 15:09:41 -0400 Subject: [PATCH 17/29] Make einsum in dot compatible with cupy --- dedalus/core/arithmetic.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/dedalus/core/arithmetic.py b/dedalus/core/arithmetic.py index d417b906..647c47a6 100644 --- a/dedalus/core/arithmetic.py +++ b/dedalus/core/arithmetic.py @@ -13,6 +13,7 @@ import numexpr as ne from collections import defaultdict from math import prod +import array_api_compat from .domain import Domain from .field import Operand, Field @@ -665,6 +666,7 @@ def GammaCoord(self, A_tensorsig, B_tensorsig, C_tensorsig): return G def operate(self, out): + xp = self.array_namespace arg0, arg1 = self.args out.preset_layout(arg0.layout) # Broadcast @@ -672,7 +674,11 @@ def operate(self, out): arg1_data = self.arg1_ghost_broadcaster.cast(arg1) # Call einsum if out.data.size: - np.einsum(self.einsum_str, arg0_data, arg1_data, out=out.data, optimize=True) + if array_api_compat.is_cupy_namespace(xp): + # Cupy does not support output keyword + out.data[:] = xp.einsum(self.einsum_str, arg0_data, arg1_data, optimize=True) + else: + xp.einsum(self.einsum_str, arg0_data, arg1_data, out=out.data, optimize=True) @alias("cross") From 1e29a80a5d01755896923ab47abdc6e8ad2b3f67 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 22 Jul 2025 15:10:06 -0400 Subject: [PATCH 18/29] Add custom kernel for cupy csr middle dot product --- dedalus/tools/linalg_gpu.py | 75 ++++++++++++++++++++++++++++--------- 1 file changed, 58 insertions(+), 17 deletions(-) diff --git a/dedalus/tools/linalg_gpu.py b/dedalus/tools/linalg_gpu.py index 095eabaf..ad0b5098 100644 --- a/dedalus/tools/linalg_gpu.py +++ b/dedalus/tools/linalg_gpu.py @@ -1,5 +1,7 @@ +"""Linear algebra routines using cupy.""" import numpy as np +import math try: import cupy as cp import cupyx.scipy.sparse as csp @@ -14,25 +16,27 @@ def cupy_apply_csr(matrix, array, axis, out): raise ImportError("cupy must be installed to use GPU linear algebra") # Check matrix format if not isinstance(matrix, csp.csr_matrix): - raise ValueError("Matrix must be in CSR format.") + # TODO: avoid this explicit conversion + matrix = csp.csr_matrix(matrix) + #raise ValueError("Matrix must be in CSR format.") # Switch by dimension ndim = array.ndim if ndim == 1: if axis == 0: - out[:] = cupy_apply_csr_vec(matrix, array) + out[:] = matrix.dot(array) else: raise ValueError("axis must be 0 for 1D arrays") elif ndim == 2: if axis == 0: if array.shape[1] == 1: - out[:,0] = cupy_apply_csr_vec(matrix, array[:,0]) + out[:,0] = matrix.dot(array[:,0]) else: - out[:] = cupy_apply_csr_first(matrix, array) + out[:] = matrix.dot(array) elif axis == 1: if array.shape[0] == 1: - out[0,:] = cupy_apply_csr_vec(matrix, array[0,:]) + out[0,:] = matrix.dot(array[0,:]) else: - out[:] = cupy_apply_csr_last(matrix, array) + out[:] = matrix.dot(array.T).T else: raise ValueError("axis must be 0 or 1 for 2D arrays") else: @@ -51,34 +55,71 @@ def cupy_apply_csr(matrix, array, axis, out): if N3 == 1: # (1, N2, 1) -> (N2,) x1 = array.reshape((N2,)) - temp = cupy_apply_csr_vec(matrix, x1) + temp = matrix.dot(x1) out[:] = temp.reshape(out.shape) else: # (1, N2, N3) -> (N2, N3) x2 = array.reshape((N2, N3)) - temp = cupy_apply_csr_first(matrix, x2) + temp = matrix.dot(x2) out[:] = temp.reshape(out.shape) else: if N3 == 1: # (N1, N2, 1) -> (N1, N2) x2 = array.reshape((N1, N2)) - temp = cupy_apply_csr_last(matrix, x2) + temp = matrix.dot(x2.T).T out[:] = temp.reshape(out.shape) else: # (N1, N2, N3) x3 = array.reshape((N1, N2, N3)) y3 = out.reshape(((N1, matrix.shape[0], N3))) - for n1 in range(N1): - y3[n1] = cupy_apply_csr_first(matrix, x3[n1]) + cupy_apply_csr_mid(matrix, x3, y3) -def cupy_apply_csr_vec(matrix, vec): - return matrix.dot(vec) +# Kernel for applying CSR matrix with parallelization over n1 and n3 +apply_csr_mid_kernel = cp.RawKernel( + r''' + extern "C" __global__ void apply_csr_mid_kernel( + const float* data, // CSR data of shape (nnz,) + const int* indices, // CSR column indices (nnz,) + const int* indptr, // CSR row pointers (N2o + 1,) + const float* input, // shape (N1, N2i, N3) + float* output, // shape (N1, N2o, N3) + int N1, int N2i, int N2o, int N3) + { + int n1 = blockIdx.x * blockDim.x + threadIdx.x ; // batch index + int n3 = blockIdx.y * blockDim.y + threadIdx.y; // output column index -def cupy_apply_csr_first(matrix, array): - return matrix.dot(array) + if (n1 >= N1 || n3 >= N3) return; -def cupy_apply_csr_last(matrix, array): - return matrix.dot(array.T).T + // Loop over output rows = CSR matrix rows + for (int i = 0; i < N2o; ++i) { + float acc = 0.0f; + int start = indptr[i]; + int end = indptr[i + 1]; + for (int k = start; k < end; ++k) { + int j = indices[k]; // input column + float val = data[k]; + acc += val * input[n1 * N2i * N3 + j * N3 + n3]; + } + + output[n1 * N2o * N3 + i * N3 + n3] = acc; + } + } + ''', + 'apply_csr_mid_kernel') + + +def cupy_apply_csr_mid(matrix, array, out): + N1, N2i, N3 = array.shape + N2o = matrix.shape[0] + # Choose thread/block config + threads_y = min(1024, N3) # maximize concurrency along n3 + threads_x = 1024 // threads_y # make block have 1024 threads + blockdim = (threads_x, threads_y) + blocks_x = (N1 + threads_x - 1) // threads_x + blocks_y = (N3 + threads_y - 1) // threads_y + griddim = (blocks_x, blocks_y) + # Launch kernel + apply_csr_mid_kernel(griddim, blockdim, (matrix.data, matrix.indices, matrix.indptr, array, out, N1, N2i, N2o, N3)) From d240656372b72067492786c20506807d36008ac9 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 22 Jul 2025 15:47:23 -0400 Subject: [PATCH 19/29] Convert local grids/modes to device arrays --- dedalus/core/distributor.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/dedalus/core/distributor.py b/dedalus/core/distributor.py index 9b2a3ff7..99ada2e9 100644 --- a/dedalus/core/distributor.py +++ b/dedalus/core/distributor.py @@ -261,11 +261,12 @@ def IdentityTensor(self, coordsys_in, coordsys_out=None, bases=None, dtype=None) return I def local_grid(self, basis, scale=None): + xp = self.array_namespace # TODO: remove from bases and do it all here? if scale is None: scale = 1 if basis.dim == 1: - return basis.local_grid(self, scale=scale) + return xp.asarray(basis.local_grid(self, scale=scale)) else: raise ValueError("Use `local_grids` for multidimensional bases.") @@ -298,16 +299,18 @@ def local_grid(self, basis, scale=None): # return tuple(grids) def local_grids(self, *bases, scales=None): + xp = self.array_namespace scales = self.remedy_scales(scales) grids = [] for basis in bases: basis_scales = scales[self.first_axis(basis):self.last_axis(basis)+1] - grids.extend(basis.local_grids(self, scales=basis_scales)) + grids.extend(xp.asarray(basis.local_grids(self, scales=basis_scales))) return grids def local_modes(self, basis): # TODO: remove from bases and do it all here? - return basis.local_modes(self) + xp = self.array_namespace + return xp.asarray(basis.local_modes(self)) @CachedAttribute def default_nonconst_groups(self): From 644f3bf8c8ebb7dd2bab718397fe34107db8ea14 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 22 Jul 2025 15:55:21 -0400 Subject: [PATCH 20/29] Explicitly cast data norms to float --- dedalus/core/field.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dedalus/core/field.py b/dedalus/core/field.py index 46f9e70f..09be49e4 100644 --- a/dedalus/core/field.py +++ b/dedalus/core/field.py @@ -834,7 +834,7 @@ def allreduce_data_norm(self, layout=None, order=2): if self.dist.comm.size > 1: norm = self.dist.comm.allreduce(norm, op=MPI.SUM) norm = norm ** (1 / order) - return norm + return float(norm) def allreduce_data_max(self, layout=None): return self.allreduce_data_norm(layout=layout, order=np.inf) From ef9091b98f71a2dcc7b3384c2c161c46e80ef337 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 22 Jul 2025 15:56:10 -0400 Subject: [PATCH 21/29] Cast grid spacing to device array in cartesian cfl --- dedalus/core/basis.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dedalus/core/basis.py b/dedalus/core/basis.py index 9fd2c264..8619b455 100644 --- a/dedalus/core/basis.py +++ b/dedalus/core/basis.py @@ -6083,6 +6083,7 @@ class CartesianAdvectiveCFL(operators.AdvectiveCFL): @CachedMethod def cfl_spacing(self): + xp = self.array_namespace velocity = self.operand coordsys = velocity.tensorsig[0] spacing = [] @@ -6104,7 +6105,7 @@ def cfl_spacing(self): axis_spacing[:] = dealias * native_spacing * basis.COV.stretch elif basis is None: axis_spacing = np.inf - spacing.append(axis_spacing) + spacing.append(xp.asarray(axis_spacing)) return spacing def compute_cfl_frequency(self, velocity, out): From 426cad746e9960b3e77703a784add6e61ff94b62 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 22 Jul 2025 16:02:30 -0400 Subject: [PATCH 22/29] Convert field data gathers to numpy on gpu --- dedalus/core/field.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/dedalus/core/field.py b/dedalus/core/field.py index 09be49e4..93330740 100644 --- a/dedalus/core/field.py +++ b/dedalus/core/field.py @@ -782,9 +782,15 @@ def allgather_data(self, layout=None): # Change layout if layout is not None: self.change_layout(layout) + # Convert to numpy if on GPU + xp = self.dist.array_namespace + if array_api_compat.is_cupy_namespace(xp): + data = xp.asnumpy(self.data) + else: + data = self.data.copy() # Shortcut for serial execution if self.dist.comm.size == 1: - return self.data.copy() + return data # Build global buffers tensor_shape = tuple(cs.dim for cs in self.tensorsig) global_shape = tensor_shape + self.layout.global_shape(self.domain, self.scales) @@ -793,7 +799,7 @@ def allgather_data(self, layout=None): recv_buff = np.empty_like(send_buff) # Combine data via allreduce -- easy but not communication-optimal # Should be optimized using Allgatherv if this is used past startup - send_buff[local_slices] = self.data + send_buff[local_slices] = data self.dist.comm.Allreduce(send_buff, recv_buff, op=MPI.SUM) return recv_buff @@ -801,13 +807,19 @@ def gather_data(self, root=0, layout=None): # Change layout if layout is not None: self.change_layout(layout) + # Convert to numpy if on GPU + xp = self.dist.array_namespace + if array_api_compat.is_cupy_namespace(xp): + data = xp.asnumpy(self.data) + else: + data = self.data.copy() # Shortcut for serial execution if self.dist.comm.size == 1: - return self.data.copy() + return data # TODO: Shortcut this for constant fields # Gather data # Should be optimized via Gatherv eventually - pieces = self.dist.comm.gather(self.data, root=root) + pieces = self.dist.comm.gather(data, root=root) # Assemble on root node if self.dist.comm.rank == root: ext_mesh = self.layout.ext_mesh From c9f5bda026f87c3bd5b2f56625e10f3bbec7a191 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 22 Jul 2025 16:24:55 -0400 Subject: [PATCH 23/29] Fix subsystem gather/scatter to copy to/from gpu --- dedalus/core/subsystems.py | 10 +++++----- dedalus/tools/array.py | 14 ++++++++++++++ 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/dedalus/core/subsystems.py b/dedalus/core/subsystems.py index 191d63e0..29f4bf4b 100644 --- a/dedalus/core/subsystems.py +++ b/dedalus/core/subsystems.py @@ -13,7 +13,7 @@ from math import prod from .domain import Domain -from ..tools.array import zeros_with_pattern, expand_pattern, sparse_block_diag, copyto, perm_matrix, drop_empty_rows, apply_sparse, assert_sparse_pinv +from ..tools.array import zeros_with_pattern, expand_pattern, sparse_block_diag, copyto, perm_matrix, drop_empty_rows, apply_sparse, assert_sparse_pinv, copy_to_device, copy_from_device from ..tools.cache import CachedAttribute, CachedMethod from ..tools.general import replace, OrderedSet from ..tools.progress import log_progress @@ -342,7 +342,7 @@ def gather_inputs(self, fields, out=None): # Gather from fields views = self._input_field_views(tuple(fields)) for buffer_view, field_view in views: - np.copyto(buffer_view, field_view) + copy_from_device(buffer_view, field_view) # Apply right preconditioner inverse to compress inputs if out is None: out = self._compressed_buffer @@ -354,7 +354,7 @@ def gather_outputs(self, fields, out=None): # Gather from fields views = self._output_field_views(tuple(fields)) for buffer_view, field_view in views: - np.copyto(buffer_view, field_view) + copy_from_device(buffer_view, field_view) # Apply left preconditioner to compress outputs if out is None: out = self._compressed_buffer @@ -368,7 +368,7 @@ def scatter_inputs(self, data, fields): # Scatter to fields views = self._input_field_views(tuple(fields)) for buffer_view, field_view in views: - np.copyto(field_view, buffer_view) + copy_to_device(field_view, buffer_view) def scatter_outputs(self, data, fields): """Precondition and scatter subproblem data out to output-like field list.""" @@ -377,7 +377,7 @@ def scatter_outputs(self, data, fields): # Scatter to fields views = self._output_field_views(tuple(fields)) for buffer_view, field_view in views: - np.copyto(field_view, buffer_view) + copy_to_device(field_view, buffer_view) def inclusion_matrices(self, bases): """List of inclusion matrices.""" diff --git a/dedalus/tools/array.py b/dedalus/tools/array.py index 9749f66c..399b4950 100644 --- a/dedalus/tools/array.py +++ b/dedalus/tools/array.py @@ -363,6 +363,20 @@ def copyto(dest, src): dest[:] = src +def copy_to_device(dest, src): + if array_api_compat.is_cupy_array(dest): + dest.set(src) + else: + dest[:] = src + + +def copy_from_device(dest, src): + if array_api_compat.is_cupy_array(src): + src.get(out=dest) + else: + dest[:] = src + + def perm_matrix(perm, M=None, source_index=False, sparse=True): """ Build sparse permutation matrix from permutation vector. From 63f403380d0bd41eb04193e58e1ae5ad16c5ba26 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 22 Jul 2025 17:17:45 -0400 Subject: [PATCH 24/29] Allow for non-contiguous device copy --- dedalus/tools/array.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dedalus/tools/array.py b/dedalus/tools/array.py index 399b4950..bfef6a95 100644 --- a/dedalus/tools/array.py +++ b/dedalus/tools/array.py @@ -364,8 +364,10 @@ def copyto(dest, src): def copy_to_device(dest, src): - if array_api_compat.is_cupy_array(dest): - dest.set(src) + xp = array_api_compat.array_namespace(dest) + if array_api_compat.is_cupy_namespace(xp): + src = xp.asarray(src) + dest[:] = src else: dest[:] = src From 68e2cb2caa078eb68596bbf52a6c56fff6de7a37 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Tue, 22 Jul 2025 17:18:20 -0400 Subject: [PATCH 25/29] Fix cupy csr kernel for double instead of float --- dedalus/tools/linalg_gpu.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dedalus/tools/linalg_gpu.py b/dedalus/tools/linalg_gpu.py index ad0b5098..c4e1cde9 100644 --- a/dedalus/tools/linalg_gpu.py +++ b/dedalus/tools/linalg_gpu.py @@ -79,11 +79,11 @@ def cupy_apply_csr(matrix, array, axis, out): apply_csr_mid_kernel = cp.RawKernel( r''' extern "C" __global__ void apply_csr_mid_kernel( - const float* data, // CSR data of shape (nnz,) + const double* data, // CSR data of shape (nnz,) const int* indices, // CSR column indices (nnz,) const int* indptr, // CSR row pointers (N2o + 1,) - const float* input, // shape (N1, N2i, N3) - float* output, // shape (N1, N2o, N3) + const double* input, // shape (N1, N2i, N3) + double* output, // shape (N1, N2o, N3) int N1, int N2i, int N2o, int N3) { int n1 = blockIdx.x * blockDim.x + threadIdx.x ; // batch index @@ -93,13 +93,13 @@ def cupy_apply_csr(matrix, array, axis, out): // Loop over output rows = CSR matrix rows for (int i = 0; i < N2o; ++i) { - float acc = 0.0f; + double acc = 0; int start = indptr[i]; int end = indptr[i + 1]; for (int k = start; k < end; ++k) { int j = indices[k]; // input column - float val = data[k]; + double val = data[k]; acc += val * input[n1 * N2i * N3 + j * N3 + n3]; } From 9421231e9f3bcca95425c28ecd7d3f157bf21387 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Fri, 25 Jul 2025 12:00:46 -0400 Subject: [PATCH 26/29] Move subsystems, coeff systems, and matrices to GPU --- dedalus/core/subsystems.py | 67 ++++++++++++++++++++++++--------- dedalus/core/system.py | 31 ++++++++------- dedalus/core/timesteppers.py | 36 +++++++++--------- dedalus/libraries/matsolvers.py | 21 ++++++++++- dedalus/tools/array.py | 10 +++++ dedalus/tools/linalg_gpu.py | 6 +-- 6 files changed, 117 insertions(+), 54 deletions(-) diff --git a/dedalus/core/subsystems.py b/dedalus/core/subsystems.py index 29f4bf4b..bd0a6c0b 100644 --- a/dedalus/core/subsystems.py +++ b/dedalus/core/subsystems.py @@ -11,6 +11,7 @@ from mpi4py import MPI import uuid from math import prod +import array_api_compat from .domain import Domain from ..tools.array import zeros_with_pattern, expand_pattern, sparse_block_diag, copyto, perm_matrix, drop_empty_rows, apply_sparse, assert_sparse_pinv, copy_to_device, copy_from_device @@ -18,6 +19,12 @@ from ..tools.general import replace, OrderedSet from ..tools.progress import log_progress +try: + import cupy as cp + import cupyx.scipy.sparse as csp +except ImportError: + pass + import logging logger = logging.getLogger(__name__.split('.')[-1]) @@ -118,6 +125,7 @@ def __init__(self, solver, group): self.solver = solver self.problem = problem = solver.problem self.dist = solver.dist + self.array_namespace = solver.dist.array_namespace self.dtype = problem.dtype self.group = group # Determine matrix group using solver matrix dependence @@ -191,11 +199,12 @@ def field_size(self, field): @CachedMethod def _gather_scatter_setup(self, fields): + xp = self.array_namespace # Allocate vector fsizes = tuple(self.field_size(f) for f in fields) fslices = tuple(self.field_slices(f) for f in fields) fshapes = tuple(self.field_shape(f) for f in fields) - data = np.empty(sum(fsizes), dtype=self.dtype) + data = xp.empty(sum(fsizes), dtype=self.dtype) # Make views into data fviews = [] i0 = 0 @@ -248,6 +257,7 @@ def __init__(self, solver, subsystems, group): self.subsystems = subsystems self.group = group self.dist = problem.dist + self.array_namespace = self.dist.array_namespace self.domain = problem.variables[0].domain # HACK self.dtype = problem.dtype # Cross reference from subsystems @@ -279,7 +289,8 @@ def size(self): @CachedAttribute def _compressed_buffer(self): - return np.zeros(self.shape, dtype=self.dtype) + xp = self.array_namespace + return xp.zeros(self.shape, dtype=self.dtype) def coeff_slices(self, domain): return self.subsystems[0].coeff_slices(domain) @@ -300,9 +311,10 @@ def field_size(self, field): return self.subsystems[0].field_size(field) def _build_buffer_views(self, fields): + xp = self.array_namespace # Allocate buffer fsizes = tuple(self.field_size(f) for f in fields) - buffer = np.zeros((sum(fsizes), len(self.subsystems)), dtype=self.dtype) + buffer = xp.zeros((sum(fsizes), len(self.subsystems)), dtype=self.dtype) # Make views into buffer views = [] i0 = 0 @@ -342,7 +354,7 @@ def gather_inputs(self, fields, out=None): # Gather from fields views = self._input_field_views(tuple(fields)) for buffer_view, field_view in views: - copy_from_device(buffer_view, field_view) + copyto(buffer_view, field_view) # Apply right preconditioner inverse to compress inputs if out is None: out = self._compressed_buffer @@ -354,7 +366,7 @@ def gather_outputs(self, fields, out=None): # Gather from fields views = self._output_field_views(tuple(fields)) for buffer_view, field_view in views: - copy_from_device(buffer_view, field_view) + copyto(buffer_view, field_view) # Apply left preconditioner to compress outputs if out is None: out = self._compressed_buffer @@ -368,7 +380,7 @@ def scatter_inputs(self, data, fields): # Scatter to fields views = self._input_field_views(tuple(fields)) for buffer_view, field_view in views: - copy_to_device(field_view, buffer_view) + copyto(field_view, buffer_view) def scatter_outputs(self, data, fields): """Precondition and scatter subproblem data out to output-like field list.""" @@ -377,7 +389,7 @@ def scatter_outputs(self, data, fields): # Scatter to fields views = self._output_field_views(tuple(fields)) for buffer_view, field_view in views: - copy_to_device(field_view, buffer_view) + copyto(field_view, buffer_view) def inclusion_matrices(self, bases): """List of inclusion matrices.""" @@ -555,24 +567,45 @@ def build_matrices(self, names): left_perm = left_permutation(self, eqns, bc_top=solver.bc_top, interleave_components=solver.interleave_components).tocsr() right_perm = right_permutation(self, vars, tau_left=solver.tau_left, interleave_components=solver.interleave_components).tocsr() - # Preconditioners + # Preconditioners on CPU # TODO: remove astype casting, requires dealing with used types in apply_sparse - self.pre_left = drop_empty_rows(left_perm @ valid_eqn).tocsr().astype(dtype) - self.pre_left_pinv = self.pre_left.T.tocsr().astype(dtype) - self.pre_right_pinv = drop_empty_rows(right_perm @ valid_var).tocsr().astype(dtype) - self.pre_right = self.pre_right_pinv.T.tocsr().astype(dtype) + pre_left = drop_empty_rows(left_perm @ valid_eqn).tocsr().astype(dtype) + pre_left_pinv = pre_left.T.tocsr().astype(dtype) + pre_right_pinv = drop_empty_rows(right_perm @ valid_var).tocsr().astype(dtype) + pre_right = pre_right_pinv.T.tocsr().astype(dtype) # Check preconditioner pseudoinverses - assert_sparse_pinv(self.pre_left, self.pre_left_pinv) - assert_sparse_pinv(self.pre_right, self.pre_right_pinv) + assert_sparse_pinv(pre_left, pre_left_pinv) + assert_sparse_pinv(pre_right, pre_right_pinv) # Precondition matrices for name in matrices: - matrices[name] = self.pre_left @ matrices[name] @ self.pre_right + matrices[name] = pre_left @ matrices[name] @ pre_right - # Store minimal CSR matrices for fast dot products + # Store minimal CSR matrices on CPU for name, matrix in matrices.items(): - setattr(self, '{:}_min'.format(name), matrix.tocsr()) + setattr(self, f'{name}_min', matrix.tocsr()) + + # Store device copies for fast dot products + xp = solver.dist.array_namespace + if array_api_compat.is_numpy_namespace(xp): + self.pre_left = pre_left + self.pre_left_pinv = pre_left_pinv + self.pre_right_pinv = pre_right_pinv + self.pre_right = pre_right + # Reference current CPU matrices + for name, matrix in matrices.items(): + setattr(self, f'{name}_min_device', getattr(self, f'{name}_min')) + elif array_api_compat.is_cupy_namespace(xp): + # Copy to device + self.pre_left = csp.csr_matrix(pre_left) + self.pre_left_pinv = csp.csr_matrix(pre_left_pinv) + self.pre_right_pinv = csp.csr_matrix(pre_right_pinv) + self.pre_right = csp.csr_matrix(pre_right) + for name, matrix in matrices.items(): + setattr(self, f'{name}_min_device', csp.csr_matrix(matrix)) + else: + raise ValueError("Unsupported array namespace: {}".format(xp)) # Store expanded CSR matrices for fast recombination if len(matrices) > 1: diff --git a/dedalus/core/system.py b/dedalus/core/system.py index 23cbb86b..f28cb206 100644 --- a/dedalus/core/system.py +++ b/dedalus/core/system.py @@ -12,45 +12,44 @@ class CoeffSystem: """ - Representation of a collection of fields that don't need to be transformed, - and are therefore stored as a contigous set of coefficient data for - efficient pencil and group manipulation. + Contiguous buffer for data from all subproblems. Parameters ---------- - nfields : int - Number of fields to represent - domain : domain object - Problem domain + subproblems : list of Subproblem objects + Subproblems to represent + dtype : dtype + Data type + array_namespace : array namespace + Array namespace Attributes ---------- data : ndarray - Contiguous buffer for field coefficients - - """ - - """ - var buffer - + Contiguous buffer for data from all subproblems + views : dict + Nested dictionary of views for each subproblem and subsystem """ - def __init__(self, subproblems, dtype): + def __init__(self, subproblems, dtype, array_namespace): + xp = array_namespace # Build buffer total_size = sum(sp.LHS.shape[1]*len(sp.subsystems) for sp in subproblems) - self.data = np.zeros(total_size, dtype=dtype) + self.data = xp.zeros(total_size, dtype=dtype) # Build views i0 = i1 = 0 self.views = views = {} for sp in subproblems: views[sp] = views_sp = {} + # View for each individual subsystem i00 = i0 for ss in sp.subsystems: i1 += sp.LHS.shape[1] views_sp[ss] = self.data[i0:i1] i0 = i1 i11 = i1 + # View combining all subsystems as rows in a matrix if i11 - i00 > 0: views_sp[None] = self.data[i00:i11].reshape((sp.LHS.shape[1], -1)) else: diff --git a/dedalus/core/timesteppers.py b/dedalus/core/timesteppers.py index 81da4c10..162a2d32 100644 --- a/dedalus/core/timesteppers.py +++ b/dedalus/core/timesteppers.py @@ -2,10 +2,9 @@ from collections import deque, OrderedDict import numpy as np -from scipy.linalg import blas from .system import CoeffSystem -from ..tools.array import apply_sparse +from ..tools.array import apply_sparse, get_axpy # Public interface @@ -71,7 +70,8 @@ class MultistepIMEX: def __init__(self, solver): self.solver = solver - self.RHS = CoeffSystem(solver.subproblems, dtype=solver.dtype) + xp = solver.dist.array_namespace + self.RHS = CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=xp) # Create deque for storing recent timesteps self.dt = deque([0.] * self.steps) @@ -81,16 +81,16 @@ def __init__(self, solver): self.LX = LX = deque() self.F = F = deque() for j in range(self.amax): - MX.append(CoeffSystem(solver.subproblems, dtype=solver.dtype)) + MX.append(CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=xp)) for j in range(self.bmax): - LX.append(CoeffSystem(solver.subproblems, dtype=solver.dtype)) + LX.append(CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=xp)) for j in range(self.cmax): - F.append(CoeffSystem(solver.subproblems, dtype=solver.dtype)) + F.append(CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=xp)) # Attributes self._iteration = 0 self._LHS_params = None - self.axpy = blas.get_blas_funcs('axpy', dtype=solver.dtype) + self.axpy = get_axpy(xp, solver.dtype) def step(self, dt, wall_time): """Advance solver by one timestep.""" @@ -143,8 +143,8 @@ def step(self, dt, wall_time): evaluator.require_coeff_space(state_fields) for sp in subproblems: spX = sp.gather_inputs(state_fields) - apply_sparse(sp.M_min, spX, axis=0, out=MX0.get_subdata(sp)) - apply_sparse(sp.L_min, spX, axis=0, out=LX0.get_subdata(sp)) + apply_sparse(sp.M_min_device, spX, axis=0, out=MX0.get_subdata(sp)) + apply_sparse(sp.L_min_device, spX, axis=0, out=LX0.get_subdata(sp)) # Evaluate F(X0) evaluator.evaluate_scheduled(iteration=iteration, wall_time=wall_time, sim_time=sim_time, timestep=dt) @@ -539,15 +539,16 @@ class RungeKuttaIMEX: def __init__(self, solver): self.solver = solver - self.RHS = CoeffSystem(solver.subproblems, dtype=solver.dtype) + xp = solver.dist.array_namespace + self.RHS = CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=xp) # Create coefficient systems for multistep history - self.MX0 = CoeffSystem(solver.subproblems, dtype=solver.dtype) - self.LX = [CoeffSystem(solver.subproblems, dtype=solver.dtype) for i in range(self.stages)] - self.F = [CoeffSystem(solver.subproblems, dtype=solver.dtype) for i in range(self.stages)] + self.MX0 = CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=xp) + self.LX = [CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=xp) for i in range(self.stages)] + self.F = [CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=xp) for i in range(self.stages)] self._LHS_params = None - self.axpy = blas.get_blas_funcs('axpy', dtype=solver.dtype) + self.axpy = get_axpy(xp, solver.dtype) def step(self, dt, wall_time): """Advance solver by one timestep.""" @@ -584,11 +585,12 @@ def step(self, dt, wall_time): # Compute M.X(n,0) and L.X(n,0) # Ensure coeff space before subsystem gathers + # TODO: add option to evaluate this matrix-free (e.g for high-bandwidth NCCs when using fast transforms) evaluator.require_coeff_space(state_fields) for sp in subproblems: spX = sp.gather_inputs(state_fields) - apply_sparse(sp.M_min, spX, axis=0, out=MX0.get_subdata(sp)) - apply_sparse(sp.L_min, spX, axis=0, out=LX0.get_subdata(sp)) + apply_sparse(sp.M_min_device, spX, axis=0, out=MX0.get_subdata(sp)) + apply_sparse(sp.L_min_device, spX, axis=0, out=LX0.get_subdata(sp)) # Compute stages # (M + k Hii L).X(n,i) = M.X(n,0) + k Aij F(n,j) - k Hij L.X(n,j) @@ -601,7 +603,7 @@ def step(self, dt, wall_time): evaluator.require_coeff_space(state_fields) for sp in subproblems: spX = sp.gather_inputs(state_fields) - apply_sparse(sp.L_min, spX, axis=0, out=LXi.get_subdata(sp)) + apply_sparse(sp.L_min_device, spX, axis=0, out=LXi.get_subdata(sp)) # Compute F(n,i-1), only doing output on first evaluation if i == 1: diff --git a/dedalus/libraries/matsolvers.py b/dedalus/libraries/matsolvers.py index f301d4f2..2544f8e9 100644 --- a/dedalus/libraries/matsolvers.py +++ b/dedalus/libraries/matsolvers.py @@ -5,7 +5,12 @@ import scipy.sparse as sp import scipy.sparse.linalg as spla from functools import partial - +import array_api_compat +try: + import cupyx.scipy.sparse.linalg as cupy_spla + cupy_available = True +except ImportError: + cupy_available = False matsolvers = {} def add_solver(solver): @@ -144,6 +149,17 @@ def __init__(self, matrix, solver=None): relax=self.relax, panel_size=self.panel_size, options=self.options) + # Cupy conversion + if array_api_compat.is_cupy_namespace(solver.dist.array_namespace): + # Avoid cupy splu which requires GPU matrices but transfers them to factorize on CPU + # Run same typecheck as cupy splu + if matrix.dtype.char not in 'fdFD': + raise TypeError('Invalid dtype (actual: {})'.format(self.LU.dtype)) + # Build cupy factorization from scipy factorization of CPU matrices + self.LU = cupy_spla.SuperLU(self.LU) + sp.save_npz("block1024.npz", matrix) + print(self.LU.shape) + print(self.LU.nnz) def solve(self, vector): return self.LU.solve(vector, trans=self.trans) @@ -225,6 +241,9 @@ class SparseInverse(SparseSolver): def __init__(self, matrix, solver=None): self.matrix_inverse = spla.inv(matrix.tocsc()) + # Cupy conversion + if array_api_compat.is_cupy_namespace(solver.dist.array_namespace): + self.matrix_inverse = cupy_spla.inv(matrix.tocsc()) def solve(self, vector): return self.matrix_inverse @ vector diff --git a/dedalus/tools/array.py b/dedalus/tools/array.py index bfef6a95..2d6a69e6 100644 --- a/dedalus/tools/array.py +++ b/dedalus/tools/array.py @@ -5,6 +5,7 @@ import scipy.sparse as sp from scipy.sparse import _sparsetools from scipy.sparse import linalg as spla +from scipy.linalg import blas from math import prod from ..tools import linalg_gpu import array_api_compat @@ -500,3 +501,12 @@ def assert_sparse_pinv(A, B): if not sparse_allclose((B @ A).conj().T, B @ A): raise AssertionError("Not a pseudoinverse") + +def get_axpy(array_namespace, dtype): + if array_api_compat.is_numpy_namespace(array_namespace): + return blas.get_blas_funcs('axpy', dtype=dtype) + elif array_api_compat.is_cupy_namespace(array_namespace): + from cupy.cublas import axpy as cublas_axpy + return cublas_axpy + else: + raise ValueError("Unsupported array namespace") diff --git a/dedalus/tools/linalg_gpu.py b/dedalus/tools/linalg_gpu.py index c4e1cde9..95ccfe52 100644 --- a/dedalus/tools/linalg_gpu.py +++ b/dedalus/tools/linalg_gpu.py @@ -5,14 +5,14 @@ try: import cupy as cp import cupyx.scipy.sparse as csp - HAVE_CUPY = True + cupy_available = True except ImportError: - HAVE_CUPY = False + cupy_available = False def cupy_apply_csr(matrix, array, axis, out): """Apply CSR matrix to arbitrary axis of array.""" - if not HAVE_CUPY: + if not cupy_available: raise ImportError("cupy must be installed to use GPU linear algebra") # Check matrix format if not isinstance(matrix, csp.csr_matrix): From 15a2d6e6bee87e09f92333a3d2e467ecbfd29fff Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Fri, 25 Jul 2025 13:49:34 -0400 Subject: [PATCH 27/29] Build custom cupy superlu wrapper to reuse spsm descriptors --- dedalus/libraries/matsolvers.py | 10 +- dedalus/tools/linalg_gpu.py | 244 ++++++++++++++++++++++++++++++++ 2 files changed, 251 insertions(+), 3 deletions(-) diff --git a/dedalus/libraries/matsolvers.py b/dedalus/libraries/matsolvers.py index 2544f8e9..ede93a1b 100644 --- a/dedalus/libraries/matsolvers.py +++ b/dedalus/libraries/matsolvers.py @@ -157,9 +157,13 @@ def __init__(self, matrix, solver=None): raise TypeError('Invalid dtype (actual: {})'.format(self.LU.dtype)) # Build cupy factorization from scipy factorization of CPU matrices self.LU = cupy_spla.SuperLU(self.LU) - sp.save_npz("block1024.npz", matrix) - print(self.LU.shape) - print(self.LU.nnz) + self.LU.spsm_L_descr = None + self.LU.spsm_U_descr = None + self.solve = self.cupy_solve + + def cupy_solve(self, vector): + from dedalus.tools.linalg_gpu import custom_SuperLU_solve + return custom_SuperLU_solve(self.LU, vector, trans=self.trans) def solve(self, vector): return self.LU.solve(vector, trans=self.trans) diff --git a/dedalus/tools/linalg_gpu.py b/dedalus/tools/linalg_gpu.py index 95ccfe52..220d3615 100644 --- a/dedalus/tools/linalg_gpu.py +++ b/dedalus/tools/linalg_gpu.py @@ -123,3 +123,247 @@ def cupy_apply_csr_mid(matrix, array, out): # Launch kernel apply_csr_mid_kernel(griddim, blockdim, (matrix.data, matrix.indices, matrix.indptr, array, out, N1, N2i, N2o, N3)) + +def custom_spsm(a, b, alpha=1.0, lower=True, unit_diag=False, transa=False, spsm_descr=None): + """Custom spsm wrapper to save spsm_descr, since spsm_analysis takes lots of time.""" + """Solves a sparse triangular linear system op(a) * x = alpha * op(b). + + Args: + a (cupyx.scipy.sparse.csr_matrix or cupyx.scipy.sparse.coo_matrix): + Sparse matrix with dimension ``(M, M)``. + b (cupy.ndarray): Dense matrix with dimension ``(M, K)``. + alpha (float or complex): Coefficient. + lower (bool): + True: ``a`` is lower triangle matrix. + False: ``a`` is upper triangle matrix. + unit_diag (bool): + True: diagonal part of ``a`` has unit elements. + False: diagonal part of ``a`` has non-unit elements. + transa (bool or str): True, False, 'N', 'T' or 'H'. + 'N' or False: op(a) == ``a``. + 'T' or True: op(a) == ``a.T``. + 'H': op(a) == ``a.conj().T``. + """ + import cupyx + from cupyx import cusparse + import cupy as _cupy + import numpy as _numpy + from cupy._core import _dtype + from cupy_backends.cuda.libs import cusparse as _cusparse + from cupy.cuda import device as _device + from cupyx.cusparse import SpMatDescriptor, DnMatDescriptor + if not cusparse.check_availability('spsm'): + raise RuntimeError('spsm is not available.') + + # Canonicalise transa + if transa is False: + transa = 'N' + elif transa is True: + transa = 'T' + elif transa not in 'NTH': + raise ValueError(f'Unknown transa (actual: {transa})') + + # Check A's type and sparse format + if cupyx.scipy.sparse.isspmatrix_csr(a): + pass + elif cupyx.scipy.sparse.isspmatrix_csc(a): + if transa == 'N': + a = a.T + transa = 'T' + elif transa == 'T': + a = a.T + transa = 'N' + elif transa == 'H': + a = a.conj().T + transa = 'N' + lower = not lower + elif cupyx.scipy.sparse.isspmatrix_coo(a): + pass + else: + raise ValueError('a must be CSR, CSC or COO sparse matrix') + assert a.has_canonical_format + + # Check B's ndim + if b.ndim == 1: + is_b_vector = True + b = b.reshape(-1, 1) + elif b.ndim == 2: + is_b_vector = False + else: + raise ValueError('b.ndim must be 1 or 2') + + # Check shapes + if not (a.shape[0] == a.shape[1] == b.shape[0]): + raise ValueError('mismatched shape') + + # Check dtypes + dtype = a.dtype + if dtype.char not in 'fdFD': + raise TypeError('Invalid dtype (actual: {})'.format(dtype)) + if dtype != b.dtype: + raise TypeError('dtype mismatch') + + # Prepare fill mode + if lower is True: + fill_mode = _cusparse.CUSPARSE_FILL_MODE_LOWER + elif lower is False: + fill_mode = _cusparse.CUSPARSE_FILL_MODE_UPPER + else: + raise ValueError('Unknown lower (actual: {})'.format(lower)) + + # Prepare diag type + if unit_diag is False: + diag_type = _cusparse.CUSPARSE_DIAG_TYPE_NON_UNIT + elif unit_diag is True: + diag_type = _cusparse.CUSPARSE_DIAG_TYPE_UNIT + else: + raise ValueError('Unknown unit_diag (actual: {})'.format(unit_diag)) + + # Prepare op_a + if transa == 'N': + op_a = _cusparse.CUSPARSE_OPERATION_NON_TRANSPOSE + elif transa == 'T': + op_a = _cusparse.CUSPARSE_OPERATION_TRANSPOSE + else: # transa == 'H' + if dtype.char in 'fd': + op_a = _cusparse.CUSPARSE_OPERATION_TRANSPOSE + else: + op_a = _cusparse.CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE + + # Prepare op_b + if b._f_contiguous: + op_b = _cusparse.CUSPARSE_OPERATION_NON_TRANSPOSE + elif b._c_contiguous: + if _cusparse.get_build_version() < 11701: # earlier than CUDA 11.6 + raise ValueError('b must be F-contiguous.') + b = b.T + op_b = _cusparse.CUSPARSE_OPERATION_TRANSPOSE + else: + raise ValueError('b must be F-contiguous or C-contiguous.') + + # Allocate space for matrix C. Note that it is known cusparseSpSM requires + # the output matrix zero initialized. + m, _ = a.shape + if op_b == _cusparse.CUSPARSE_OPERATION_NON_TRANSPOSE: + _, n = b.shape + else: + n, _ = b.shape + c_shape = m, n + c = _cupy.zeros(c_shape, dtype=a.dtype, order='f') + + # Prepare descriptors and other parameters + handle = _device.get_cusparse_handle() + mat_a = SpMatDescriptor.create(a) + mat_b = DnMatDescriptor.create(b) + mat_c = DnMatDescriptor.create(c) + if spsm_descr is None: + spsm_descr = _cusparse.spSM_createDescr() + new_spsm_descr = True + else: + spsm_descr, buff = spsm_descr + new_spsm_descr = False + alpha = _numpy.array(alpha, dtype=c.dtype).ctypes + cuda_dtype = _dtype.to_cuda_dtype(c.dtype) + algo = _cusparse.CUSPARSE_SPSM_ALG_DEFAULT + + try: + # Specify Lower|Upper fill mode + mat_a.set_attribute(_cusparse.CUSPARSE_SPMAT_FILL_MODE, fill_mode) + + # Specify Unit|Non-Unit diagonal type + mat_a.set_attribute(_cusparse.CUSPARSE_SPMAT_DIAG_TYPE, diag_type) + + # Allocate the workspace needed by the succeeding phases + if new_spsm_descr: + buff_size = _cusparse.spSM_bufferSize( + handle, op_a, op_b, alpha.data, mat_a.desc, mat_b.desc, + mat_c.desc, cuda_dtype, algo, spsm_descr) + buff = _cupy.empty(buff_size, dtype=_cupy.int8) + + # Perform the analysis phase + if new_spsm_descr: + _cusparse.spSM_analysis( + handle, op_a, op_b, alpha.data, mat_a.desc, mat_b.desc, + mat_c.desc, cuda_dtype, algo, spsm_descr, buff.data.ptr) + + # Executes the solve phase + _cusparse.spSM_solve( + handle, op_a, op_b, alpha.data, mat_a.desc, mat_b.desc, + mat_c.desc, cuda_dtype, algo, spsm_descr, buff.data.ptr) + + # Reshape back if B was a vector + if is_b_vector: + c = c.reshape(-1) + + return c, (spsm_descr, buff) + + finally: + # Destroy matrix/vector descriptors + #_cusparse.spSM_destroyDescr(spsm_descr) + pass + + +def custom_SuperLU_solve(self, rhs, trans='N', spsm_descr=None): + """Custom SuperLU solve wrapper to save spsm_descr, since spsm_analysis takes lots of time.""" + """Solves linear system of equations with one or several right-hand sides. + + Args: + rhs (cupy.ndarray): Right-hand side(s) of equation with dimension + ``(M)`` or ``(M, K)``. + trans (str): 'N', 'T' or 'H'. + 'N': Solves ``A * x = rhs``. + 'T': Solves ``A.T * x = rhs``. + 'H': Solves ``A.conj().T * x = rhs``. + + Returns: + cupy.ndarray: + Solution vector(s) + """ # NOQA + from cupyx import cusparse + import cupy + from cupyx.scipy.sparse.linalg._solve import _should_use_spsm + + if not isinstance(rhs, cupy.ndarray): + raise TypeError('ojb must be cupy.ndarray') + if rhs.ndim not in (1, 2): + raise ValueError('rhs.ndim must be 1 or 2 (actual: {})'. + format(rhs.ndim)) + if rhs.shape[0] != self.shape[0]: + raise ValueError('shape mismatch (self.shape: {}, rhs.shape: {})' + .format(self.shape, rhs.shape)) + if trans not in ('N', 'T', 'H'): + raise ValueError('trans must be \'N\', \'T\', or \'H\'') + + if cusparse.check_availability('spsm') and _should_use_spsm(rhs): + def spsm(A, B, lower, transa, spsm_descr): + return custom_spsm(A, B, lower=lower, transa=transa, spsm_descr=spsm_descr) + sm = spsm + else: + raise NotImplementedError + + x = rhs.astype(self.L.dtype) + if trans == 'N': + if self.perm_r is not None: + if x.ndim == 2 and x._f_contiguous: + x = x.T[:, self._perm_r_rev].T # want to keep f-order + else: + x = x[self._perm_r_rev] + x, self.spsm_L_descr = sm(self.L, x, lower=True, transa=trans, spsm_descr=self.spsm_L_descr) + x, self.spsm_U_descr = sm(self.U, x, lower=False, transa=trans, spsm_descr=self.spsm_U_descr) + if self.perm_c is not None: + x = x[self.perm_c] + else: + if self.perm_c is not None: + if x.ndim == 2 and x._f_contiguous: + x = x.T[:, self._perm_c_rev].T # want to keep f-order + else: + x = x[self._perm_c_rev] + x, self.spsm_U_descr = sm(self.U, x, lower=False, transa=trans, spsm_descr=self.spsm_U_descr) + x, self.spsm_L_descr = sm(self.L, x, lower=True, transa=trans, spsm_descr=self.spsm_L_descr) + if self.perm_r is not None: + x = x[self.perm_r] + + if not x._f_contiguous: + # For compatibility with SciPy + x = x.copy(order='F') + return x From 2e674b7d514d4a72aede405c3c33f627faced096 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Fri, 25 Jul 2025 16:46:49 -0400 Subject: [PATCH 28/29] Move all operator matrices to device. Add Chebyshev transforms --- dedalus/core/basis.py | 8 +- dedalus/core/operators.py | 17 +++- dedalus/core/transforms.py | 74 +++++++++++++-- dedalus/tools/array.py | 72 ++++++++++----- dedalus/tools/linalg_gpu.py | 176 ++++++++++++++++++++++++++++++++++++ 5 files changed, 311 insertions(+), 36 deletions(-) diff --git a/dedalus/core/basis.py b/dedalus/core/basis.py index 8619b455..fe44931c 100644 --- a/dedalus/core/basis.py +++ b/dedalus/core/basis.py @@ -506,7 +506,13 @@ def _native_grid(self, scale): @CachedMethod def transform_plan(self, dist, grid_size): """Build transform plan.""" - return self.transforms[self.library](grid_size, self.size, self.a, self.b, self.a0, self.b0) + xp = dist.array_namespace + xp_name = xp.__name__.split('.')[-1] + # Shortcut trivial transforms + if grid_size == 1 or self.size == 1: + return self.transforms[f"matrix-{xp_name}"](grid_size, self.size, self.a, self.b, self.a0, self.b0, dist.array_namespace, dist.dtype) + else: + return self.transforms[f"{self.library}-{xp_name}"](grid_size, self.size, self.a, self.b, self.a0, self.b0, dist.array_namespace, dist.dtype) # def weights(self, scales): # """Gauss-Jacobi weights.""" diff --git a/dedalus/core/operators.py b/dedalus/core/operators.py index cc6f0dad..3655fe72 100644 --- a/dedalus/core/operators.py +++ b/dedalus/core/operators.py @@ -15,6 +15,7 @@ from math import prod from ..libraries import dedalus_sphere import logging +import array_api_compat logger = logging.getLogger(__name__.split('.')[-1]) from .domain import Domain @@ -950,6 +951,20 @@ def subspace_matrix(self, layout): # Caching layer to allow insertion of other arguments return self._subspace_matrix(layout, self.input_basis, self.output_basis, self.first_axis) + @CachedMethod + def subspace_matrix_device(self, layout): + """Build matrix operating on local subspace data on device.""" + # Caching layer to allow insertion of other arguments + matrix = self._subspace_matrix(layout, self.input_basis, self.output_basis, self.first_axis) + if array_api_compat.is_cupy_namespace(self.array_namespace): + import cupy as cp + import cupyx.scipy.sparse as csp + if sparse.issparse(matrix): + matrix = csp.csr_matrix(matrix) + else: + matrix = cp.array(matrix) + return matrix + def group_matrix(self, group): return self._group_matrix(group, self.input_basis, self.output_basis) @@ -990,7 +1005,7 @@ def operate(self, out): # Apply matrix if arg.data.size and out.data.size: data_axis = self.last_axis + len(arg.tensorsig) - apply_matrix(self.subspace_matrix(layout), arg.data, data_axis, out=out.data) + apply_matrix(self.subspace_matrix_device(layout), arg.data, data_axis, out=out.data) else: out.data.fill(0) diff --git a/dedalus/core/transforms.py b/dedalus/core/transforms.py index 5616c1c7..0433d802 100644 --- a/dedalus/core/transforms.py +++ b/dedalus/core/transforms.py @@ -8,14 +8,16 @@ import scipy.fftpack from ..libraries import dedalus_sphere from math import prod +import array_api_compat from . import basis from ..libraries.fftw import fftw_wrappers as fftw from ..tools import jacobi -from ..tools.array import apply_matrix, apply_dense, axslice, solve_upper_sparse, apply_sparse +from ..tools.array import apply_matrix, apply_dense, axslice, solve_upper_sparse, apply_sparse, copyto from ..tools.cache import CachedAttribute from ..tools.cache import CachedMethod from ..tools.general import float_to_complex +from ..tools.linalg_gpu import cupy_solve_upper_csr, CustomCupyUpperTriangularSolver import logging logger = logging.getLogger(__name__.split('.')[-1]) @@ -94,31 +96,39 @@ class JacobiTransform(SeparableTransform): Jacobi "a" parameter for the quadrature grid. b0 : int Jacobi "b" parameter for the quadrature grid. + array_namespace : array namespace + Array namespace for the transform. + dtype : dtype + Data type for the transform. Notes ----- TODO: We need to define the normalization we use here. """ - def __init__(self, grid_size, coeff_size, a, b, a0, b0, dealias_before_converting=None): + def __init__(self, grid_size, coeff_size, a, b, a0, b0, array_namespace, dtype, dealias_before_converting=None): self.N = grid_size self.M = coeff_size self.a = a self.b = b self.a0 = a0 self.b0 = b0 + self.array_namespace = array_namespace + self.dtype = dtype if dealias_before_converting is None: dealias_before_converting = GET_DEALIAS_BEFORE_CONVERTING() self.dealias_before_converting = dealias_before_converting -@register_transform(basis.Jacobi, 'matrix') +@register_transform(basis.Jacobi, 'matrix-numpy') +@register_transform(basis.Jacobi, 'matrix-cupy') class JacobiMMT(JacobiTransform, SeparableMatrixTransform): """Jacobi polynomial MMTs.""" @CachedAttribute def forward_matrix(self): """Build forward transform matrix.""" + xp = self.array_namespace N, M = self.N, self.M a, a0 = self.a, self.a0 b, b0 = self.b, self.b0 @@ -142,11 +152,12 @@ def forward_matrix(self): # Truncate to specified coeff_size forward_matrix = forward_matrix[:M, :] # Ensure C ordering for fast dot products - return np.asarray(forward_matrix, order='C') + return xp.asarray(forward_matrix, order='C', dtype=self.dtype) @CachedAttribute def backward_matrix(self): """Build backward transform matrix.""" + xp = self.array_namespace N, M = self.N, self.M a, a0 = self.a, self.a0 b, b0 = self.b, self.b0 @@ -156,7 +167,7 @@ def backward_matrix(self): # Zero higher polynomials for transforms with grid_size < coeff_size polynomials[N:, :] = 0 # Transpose and ensure C ordering for fast dot products - return np.asarray(polynomials.T, order='C') + return xp.asarray(polynomials.T, order='C', dtype=self.dtype) class ComplexFourierTransform(SeparableTransform): @@ -848,6 +859,33 @@ def backward(self, cdata, gdata, axis): np.copyto(gdata, temp) +class CupyDCT(FastCosineTransform): + """Fast cosine transform using cupy fft.""" + + def __init__(self, *args, **kw): + import cupyx.scipy.fft as cufft + self.cufft = cufft + super().__init__(*args, **kw) + + def forward(self, gdata, cdata, axis): + """Apply forward transform along specified axis.""" + # Call DCT + temp = self.cufft.dct(gdata, type=2, axis=axis) # Creates temporary + # Resize and rescale for unit-ampltidue normalization + self.resize_rescale_forward(temp, cdata, axis, self.Kmax) + + def backward(self, cdata, gdata, axis): + """Apply backward transform along specified axis.""" + xp = self.array_namespace + # Resize and rescale for unit-amplitude normalization + # Need temporary to avoid overwriting problems + temp = xp.empty_like(gdata) # Creates temporary + self.resize_rescale_backward(cdata, temp, axis, self.Kmax) + # Call IDCT + temp = self.cufft.dct(temp, type=3, axis=axis, overwrite_x=True) # Creates temporary + copyto(gdata, temp) + + #@register_transform(basis.Cosine, 'fftw') class FFTWDCT(FFTWBase, FastCosineTransform): """Fast cosine transform using FFTW.""" @@ -884,11 +922,11 @@ class FastChebyshevTransform(JacobiTransform): Subclasses should inherit from this class, then a FastCosineTransform subclass. """ - def __init__(self, grid_size, coeff_size, a, b, a0, b0, **kw): + def __init__(self, grid_size, coeff_size, a, b, a0, b0, array_namespace, dtype, **kw): if not a0 == b0 == -1/2: raise ValueError("Fast Chebshev transform requires a0 == b0 == -1/2.") # Jacobi initialization - super().__init__(grid_size, coeff_size, a, b, a0, b0, **kw) + super().__init__(grid_size, coeff_size, a, b, a0, b0, array_namespace, dtype, **kw) # DCT initialization to set scaling factors if a != a0 or b != b0: # Modify coeff_size to avoid truncation before conversion @@ -920,6 +958,13 @@ def __init__(self, grid_size, coeff_size, a, b, a0, b0, **kw): self.backward_conversion.sum_duplicates() # for faster solve_upper self.resize_rescale_forward = self._resize_rescale_forward_convert self.resize_rescale_backward = self._resize_rescale_backward_convert + if array_api_compat.is_cupy_namespace(self.array_namespace): + import cupyx.scipy.sparse as csp + self.forward_conversion = csp.csr_matrix(self.forward_conversion) + self.backward_conversion = csp.csr_matrix(self.backward_conversion) + self.forward_conversion.sum_duplicates() + self.backward_conversion.sum_duplicates() + self.backward_conversion_LU = CustomCupyUpperTriangularSolver(self.backward_conversion) def _resize_rescale_forward(self, data_in, data_out, axis, Kmax): """Resize by padding/trunction and rescale to unit amplitude.""" @@ -961,7 +1006,10 @@ def _resize_rescale_backward_convert(self, data_in, data_out, axis, Kmax_DCT): # Truncate input before conversion data_in[badfreq] = 0 # Ultraspherical conversion - solve_upper_sparse(self.backward_conversion, data_in, axis, out=data_in) + if array_api_compat.is_cupy_namespace(self.array_namespace): + cupy_solve_upper_csr(self.backward_conversion_LU, data_in, axis, out=data_in) + else: + solve_upper_sparse(self.backward_conversion, data_in, axis, out=data_in) # Change sign of odd modes if Kmax_orig > 0: posfreq_odd = axslice(axis, 1, Kmax_orig+1, 2) @@ -970,18 +1018,24 @@ def _resize_rescale_backward_convert(self, data_in, data_out, axis, Kmax_DCT): super().resize_rescale_backward(data_in, data_out, axis, Kmax_orig) -@register_transform(basis.Jacobi, 'scipy_dct') +@register_transform(basis.Jacobi, 'scipy_dct-numpy') class ScipyFastChebyshevTransform(FastChebyshevTransform, ScipyDCT): """Fast ultraspherical transform using scipy.fft and spectral conversion.""" pass # Implementation is complete via inheritance -@register_transform(basis.Jacobi, 'fftw_dct') +@register_transform(basis.Jacobi, 'fftw_dct-numpy') class FFTWFastChebyshevTransform(FastChebyshevTransform, FFTWDCT): """Fast ultraspherical transform using scipy.fft and spectral conversion.""" pass # Implementation is complete via inheritance +@register_transform(basis.Jacobi, 'scipy_dct-cupy') +class CupyFastChebyshevTransform(FastChebyshevTransform, CupyDCT): + """Fast ultraspherical transform using cupy fft and spectral conversion.""" + pass # Implementation is complete via inheritance + + # class ScipyDST(PolynomialTransform): # def forward_reduced(self): diff --git a/dedalus/tools/array.py b/dedalus/tools/array.py index 2d6a69e6..e137f75d 100644 --- a/dedalus/tools/array.py +++ b/dedalus/tools/array.py @@ -79,10 +79,20 @@ def expand_pattern(input, pattern): def apply_matrix(matrix, array, axis, **kw): """Apply matrix along any axis of an array.""" - if sparse.isspmatrix(matrix): - return apply_sparse(matrix, array, axis, **kw) + xp = array_api_compat.array_namespace(array) + if array_api_compat.is_numpy_namespace(xp): + if sparse.issparse(matrix): + return apply_sparse(matrix, array, axis, **kw) + else: + return apply_dense(matrix, array, axis, **kw) + elif array_api_compat.is_cupy_namespace(xp): + import cupyx.scipy.sparse as csp + if csp.issparse(matrix): + return apply_sparse(matrix, array, axis, **kw) + else: + return apply_dense(matrix, array, axis, **kw) else: - return apply_dense(matrix, array, axis, **kw) + raise ValueError("Unsupported array type") def apply_dense_einsum(matrix, array, axis, optimize=True, **kw): @@ -177,6 +187,8 @@ def apply_sparse(matrix, array, axis, out=None, check_shapes=False, num_threads= Must be out of place if ouptut is specified. """ xp = array_api_compat.array_namespace(array) + matrix.sum_duplicates() + matrix.has_canonical_format = True # Check output if out is None: out_shape = list(array.shape) @@ -219,28 +231,40 @@ def solve_upper_sparse(matrix, rhs, axis, out=None, check_shapes=False, num_thre Solve upper triangular sparse matrix along any axis of an array. Matrix assumed to be nonzero on the diagonals. """ - # Check matrix - if not isinstance(matrix, sparse.csr_matrix): - raise ValueError("Matrix must be in CSR format.") - if not matrix._has_canonical_format: # avoid property hook (without underscore) - matrix.sum_duplicates() - # Setup output = rhs + xp = array_api_compat.array_namespace(rhs) + matrix.sum_duplicates() + matrix.has_canonical_format = True + # Check output if out is None: - out = np.copy(rhs) - elif out is not rhs: - np.copyto(out, rhs) - # Promote datatypes - matrix_data = matrix.data - if matrix_data.dtype != rhs.dtype: - matrix_data = matrix_data.astype(rhs.dtype) - # Check shapes - if check_shapes: - if not (0 <= axis < rhs.ndim): - raise ValueError("Axis out of bounds.") - if not (matrix.shape[0] == matrix.shape[1] == rhs.shape[axis]): - raise ValueError("Matrix shape mismatch.") - # Call cython routine - cython_linalg.solve_upper_csr(matrix.indptr, matrix.indices, matrix_data, out, axis, num_threads) + out = xp.empty_like(rhs) + # Dispatch on array type + if array_api_compat.is_numpy_namespace(xp): + # Check matrix + if not isinstance(matrix, sparse.csr_matrix): + raise ValueError("Matrix must be in CSR format.") + if not matrix._has_canonical_format: # avoid property hook (without underscore) + matrix.sum_duplicates() + # Setup output = rhs + copyto(out, rhs) + # Promote datatypes + matrix_data = matrix.data + if matrix_data.dtype != rhs.dtype: + matrix_data = matrix_data.astype(rhs.dtype) + # Check shapes + if check_shapes: + if not (0 <= axis < rhs.ndim): + raise ValueError("Axis out of bounds.") + if not (matrix.shape[0] == matrix.shape[1] == rhs.shape[axis]): + raise ValueError("Matrix shape mismatch.") + # Call cython routine + cython_linalg.solve_upper_csr(matrix.indptr, matrix.indices, matrix_data, out, axis, num_threads) + elif array_api_compat.is_cupy_namespace(xp): + if not matrix._has_canonical_format: # avoid property hook (without underscore) + matrix.sum_duplicates() + linalg_gpu.cupy_solve_upper_csr(matrix, rhs, axis, out) + else: + raise ValueError("Unsupported array type") + return out def csr_matvec(A_csr, x_vec, out_vec): diff --git a/dedalus/tools/linalg_gpu.py b/dedalus/tools/linalg_gpu.py index 220d3615..a64f51f1 100644 --- a/dedalus/tools/linalg_gpu.py +++ b/dedalus/tools/linalg_gpu.py @@ -5,6 +5,7 @@ try: import cupy as cp import cupyx.scipy.sparse as csp + import cupyx.scipy.sparse.linalg as cupy_spla cupy_available = True except ImportError: cupy_available = False @@ -17,6 +18,7 @@ def cupy_apply_csr(matrix, array, axis, out): # Check matrix format if not isinstance(matrix, csp.csr_matrix): # TODO: avoid this explicit conversion + print('WARNING: converting matrix to CSR format') matrix = csp.csr_matrix(matrix) #raise ValueError("Matrix must be in CSR format.") # Switch by dimension @@ -367,3 +369,177 @@ def spsm(A, B, lower, transa, spsm_descr): # For compatibility with SciPy x = x.copy(order='F') return x + + +class CustomCupyUpperTriangularSolver: + """Hacky class to save spsm_descr for reuse in spsm for triangular solves.""" + + def __init__(self, matrix): + # Check matrix format + if not isinstance(matrix, csp.csr_matrix): + # TODO: avoid this explicit conversion + matrix = csp.csr_matrix(matrix) + print('WARNING: converting matrix to CSR format') + #raise ValueError("Matrix must be in CSR format.") + self.matrix = matrix + self.spsm_descr = None + + def solve(self, b, lower=True, overwrite_A=False, overwrite_b=False, + unit_diagonal=False): + """Solves a sparse triangular system ``A x = b``. + + Args: + A (cupyx.scipy.sparse.spmatrix): + Sparse matrix with dimension ``(M, M)``. + b (cupy.ndarray): + Dense vector or matrix with dimension ``(M)`` or ``(M, K)``. + lower (bool): + Whether ``A`` is a lower or upper triangular matrix. + If True, it is lower triangular, otherwise, upper triangular. + overwrite_A (bool): + (not supported) + overwrite_b (bool): + Allows overwriting data in ``b``. + unit_diagonal (bool): + If True, diagonal elements of ``A`` are assumed to be 1 and will + not be referenced. + + Returns: + cupy.ndarray: + Solution to the system ``A x = b``. The shape is the same as ``b``. + """ + from cupyx import cusparse + from cupyx.scipy import sparse + import cupy + from cupyx.scipy.sparse.linalg._solve import _should_use_spsm + + A = self.matrix + + if not (cusparse.check_availability('spsm') or + cusparse.check_availability('csrsm2')): + raise NotImplementedError + + if not sparse.isspmatrix(A): + raise TypeError('A must be cupyx.scipy.sparse.spmatrix') + if not isinstance(b, cupy.ndarray): + raise TypeError('b must be cupy.ndarray') + if A.shape[0] != A.shape[1]: + raise ValueError(f'A must be a square matrix (A.shape: {A.shape})') + if b.ndim not in [1, 2]: + raise ValueError(f'b must be 1D or 2D array (b.shape: {b.shape})') + if A.shape[0] != b.shape[0]: + raise ValueError('The size of dimensions of A must be equal to the ' + 'size of the first dimension of b ' + f'(A.shape: {A.shape}, b.shape: {b.shape})') + if A.dtype.char not in 'fdFD': + raise TypeError(f'unsupported dtype (actual: {A.dtype})') + + if cusparse.check_availability('spsm') and _should_use_spsm(b): + if not (sparse.isspmatrix_csr(A) or + sparse.isspmatrix_csc(A) or + sparse.isspmatrix_coo(A)): + warnings.warn('CSR, CSC or COO format is required. Converting to ' + 'CSR format.', sparse.SparseEfficiencyWarning) + A = A.tocsr() + A.sum_duplicates() + x, self.spsm_descr = custom_spsm(A, b, lower=lower, unit_diag=unit_diagonal, spsm_descr=self.spsm_descr) + elif cusparse.check_availability('csrsm2'): + if not (sparse.isspmatrix_csr(A) or sparse.isspmatrix_csc(A)): + warnings.warn('CSR or CSC format is required. Converting to CSR ' + 'format.', sparse.SparseEfficiencyWarning) + A = A.tocsr() + A.sum_duplicates() + + if (overwrite_b and A.dtype == b.dtype and + (b._c_contiguous or b._f_contiguous)): + x = b + else: + x = b.astype(A.dtype, copy=True) + + cusparse.csrsm2(A, x, lower=lower, unit_diag=unit_diagonal) + else: + assert False + + if x.dtype.char in 'fF': + # Note: This is for compatibility with SciPy. + dtype = numpy.promote_types(x.dtype, 'float64') + x = x.astype(dtype) + return x + + +def cupy_solve_upper_csr(matrix, array, axis, out): + """Solve upper triangular CSR matrix along specified axis of an array.""" + # Switch by dimension + ndim = array.ndim + if ndim == 1: + if axis == 0: + cupy_solve_upper_csr_vec(matrix, array, out) + else: + raise ValueError("axis must be 0 for 1D arrays") + elif ndim == 2: + if axis == 0: + if array.shape[1] == 1: + cupy_solve_upper_csr_vec(matrix, array[:,0], out[:,0]) + else: + cupy_solve_upper_csr_first(matrix, array, out) + elif axis == 1: + if array.shape[0] == 1: + cupy_solve_upper_csr_vec(matrix, array[0,:], out[0,:]) + else: + cupy_solve_upper_csr_last(matrix, array, out) + else: + raise ValueError("axis must be 0 or 1 for 2D arrays") + else: + # Treat as 3D array with specified axis in the middle + # Compute equivalent shape (N1, N2, N3) + if ndim == 3 and axis == 1: + N1 = shape[0] + N2 = shape[1] + N3 = shape[2] + else: + N1 = int(np.prod(array.shape[:axis])) + N2 = array.shape[axis] + N3 = int(np.prod(array.shape[axis+1:])) + # Dispatch to cupy routines + if N1 == 1: + if N3 == 1: + # (1, N2, 1) -> (N2,) + x1 = array.reshape((N2,)) + y1 = out.reshape((N2,)) + cupy_solve_upper_csr_vec(matrix, x1, y1) + else: + # (1, N2, N3) -> (N2, N3) + x2 = array.reshape((N2, N3)) + y2 = out.reshape((N2, N3)) + cupy_solve_upper_csr_first(matrix, x2, y2) + else: + if N3 == 1: + # (N1, N2, 1) -> (N1, N2) + x2 = array.reshape((N1, N2)) + y2 = out.reshape((N1, N2)) + cupy_solve_upper_csr_last(matrix, x2, y2) + else: + # (N1, N2, N3) + x3 = array.reshape((N1, N2, N3)) + y3 = out.reshape((N1, N2, N3)) + cupy_solve_upper_csr_mid(matrix, x3, y3) + + +def cupy_solve_upper_csr_vec(matrix, vec, out): + """Solve upper triangular CSR matrix along a vector.""" + out[:] = matrix.solve(vec, lower=False) + + +def cupy_solve_upper_csr_first(matrix, array, out): + """Solve upper triangular CSR matrix along first axis of 2D array.""" + out[:] = matrix.solve(array, lower=False) + + +def cupy_solve_upper_csr_last(matrix, array, out): + """Solve upper triangular CSR matrix along last axis of 2D array.""" + out.T[:] = matrix.solve(array.T, lower=False) + + +def cupy_solve_upper_csr_mid(matrix, array, out): + """Solve upper triangular CSR matrix along middle axis of 3D array.""" + raise NotImplementedError From 6e32312d9db7d8d5b8eaf019504120b2cd5bff89 Mon Sep 17 00:00:00 2001 From: "Calum S. Skene" Date: Tue, 16 Dec 2025 11:51:26 +0000 Subject: [PATCH 29/29] Make einsum in trace compatible with cupy --- dedalus/core/operators.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dedalus/core/operators.py b/dedalus/core/operators.py index 3655fe72..334e272c 100644 --- a/dedalus/core/operators.py +++ b/dedalus/core/operators.py @@ -1795,8 +1795,10 @@ def operate(self, out): xp = self.array_namespace arg = self.args[0] out.preset_layout(arg.layout) - xp.einsum('ii...', arg.data, out=out.data) - + if array_api_compat.is_cupy_namespace(xp): + out.data[:] = xp.einsum('ii...', arg.data) + else: + xp.einsum('ii...', arg.data, out=out.data) class SphericalTrace(Trace):