Skip to content

Conversation

@BlissChapman
Copy link

@BlissChapman BlissChapman commented Dec 29, 2025

Motivation

For some inversion problems, a subset of the true density/velocity model is known. In these use cases, we only need to solve for the subset of the model parameters that are unknown. Large memory savings are possible from only storing the gradient where necessary.

This pull request introduces a gradient mask parameter to the acoustic wave propagator that allows the CUDA kernel to compute gradients only for a subset of grid points using a sparse gradient mask index:

        gradient_mask: A boolean Tensor with the same spatial shape as the
            model, specifying which cells should have gradients computed.
            Optional. If not provided, gradients will be computed everywhere in
            the model. If the model is padded internally, an unpadded mask will
            be padded with False values. True values indicate cells where
            gradients should be computed, while False values indicate cells
            where gradients should be set to 0.

Test Plan

Added new unit tests:

  1. Test that gradient masking reduces memory requirements.
  2. Test that the gradient mask zeros gradients outside the mask.
  3. Test that providing no gradient mask computes the gradient everywhere.
  4. Test that gradient masking is not supported when using the python backend.
pip install . && PYTHONPATH= python -m pytest

In addition, a similar version of this code has been used internally at my work for full waveform inversion with a computational domain that cannot fit in H200 gpu memory without gradient masking.

@ar4
Copy link
Owner

ar4 commented Dec 30, 2025

Thank you, Bliss. This is great. It's an elegant implementation, and you did a great job of keeping the style consistent. A few comments:

1. Testing

I suggest making your tests a bit more sensitive to possible problems. For example, the following code:

from typing import Optional, Tuple

import torch

from deepwave import acoustic


def _compute_gradient(
    mask: Optional[torch.Tensor], nx: Tuple[int, int] = (5, 5), nt: int = 4
) -> Tuple[torch.Tensor, torch.Tensor]:
    v = torch.ones(nx, requires_grad=True)
    rho = torch.ones_like(v, requires_grad=True)
    source_amplitudes = torch.zeros(1, 1, nt)
    source_amplitudes[0, 0, 0] = 1
    source_locations = torch.tensor([[[0, 0]]], dtype=torch.long)
    receiver_locations = torch.tensor([[[nx[0] - 1, nx[1] - 1]]], dtype=torch.long)
    outputs = acoustic(
        v,
        rho,
        1.0,
        0.001,
        source_amplitudes_p=source_amplitudes,
        source_locations_p=source_locations,
        receiver_locations_p=receiver_locations,
        nt=nt,
        pml_width=2,
        gradient_mask=mask,
    )
    receivers = outputs[-3:]
    loss = sum(o.sum() for o in receivers)
    loss.backward()
    assert v.grad is not None
    assert rho.grad is not None
    return v.grad.detach(), rho.grad.detach()


nx = (10, 10)
mask = torch.zeros(nx, dtype=torch.bool)
mask.fill_(True)
# mask[1:-1, 1:-1] = True
grad_no_mask_v, grad_no_mask_rho = _compute_gradient(None, nx=nx, nt=500)
grad_masked_v, grad_masked_rho = _compute_gradient(mask, nx=nx, nt=500)
grad_no_mask_v[~mask] = 0
grad_no_mask_rho[~mask] = 0

reveals that the gradients at the edges of the model do not exactly match the unmasked gradient.
mask_all
I think this can be resolved by changing the padding of the mask to be replication rather than zero-padding, so that the gradients will be calculated in the PML regions when the mask touches the edge of the model.

I also suggest testing the case when the mask does not cover the whole model. Using mask[1:-1, 1:-1] = True instead of mask.fill_(True) shows that the density gradients at the edges of the mask region do not match exactly.
mask_part
I think this could be resolved by extending the gradient mask by one cell in each dimension (where possible), so that all of the buoyancy gradient components that contribute to the density gradient are calculated.

2. Vectorization

The compiler is having greater difficulty vectorizing the C code because of the conditionals in the inner loop. On my system, this causes the CPU runtime of the propagator to double, even when the gradient masking feature is not used. One possible way to overcome this would be to use the normal full gradients within the inner loop, but then change the storing and loading code so that the masking occurs there, outside the main spatial computation loops. That would, however, interfere with vectorization of the loading/storing code, so I'm not sure how much of an improvement it would be.

3. Consider improving existing alternatives (callbacks, compression)

Two alternatives to masked gradients in the C/CUDA code are:
(a) using callback functions to achieve the same effect
(b) storing the entire gradient but using the new storage features to ease the memory constraints

In my tests comparing the runtime (I only tested 2D), the runtime of (a) was about 1.5-2 times that of the C/CUDA masked gradients, and the runtime of (b) was only slightly higher on the CPU but was about four times that of the CUDA masked gradients code on the GPU. These approaches both have the advantage of not adding complexity to the existing code or any performance penalty to those who do not use masked gradients. It therefore might be worth considering whether improving the performance of either of these options is possible as an alternative to adding the masked gradients feature.

At the moment control returns to the Python code to call the callbacks. Although this is done within an autograd.Function, avoiding the overhead of PyTorch's automatic gradient calculation, there is still some overhead. I suspect that this overhead is the main reason why the callback version is slower than the C/CUDA masked gradient implementation. I will do some tests over the next few days to try to measure this overhead. If this overhead is confirmed to be the source of the slower runtime, there are some possible solutions. One is to pass a function pointer to the callback function into the C/CUDA code (using ctypes), enabling the C/CUDA code to directly call the callback function without having to go back to the propagator's Python wrapper. The other is to reduce the overhead of calling the C/CUDA code from Python. Deepwave currently uses ctypes to do this (because of its portability), but there are other means of calling C/CUDA from Python that apparently have lower overhead.

I added double buffering and CUDA streams to Deepwave a few days ago to improve the performance of some of the new storage options (compression and storage to CPU memory). Further profiling might be useful to identify why it is still so slow. The current compression method is quite simplistic, so I suspect that improvements are possible. It does reduce the memory requirement by about three quarters, however, and storage in CPU memory further reduces the constraint imposed by limited GPU memory, so in some cases this might be a sufficient alternative to masked gradients.

4. Consider different method (local solver)

Another option to consider, although it is a larger undertaking, is to investigate something like this local solver approach, which might also reduce runtime.

Summary

Thank you again for all of your work on this. To summarise, the main issues to be resolved at the moment are the vectorization of the C code and the differences between the masked gradient and the regular gradient in the masked areas. I will investigate whether the runtime of the two existing alternatives (callbacks and compressed storage) can be improved, and I suggest you consider whether a local solver-type approach might be a good fit for your application.

@ar4
Copy link
Owner

ar4 commented Jan 2, 2026

Looking again at the performance measurements, I realised that I used a setup that was somewhat unrealistic, with a large grid spacing and a low maximum velocity causing there to be only two internal steps (to obey the CFL condition) per gradient calculation. This meant that the overhead of returning to Python to call the callback was large compared to the time spent propagating. It also meant that there was less computation time to overlap with the transfer to CPU when using the compression and CPU memory storage approach, resulting in the transfer time not being hidden.

Using a more realistic setup, the callback approach has approximately the same runtime in 3D as the C/CUDA masked gradient code in my test.

The measurement code is below, including the callback that I used to calculate the masked gradient for the acoustic propagator. The callback code fixes the error in the density gradient at the edges of the masked region. However, for simplicity, I didn't change the way the PML region is handled, so the gradient won't be right on cells at the edge of the simulation domain.

import deepwave
import matplotlib.pyplot as plt
import torch
from deepwave.staggered_grid import diff1, diff1h
import time
import math


class MaskedAcousticGradientCalculator:
    """Handles gradient accumulation within a masked region for acoustic propagator."""

    def __init__(self, mask, grid_spacing, dt, accuracy):
        self.ndim = mask.ndim
        self.grid_spacing = grid_spacing  # [dim0, dim1, ...]
        self.dt = dt
        self.accuracy = accuracy
        self.snapshots_div_v = []
        self.snapshots_grad_p = []  # List of lists [grad_p_dim0, grad_p_dim1, ...]

        # Expand mask by one cell in all directions (dilation)
        if self.ndim == 2:
            max_pool = torch.nn.functional.max_pool2d
        elif self.ndim == 3:
            max_pool = torch.nn.functional.max_pool3d
        else:
            raise ValueError(f"Unsupported number of dimensions: {self.ndim}")
        # max_pool requires batch and channel dimensions
        self.mask = (
            max_pool(
                mask.float().unsqueeze(0).unsqueeze(0),
                kernel_size=3,
                stride=1,
                padding=1,
            )
            .squeeze(0)
            .squeeze(0)
            .bool()
        )

        self.dims = ["z", "y", "x"][-self.ndim :]

        # Accumulators for gradient contributions in the masked region
        num_masked = self.mask.sum()
        self.grad_K_accum = torch.zeros(num_masked, device=self.mask.device)
        self.grad_B_accum = [
            torch.zeros(num_masked, device=self.mask.device) for _ in range(self.ndim)
        ]

    def forward_callback(self, state: deepwave.common.CallbackState):
        """Callback to store masked wavefield derivatives during the forward pass."""
        p = state.get_wavefield("pressure_0", view="inner").detach()
        vs = [
            state.get_wavefield(f"v{dim}_0", view="inner").detach() for dim in self.dims
        ]

        ndim = p.ndim - 1
        inv_spacing = [1.0 / gs for gs in self.grid_spacing]

        # Calculate divergence of velocity
        div_v = 0
        for i, v in enumerate(vs):
            div_v += diff1(v, i, self.accuracy, inv_spacing[i], ndim)

        # Calculate gradient of pressure
        grad_p = []
        for i in range(self.ndim):
            grad_p.append(diff1h(p, i, self.accuracy, inv_spacing[i], ndim))

        # Store only the values at the masked indices
        self.snapshots_div_v.append(div_v[:, self.mask])
        self.snapshots_grad_p.append([gp[:, self.mask] for gp in grad_p])

    def backward_callback(self, state: deepwave.common.CallbackState):
        """Callback to compute gradients during the backward pass."""
        if not self.snapshots_div_v:
            raise AssertionError(
                "No forward snapshot was found during backpropagation."
            )

        fwd_div_v = self.snapshots_div_v.pop()
        fwd_grad_p = self.snapshots_grad_p.pop()

        # Get the adjoint wavefields
        q = state.get_wavefield("pressure_0", view="inner").detach()
        v_stars = [
            state.get_wavefield(f"v{dim}_0", view="inner").detach() for dim in self.dims
        ]

        # Calculate gradient contributions
        # grad_K = sum -dt * q * div_v
        # grad_B = sum -dt * v * grad_p
        self.grad_K_accum += -(q[:, self.mask] * fwd_div_v).sum(dim=0)

        for i in range(self.ndim):
            self.grad_B_accum[i] += -(v_stars[i][:, self.mask] * fwd_grad_p[i]).sum(
                dim=0
            )

        # At the end of the backward pass (step 0), inject the calculated gradients
        if state.step == 0:
            state.get_gradient("K", view="inner")[:, self.mask] = (
                self.dt * self.grad_K_accum
            )
            for i, dim in enumerate(self.dims):
                state.get_gradient(f"B{dim}", view="inner")[:, self.mask] = (
                    self.dt * self.grad_B_accum[i]
                )


def run(device=None, mode="standard", shape=(100, 200, 200)):
    if device is None:
        device = torch.device("cpu")
    dtype = torch.float32
    grid_spacing_val = 4.0
    grid_spacing = [grid_spacing_val] * len(shape)
    dt = 0.004
    freq = 25
    peak_time = 1.5 / freq
    accuracy = 4
    v_val = 3500.0
    # Calculate nt based on max dimension size
    max_dim = math.sqrt(sum(s**2 for s in shape))
    nt = int((2 * peak_time + max_dim * grid_spacing_val / v_val) / dt)
    pml_width = [10] * (2 * len(shape))

    # Create models
    v = torch.full(shape, v_val, device=device, dtype=dtype)
    rho = torch.full(shape, 2000.0, device=device, dtype=dtype)

    # Create mask
    mask = torch.zeros(shape, dtype=torch.bool, device=device)
    # Define a simple mask region in the middle
    slices = tuple(slice(4, -4) for _ in range(len(shape)))
    mask[slices] = True
    slices_hole = tuple(slice(6, -6) for _ in range(len(shape)))
    mask[slices_hole] = False

    v.requires_grad = True
    rho.requires_grad = True

    # Setup source and receivers
    source_amplitudes = (
        deepwave.wavelets.ricker(freq, nt, dt, peak_time).reshape(1, 1, -1).to(device)
    )

    # Source at origin, Receiver at far corner
    source_locs = [0] * len(shape)
    receiver_locs = [s - 1 for s in shape]

    source_locations = torch.tensor([[source_locs]], dtype=torch.long, device=device)
    receiver_locations = torch.tensor(
        [[receiver_locs]], dtype=torch.long, device=device
    )

    def gradient_mask_hook(x):
        return x * mask

    if mode == "standard":
        kwargs = {}
        v.register_hook(gradient_mask_hook)
        rho.register_hook(gradient_mask_hook)
    elif mode == "compiled":
        kwargs = {"gradient_mask": mask}
    elif mode == "callbacks":
        mask_calc = MaskedAcousticGradientCalculator(
            mask,
            grid_spacing,
            dt,
            accuracy,
        )
        kwargs = {
            "forward_callback": mask_calc.forward_callback,
            "backward_callback": mask_calc.backward_callback,
            "storage_mode": "none",
        }
        v.register_hook(gradient_mask_hook)
        rho.register_hook(gradient_mask_hook)
    else:
        raise RuntimeError("Unknown mode")

    start_time = time.time()

    # Forward pass
    out = deepwave.acoustic(
        v,
        rho,
        grid_spacing,
        dt,
        source_amplitudes_p=source_amplitudes,
        source_locations_p=source_locations,
        receiver_locations_p=receiver_locations,
        accuracy=accuracy,
        pml_width=pml_width,
        pml_freq=freq,
        **kwargs,
    )

    receiver_amplitudes = out[-len(shape) - 1]

    # Backward pass
    loss = (receiver_amplitudes**2).sum()
    loss.backward()

    end_time = time.time()
    duration = end_time - start_time

    return duration, v.grad.detach().cpu(), rho.grad.detach().cpu()


def run_validation(device):
    """Validate that the different methods produce similar gradients."""
    shapes = [(20, 20), (20, 20, 20)]

    for shape in shapes:
        ndim = len(shape)
        print(f"\nValidating {ndim}D model with shape {shape}...")

        _, v_standard, rho_standard = run(device=device, mode="standard", shape=shape)
        _, v_compiled, rho_compiled = run(device=device, mode="compiled", shape=shape)
        _, v_callbacks, rho_callbacks = run(
            device=device, mode="callbacks", shape=shape
        )

        # Plotting
        num_cols = 5
        fig, ax = plt.subplots(2, num_cols, figsize=(15, 6))

        # Select a slice for display
        if ndim == 3:
            s = (slice(None), slice(None), shape[-1] // 2)
        else:
            s = (slice(None), slice(None))

        # Row 0: Velocity
        titles = [
            "V Standard",
            "V Compiled",
            "V Comp Err",
            "V Callbacks",
            "V Callbacks Err",
        ]
        data = [
            v_standard,
            v_compiled,
            v_standard - v_compiled,
            v_callbacks,
            v_standard - v_callbacks,
        ]

        for i in range(num_cols):
            im = ax[0, i].imshow(data[i][s])
            plt.colorbar(im, ax=ax[0, i])
            ax[0, i].set_title(titles[i])

        # Row 1: Density
        titles = [
            "Rho Standard",
            "Rho Compiled",
            "Rho Comp Err",
            "Rho Callbacks",
            "Rho Callbacks Err",
        ]
        data = [
            rho_standard,
            rho_compiled,
            rho_standard - rho_compiled,
            rho_callbacks,
            rho_standard - rho_callbacks,
        ]

        for i in range(num_cols):
            im = ax[1, i].imshow(data[i][s])
            plt.colorbar(im, ax=ax[1, i])
            ax[1, i].set_title(titles[i])

        plt.tight_layout()
        filename = f"validation_{ndim}d.png"
        plt.savefig(filename)
        print(f"Validation plot saved to {filename}")

        print(f"{ndim}D V Compiled Max Error:", (v_standard - v_compiled).abs().max())
        print(
            f"{ndim}D Rho Compiled Max Error:",
            (rho_standard - rho_compiled).abs().max(),
        )
        print(f"{ndim}D V Callbacks Max Error:", (v_standard - v_callbacks).abs().max())
        print(
            f"{ndim}D Rho Callbacks Max Error:",
            (rho_standard - rho_callbacks).abs().max(),
        )


def run_benchmarking(device):
    """Benchmarks the performance of different gradient calculation methods."""
    print("\n--- Starting Benchmarking Run ---")

    # Define shapes for benchmarking
    shapes = [
        (1000, 1000),  # 2D
        (100, 200, 200),  # 3D
    ]

    for shape in shapes:
        ndim = len(shape)
        print(f"\nBenchmarking {ndim}D model with shape {shape}:")

        def benchmark(device, mode, shape):
            runtimes = []
            for _ in range(3):
                runtime, _, _ = run(device=device, mode=mode, shape=shape)
                runtimes.append(runtime)
            return min(runtimes)

        print("Compiled C/CUDA approach")
        t_compiled = benchmark(device, mode="compiled", shape=shape)

        print("Callbacks approach")
        t_callbacks = benchmark(device, mode="callbacks", shape=shape)

        print(f"Summary for {ndim}D:")
        print(f"Compiled approach: {t_compiled:.4f} s")
        print(f"Callbacks approach:  {t_callbacks:.4f} s")


def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    run_validation(device)
    run_benchmarking(device)


if __name__ == "__main__":
    main()

Does it also have similar performance to the compiled version when running your application on your system?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants