-
Notifications
You must be signed in to change notification settings - Fork 63
Support gradient masking in acoustic wave propagation #114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
a52509c to
5290338
Compare
5290338 to
918a6ae
Compare
|
Thank you, Bliss. This is great. It's an elegant implementation, and you did a great job of keeping the style consistent. A few comments: 1. TestingI suggest making your tests a bit more sensitive to possible problems. For example, the following code: from typing import Optional, Tuple
import torch
from deepwave import acoustic
def _compute_gradient(
mask: Optional[torch.Tensor], nx: Tuple[int, int] = (5, 5), nt: int = 4
) -> Tuple[torch.Tensor, torch.Tensor]:
v = torch.ones(nx, requires_grad=True)
rho = torch.ones_like(v, requires_grad=True)
source_amplitudes = torch.zeros(1, 1, nt)
source_amplitudes[0, 0, 0] = 1
source_locations = torch.tensor([[[0, 0]]], dtype=torch.long)
receiver_locations = torch.tensor([[[nx[0] - 1, nx[1] - 1]]], dtype=torch.long)
outputs = acoustic(
v,
rho,
1.0,
0.001,
source_amplitudes_p=source_amplitudes,
source_locations_p=source_locations,
receiver_locations_p=receiver_locations,
nt=nt,
pml_width=2,
gradient_mask=mask,
)
receivers = outputs[-3:]
loss = sum(o.sum() for o in receivers)
loss.backward()
assert v.grad is not None
assert rho.grad is not None
return v.grad.detach(), rho.grad.detach()
nx = (10, 10)
mask = torch.zeros(nx, dtype=torch.bool)
mask.fill_(True)
# mask[1:-1, 1:-1] = True
grad_no_mask_v, grad_no_mask_rho = _compute_gradient(None, nx=nx, nt=500)
grad_masked_v, grad_masked_rho = _compute_gradient(mask, nx=nx, nt=500)
grad_no_mask_v[~mask] = 0
grad_no_mask_rho[~mask] = 0reveals that the gradients at the edges of the model do not exactly match the unmasked gradient. I also suggest testing the case when the mask does not cover the whole model. Using 2. VectorizationThe compiler is having greater difficulty vectorizing the C code because of the conditionals in the inner loop. On my system, this causes the CPU runtime of the propagator to double, even when the gradient masking feature is not used. One possible way to overcome this would be to use the normal full gradients within the inner loop, but then change the storing and loading code so that the masking occurs there, outside the main spatial computation loops. That would, however, interfere with vectorization of the loading/storing code, so I'm not sure how much of an improvement it would be. 3. Consider improving existing alternatives (callbacks, compression)Two alternatives to masked gradients in the C/CUDA code are: In my tests comparing the runtime (I only tested 2D), the runtime of (a) was about 1.5-2 times that of the C/CUDA masked gradients, and the runtime of (b) was only slightly higher on the CPU but was about four times that of the CUDA masked gradients code on the GPU. These approaches both have the advantage of not adding complexity to the existing code or any performance penalty to those who do not use masked gradients. It therefore might be worth considering whether improving the performance of either of these options is possible as an alternative to adding the masked gradients feature. At the moment control returns to the Python code to call the callbacks. Although this is done within an autograd.Function, avoiding the overhead of PyTorch's automatic gradient calculation, there is still some overhead. I suspect that this overhead is the main reason why the callback version is slower than the C/CUDA masked gradient implementation. I will do some tests over the next few days to try to measure this overhead. If this overhead is confirmed to be the source of the slower runtime, there are some possible solutions. One is to pass a function pointer to the callback function into the C/CUDA code (using ctypes), enabling the C/CUDA code to directly call the callback function without having to go back to the propagator's Python wrapper. The other is to reduce the overhead of calling the C/CUDA code from Python. Deepwave currently uses ctypes to do this (because of its portability), but there are other means of calling C/CUDA from Python that apparently have lower overhead. I added double buffering and CUDA streams to Deepwave a few days ago to improve the performance of some of the new storage options (compression and storage to CPU memory). Further profiling might be useful to identify why it is still so slow. The current compression method is quite simplistic, so I suspect that improvements are possible. It does reduce the memory requirement by about three quarters, however, and storage in CPU memory further reduces the constraint imposed by limited GPU memory, so in some cases this might be a sufficient alternative to masked gradients. 4. Consider different method (local solver)Another option to consider, although it is a larger undertaking, is to investigate something like this local solver approach, which might also reduce runtime. SummaryThank you again for all of your work on this. To summarise, the main issues to be resolved at the moment are the vectorization of the C code and the differences between the masked gradient and the regular gradient in the masked areas. I will investigate whether the runtime of the two existing alternatives (callbacks and compressed storage) can be improved, and I suggest you consider whether a local solver-type approach might be a good fit for your application. |
|
Looking again at the performance measurements, I realised that I used a setup that was somewhat unrealistic, with a large grid spacing and a low maximum velocity causing there to be only two internal steps (to obey the CFL condition) per gradient calculation. This meant that the overhead of returning to Python to call the callback was large compared to the time spent propagating. It also meant that there was less computation time to overlap with the transfer to CPU when using the compression and CPU memory storage approach, resulting in the transfer time not being hidden. Using a more realistic setup, the callback approach has approximately the same runtime in 3D as the C/CUDA masked gradient code in my test. The measurement code is below, including the callback that I used to calculate the masked gradient for the acoustic propagator. The callback code fixes the error in the density gradient at the edges of the masked region. However, for simplicity, I didn't change the way the PML region is handled, so the gradient won't be right on cells at the edge of the simulation domain. import deepwave
import matplotlib.pyplot as plt
import torch
from deepwave.staggered_grid import diff1, diff1h
import time
import math
class MaskedAcousticGradientCalculator:
"""Handles gradient accumulation within a masked region for acoustic propagator."""
def __init__(self, mask, grid_spacing, dt, accuracy):
self.ndim = mask.ndim
self.grid_spacing = grid_spacing # [dim0, dim1, ...]
self.dt = dt
self.accuracy = accuracy
self.snapshots_div_v = []
self.snapshots_grad_p = [] # List of lists [grad_p_dim0, grad_p_dim1, ...]
# Expand mask by one cell in all directions (dilation)
if self.ndim == 2:
max_pool = torch.nn.functional.max_pool2d
elif self.ndim == 3:
max_pool = torch.nn.functional.max_pool3d
else:
raise ValueError(f"Unsupported number of dimensions: {self.ndim}")
# max_pool requires batch and channel dimensions
self.mask = (
max_pool(
mask.float().unsqueeze(0).unsqueeze(0),
kernel_size=3,
stride=1,
padding=1,
)
.squeeze(0)
.squeeze(0)
.bool()
)
self.dims = ["z", "y", "x"][-self.ndim :]
# Accumulators for gradient contributions in the masked region
num_masked = self.mask.sum()
self.grad_K_accum = torch.zeros(num_masked, device=self.mask.device)
self.grad_B_accum = [
torch.zeros(num_masked, device=self.mask.device) for _ in range(self.ndim)
]
def forward_callback(self, state: deepwave.common.CallbackState):
"""Callback to store masked wavefield derivatives during the forward pass."""
p = state.get_wavefield("pressure_0", view="inner").detach()
vs = [
state.get_wavefield(f"v{dim}_0", view="inner").detach() for dim in self.dims
]
ndim = p.ndim - 1
inv_spacing = [1.0 / gs for gs in self.grid_spacing]
# Calculate divergence of velocity
div_v = 0
for i, v in enumerate(vs):
div_v += diff1(v, i, self.accuracy, inv_spacing[i], ndim)
# Calculate gradient of pressure
grad_p = []
for i in range(self.ndim):
grad_p.append(diff1h(p, i, self.accuracy, inv_spacing[i], ndim))
# Store only the values at the masked indices
self.snapshots_div_v.append(div_v[:, self.mask])
self.snapshots_grad_p.append([gp[:, self.mask] for gp in grad_p])
def backward_callback(self, state: deepwave.common.CallbackState):
"""Callback to compute gradients during the backward pass."""
if not self.snapshots_div_v:
raise AssertionError(
"No forward snapshot was found during backpropagation."
)
fwd_div_v = self.snapshots_div_v.pop()
fwd_grad_p = self.snapshots_grad_p.pop()
# Get the adjoint wavefields
q = state.get_wavefield("pressure_0", view="inner").detach()
v_stars = [
state.get_wavefield(f"v{dim}_0", view="inner").detach() for dim in self.dims
]
# Calculate gradient contributions
# grad_K = sum -dt * q * div_v
# grad_B = sum -dt * v * grad_p
self.grad_K_accum += -(q[:, self.mask] * fwd_div_v).sum(dim=0)
for i in range(self.ndim):
self.grad_B_accum[i] += -(v_stars[i][:, self.mask] * fwd_grad_p[i]).sum(
dim=0
)
# At the end of the backward pass (step 0), inject the calculated gradients
if state.step == 0:
state.get_gradient("K", view="inner")[:, self.mask] = (
self.dt * self.grad_K_accum
)
for i, dim in enumerate(self.dims):
state.get_gradient(f"B{dim}", view="inner")[:, self.mask] = (
self.dt * self.grad_B_accum[i]
)
def run(device=None, mode="standard", shape=(100, 200, 200)):
if device is None:
device = torch.device("cpu")
dtype = torch.float32
grid_spacing_val = 4.0
grid_spacing = [grid_spacing_val] * len(shape)
dt = 0.004
freq = 25
peak_time = 1.5 / freq
accuracy = 4
v_val = 3500.0
# Calculate nt based on max dimension size
max_dim = math.sqrt(sum(s**2 for s in shape))
nt = int((2 * peak_time + max_dim * grid_spacing_val / v_val) / dt)
pml_width = [10] * (2 * len(shape))
# Create models
v = torch.full(shape, v_val, device=device, dtype=dtype)
rho = torch.full(shape, 2000.0, device=device, dtype=dtype)
# Create mask
mask = torch.zeros(shape, dtype=torch.bool, device=device)
# Define a simple mask region in the middle
slices = tuple(slice(4, -4) for _ in range(len(shape)))
mask[slices] = True
slices_hole = tuple(slice(6, -6) for _ in range(len(shape)))
mask[slices_hole] = False
v.requires_grad = True
rho.requires_grad = True
# Setup source and receivers
source_amplitudes = (
deepwave.wavelets.ricker(freq, nt, dt, peak_time).reshape(1, 1, -1).to(device)
)
# Source at origin, Receiver at far corner
source_locs = [0] * len(shape)
receiver_locs = [s - 1 for s in shape]
source_locations = torch.tensor([[source_locs]], dtype=torch.long, device=device)
receiver_locations = torch.tensor(
[[receiver_locs]], dtype=torch.long, device=device
)
def gradient_mask_hook(x):
return x * mask
if mode == "standard":
kwargs = {}
v.register_hook(gradient_mask_hook)
rho.register_hook(gradient_mask_hook)
elif mode == "compiled":
kwargs = {"gradient_mask": mask}
elif mode == "callbacks":
mask_calc = MaskedAcousticGradientCalculator(
mask,
grid_spacing,
dt,
accuracy,
)
kwargs = {
"forward_callback": mask_calc.forward_callback,
"backward_callback": mask_calc.backward_callback,
"storage_mode": "none",
}
v.register_hook(gradient_mask_hook)
rho.register_hook(gradient_mask_hook)
else:
raise RuntimeError("Unknown mode")
start_time = time.time()
# Forward pass
out = deepwave.acoustic(
v,
rho,
grid_spacing,
dt,
source_amplitudes_p=source_amplitudes,
source_locations_p=source_locations,
receiver_locations_p=receiver_locations,
accuracy=accuracy,
pml_width=pml_width,
pml_freq=freq,
**kwargs,
)
receiver_amplitudes = out[-len(shape) - 1]
# Backward pass
loss = (receiver_amplitudes**2).sum()
loss.backward()
end_time = time.time()
duration = end_time - start_time
return duration, v.grad.detach().cpu(), rho.grad.detach().cpu()
def run_validation(device):
"""Validate that the different methods produce similar gradients."""
shapes = [(20, 20), (20, 20, 20)]
for shape in shapes:
ndim = len(shape)
print(f"\nValidating {ndim}D model with shape {shape}...")
_, v_standard, rho_standard = run(device=device, mode="standard", shape=shape)
_, v_compiled, rho_compiled = run(device=device, mode="compiled", shape=shape)
_, v_callbacks, rho_callbacks = run(
device=device, mode="callbacks", shape=shape
)
# Plotting
num_cols = 5
fig, ax = plt.subplots(2, num_cols, figsize=(15, 6))
# Select a slice for display
if ndim == 3:
s = (slice(None), slice(None), shape[-1] // 2)
else:
s = (slice(None), slice(None))
# Row 0: Velocity
titles = [
"V Standard",
"V Compiled",
"V Comp Err",
"V Callbacks",
"V Callbacks Err",
]
data = [
v_standard,
v_compiled,
v_standard - v_compiled,
v_callbacks,
v_standard - v_callbacks,
]
for i in range(num_cols):
im = ax[0, i].imshow(data[i][s])
plt.colorbar(im, ax=ax[0, i])
ax[0, i].set_title(titles[i])
# Row 1: Density
titles = [
"Rho Standard",
"Rho Compiled",
"Rho Comp Err",
"Rho Callbacks",
"Rho Callbacks Err",
]
data = [
rho_standard,
rho_compiled,
rho_standard - rho_compiled,
rho_callbacks,
rho_standard - rho_callbacks,
]
for i in range(num_cols):
im = ax[1, i].imshow(data[i][s])
plt.colorbar(im, ax=ax[1, i])
ax[1, i].set_title(titles[i])
plt.tight_layout()
filename = f"validation_{ndim}d.png"
plt.savefig(filename)
print(f"Validation plot saved to {filename}")
print(f"{ndim}D V Compiled Max Error:", (v_standard - v_compiled).abs().max())
print(
f"{ndim}D Rho Compiled Max Error:",
(rho_standard - rho_compiled).abs().max(),
)
print(f"{ndim}D V Callbacks Max Error:", (v_standard - v_callbacks).abs().max())
print(
f"{ndim}D Rho Callbacks Max Error:",
(rho_standard - rho_callbacks).abs().max(),
)
def run_benchmarking(device):
"""Benchmarks the performance of different gradient calculation methods."""
print("\n--- Starting Benchmarking Run ---")
# Define shapes for benchmarking
shapes = [
(1000, 1000), # 2D
(100, 200, 200), # 3D
]
for shape in shapes:
ndim = len(shape)
print(f"\nBenchmarking {ndim}D model with shape {shape}:")
def benchmark(device, mode, shape):
runtimes = []
for _ in range(3):
runtime, _, _ = run(device=device, mode=mode, shape=shape)
runtimes.append(runtime)
return min(runtimes)
print("Compiled C/CUDA approach")
t_compiled = benchmark(device, mode="compiled", shape=shape)
print("Callbacks approach")
t_callbacks = benchmark(device, mode="callbacks", shape=shape)
print(f"Summary for {ndim}D:")
print(f"Compiled approach: {t_compiled:.4f} s")
print(f"Callbacks approach: {t_callbacks:.4f} s")
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
run_validation(device)
run_benchmarking(device)
if __name__ == "__main__":
main()Does it also have similar performance to the compiled version when running your application on your system? |


Motivation
For some inversion problems, a subset of the true density/velocity model is known. In these use cases, we only need to solve for the subset of the model parameters that are unknown. Large memory savings are possible from only storing the gradient where necessary.
This pull request introduces a gradient mask parameter to the acoustic wave propagator that allows the CUDA kernel to compute gradients only for a subset of grid points using a sparse gradient mask index:
Test Plan
Added new unit tests:
In addition, a similar version of this code has been used internally at my work for full waveform inversion with a computational domain that cannot fit in H200 gpu memory without gradient masking.