diff --git a/src/deepwave/acoustic.c b/src/deepwave/acoustic.c index 64bc557..015286d 100644 --- a/src/deepwave/acoustic.c +++ b/src/deepwave/acoustic.c @@ -108,6 +108,14 @@ #define SET_I int64_t const i = x #endif +static inline int64_t get_grad_idx( + int64_t const *__restrict const gradient_mask_indices, int64_t i) { + if (gradient_mask_indices == NULL) { + return i; + } + return gradient_mask_indices[i]; +} + #ifdef _WIN32 __declspec(dllexport) #endif @@ -173,6 +181,8 @@ __declspec(dllexport) void *__restrict const bx_grad_store_2, void *__restrict const bx_grad_store_3, char const *__restrict const *__restrict const bx_grad_filenames_ptr, + int64_t const *__restrict const gradient_mask_indices, + int64_t const n_grad_grid_points, DW_DTYPE *__restrict const r_p, #if DW_NDIM >= 3 DW_DTYPE *__restrict const r_vz, @@ -298,6 +308,7 @@ __declspec(dllexport) #endif int64_t const si = shot * n_grid_points; + int64_t const grad_si = shot * n_grad_grid_points; int64_t const src_i_p = shot * n_sources_p_per_shot; #if DW_NDIM >= 3 int64_t const src_i_vz = shot * n_sources_vz_per_shot; @@ -372,9 +383,9 @@ __declspec(dllexport) for (t = start_t; t < start_t + nt; ++t) { #define SETUP_STORE_SAVE(name, grad_cond) \ DW_DTYPE *__restrict const name##_store_1_t = \ - name##_store_1 + si + \ + name##_store_1 + grad_si + \ ((storage_mode == STORAGE_DEVICE && !storage_compression) \ - ? (t / step_ratio) * n_shots * n_grid_points \ + ? (t / step_ratio) * n_shots * n_grad_grid_points \ : 0); \ void *__restrict const name##_store_2_t = \ (uint8_t *)name##_store_2 + \ @@ -469,7 +480,11 @@ __declspec(dllexport) } if (b_requires_grad_t) { - bz_grad_store_1_t[i] = term_z; + int64_t const grad_idx = + get_grad_idx(gradient_mask_indices, i); + if (grad_idx >= 0) { + bz_grad_store_1_t[grad_idx] = term_z; + } } vz_shot[i] -= dt * buoyancy_z_shot[i] * term_z; } @@ -492,7 +507,11 @@ __declspec(dllexport) } if (b_requires_grad_t) { - by_grad_store_1_t[i] = term_y; + int64_t const grad_idx = + get_grad_idx(gradient_mask_indices, i); + if (grad_idx >= 0) { + by_grad_store_1_t[grad_idx] = term_y; + } } vy_shot[i] -= dt * buoyancy_y_shot[i] * term_y; } @@ -516,7 +535,11 @@ __declspec(dllexport) } if (b_requires_grad_t) { - bx_grad_store_1_t[i] = term_x; + int64_t const grad_idx = + get_grad_idx(gradient_mask_indices, i); + if (grad_idx >= 0) { + bx_grad_store_1_t[grad_idx] = term_x; + } } vx_shot[i] -= dt * buoyancy_x_shot[i] * term_x; } @@ -599,7 +622,11 @@ __declspec(dllexport) div_v += d_x; if (k_requires_grad_t) { - k_grad_store_1_t[i] = div_v; + int64_t const grad_idx = + get_grad_idx(gradient_mask_indices, i); + if (grad_idx >= 0) { + k_grad_store_1_t[grad_idx] = div_v; + } } p_shot[i] -= dt * k_shot[i] * div_v; } @@ -619,7 +646,7 @@ __declspec(dllexport) storage_save_snapshot_cpu( \ name##_store_1_t, name##_store_2_t, fp_##name, storage_mode, \ storage_compression, step_idx, shot_bytes_uncomp, shot_bytes_comp, \ - n_grid_points, sizeof(DW_DTYPE) == sizeof(double)); \ + n_grad_grid_points, sizeof(DW_DTYPE) == sizeof(double)); \ } SAVE_SNAPSHOT(k_grad, k_requires_grad_t) @@ -721,6 +748,8 @@ __declspec(dllexport) void *__restrict const bx_grad_store_2, void *__restrict const bx_grad_store_3, char const *__restrict const *__restrict const bx_grad_filenames_ptr, + int64_t const *__restrict const gradient_mask_indices, + int64_t const n_grad_grid_points, DW_DTYPE *__restrict const grad_f_p, #if DW_NDIM >= 3 DW_DTYPE *__restrict const grad_f_vz, @@ -859,6 +888,7 @@ __declspec(dllexport) #endif /* _OPENMP */ int64_t const si = shot * n_grid_points; + int64_t const grad_si = shot * n_grad_grid_points; int64_t const src_i_p = shot * n_sources_p_per_shot; #if DW_NDIM >= 3 int64_t const src_i_vz = shot * n_sources_vz_per_shot; @@ -941,9 +971,9 @@ __declspec(dllexport) for (t = start_t - 1; t >= start_t - nt; --t) { #define SETUP_STORE_LOAD(name, grad_cond) \ DW_DTYPE *__restrict const name##_store_1_t = \ - name##_store_1 + si + \ + name##_store_1 + grad_si + \ ((storage_mode == STORAGE_DEVICE && !storage_compression) \ - ? (t / step_ratio) * n_shots * n_grid_points \ + ? (t / step_ratio) * n_shots * n_grad_grid_points \ : 0); \ void *__restrict const name##_store_2_t = \ (uint8_t *)name##_store_2 + \ @@ -956,7 +986,7 @@ __declspec(dllexport) storage_load_snapshot_cpu( \ (void *)name##_store_1_t, name##_store_2_t, fp_##name, storage_mode, \ storage_compression, step_idx, shot_bytes_uncomp, shot_bytes_comp, \ - n_grid_points, sizeof(DW_DTYPE) == sizeof(double)); \ + n_grad_grid_points, sizeof(DW_DTYPE) == sizeof(double)); \ } SETUP_STORE_LOAD(k_grad, k_requires_grad) @@ -1039,8 +1069,13 @@ __declspec(dllexport) dt * buoyancy_z_shot[i] * azh[z] * vz_shot[i]; if (b_requires_grad_t) { - grad_bz_shot[i] -= dt * vz_shot[i] * bz_grad_store_1_t[i] * - (DW_DTYPE)step_ratio; + int64_t const grad_idx = + get_grad_idx(gradient_mask_indices, i); + if (grad_idx >= 0) { + grad_bz_shot[i] -= + dt * vz_shot[i] * + bz_grad_store_1_t[grad_idx] * (DW_DTYPE)step_ratio; + } } } #endif @@ -1061,8 +1096,13 @@ __declspec(dllexport) dt * buoyancy_y_shot[i] * ayh[y] * vy_shot[i]; if (b_requires_grad_t) { - grad_by_shot[i] -= dt * vy_shot[i] * by_grad_store_1_t[i] * - (DW_DTYPE)step_ratio; + int64_t const grad_idx = + get_grad_idx(gradient_mask_indices, i); + if (grad_idx >= 0) { + grad_by_shot[i] -= + dt * vy_shot[i] * + by_grad_store_1_t[grad_idx] * (DW_DTYPE)step_ratio; + } } } #endif @@ -1084,8 +1124,13 @@ __declspec(dllexport) dt * buoyancy_x_shot[i] * axh[x] * vx_shot[i]; if (b_requires_grad_t) { - grad_bx_shot[i] -= dt * vx_shot[i] * bx_grad_store_1_t[i] * - (DW_DTYPE)step_ratio; + int64_t const grad_idx = + get_grad_idx(gradient_mask_indices, i); + if (grad_idx >= 0) { + grad_bx_shot[i] -= + dt * vx_shot[i] * + bx_grad_store_1_t[grad_idx] * (DW_DTYPE)step_ratio; + } } } } @@ -1162,8 +1207,13 @@ __declspec(dllexport) dt * k_shot[i] * ax[x] * p_shot[i]; if (k_requires_grad_t) { - grad_k_shot[i] -= dt * p_shot[i] * k_grad_store_1_t[i] * - (DW_DTYPE)step_ratio; + int64_t const grad_idx = + get_grad_idx(gradient_mask_indices, i); + if (grad_idx >= 0) { + grad_k_shot[i] -= + dt * p_shot[i] * + k_grad_store_1_t[grad_idx] * (DW_DTYPE)step_ratio; + } } // Update P diff --git a/src/deepwave/acoustic.cu b/src/deepwave/acoustic.cu index ed5dd8a..b18ff82 100644 --- a/src/deepwave/acoustic.cu +++ b/src/deepwave/acoustic.cu @@ -112,6 +112,7 @@ __constant__ int64_t pml_y1; __constant__ DW_DTYPE rdx; __constant__ int64_t nx; __constant__ int64_t shot_numel; +__constant__ int64_t grad_shot_numel; __constant__ int64_t n_shots; __constant__ int64_t step_ratio; __constant__ int64_t pml_x0; @@ -170,6 +171,12 @@ __launch_bounds__(128) __global__ } } +__device__ __forceinline__ int64_t get_grad_idx( + int64_t const *__restrict__ const gradient_mask_indices, int64_t j) { + if (!gradient_mask_indices) return j; + return gradient_mask_indices[j]; +} + __launch_bounds__(128) __global__ void forward_kernel_v(DW_DTYPE const *__restrict__ const p, #if DW_NDIM >= 3 @@ -193,6 +200,8 @@ __launch_bounds__(128) __global__ DW_DTYPE const *__restrict__ const buoyancy_y, #endif DW_DTYPE const *__restrict__ const buoyancy_x, + int64_t const *__restrict__ const + gradient_mask_indices, #if DW_NDIM >= 3 DW_DTYPE *__restrict__ const bz_grad_store_1, #endif @@ -232,6 +241,7 @@ __launch_bounds__(128) __global__ int64_t const j = x; #endif int64_t const i = shot_idx * shot_numel + j; + int64_t const grad_idx = get_grad_idx(gradient_mask_indices, j); #if DW_NDIM >= 3 int64_t const pml_z0h = pml_z0; @@ -257,8 +267,8 @@ __launch_bounds__(128) __global__ psi_z[i] = bzh[z] * term_z + azh[z] * psi_z[i]; term_z += psi_z[i]; } - if (b_requires_grad) { - bz_grad_store_1[i] = term_z; + if (b_requires_grad && grad_idx >= 0) { + bz_grad_store_1[shot_idx * grad_shot_numel + grad_idx] = term_z; } vz[i] -= dt * buoyancy_z_shot[j] * term_z; } @@ -274,8 +284,8 @@ __launch_bounds__(128) __global__ psi_y[i] = byh[y] * term_y + ayh[y] * psi_y[i]; term_y += psi_y[i]; } - if (b_requires_grad) { - by_grad_store_1[i] = term_y; + if (b_requires_grad && grad_idx >= 0) { + by_grad_store_1[shot_idx * grad_shot_numel + grad_idx] = term_y; } vy[i] -= dt * buoyancy_y_shot[j] * term_y; } @@ -288,8 +298,8 @@ __launch_bounds__(128) __global__ psi_x[i] = bxh[x] * term_x + axh[x] * psi_x[i]; term_x += psi_x[i]; } - if (b_requires_grad) { - bx_grad_store_1[i] = term_x; + if (b_requires_grad && grad_idx >= 0) { + bx_grad_store_1[shot_idx * grad_shot_numel + grad_idx] = term_x; } vx[i] -= dt * buoyancy_x_shot[j] * term_x; } @@ -317,6 +327,7 @@ __launch_bounds__(128) __global__ #endif DW_DTYPE *__restrict__ const phi_x, DW_DTYPE const *__restrict__ const k, + int64_t const *__restrict__ const gradient_mask_indices, DW_DTYPE *__restrict__ const k_grad_store_1, #if DW_NDIM >= 3 DW_DTYPE const *__restrict__ const az, @@ -350,6 +361,7 @@ __launch_bounds__(128) __global__ int64_t const j = x; #endif int64_t const i = shot_idx * shot_numel + j; + int64_t const grad_idx = get_grad_idx(gradient_mask_indices, j); DW_DTYPE const *__restrict__ const k_shot = k_batched ? k + shot_idx * shot_numel : k; @@ -381,8 +393,8 @@ __launch_bounds__(128) __global__ } div_v += d_x; - if (k_requires_grad) { - k_grad_store_1[i] = div_v; + if (k_requires_grad && grad_idx >= 0) { + k_grad_store_1[shot_idx * grad_shot_numel + grad_idx] = div_v; } p[i] -= dt * k_shot[j] * div_v; @@ -430,6 +442,8 @@ __launch_bounds__(128) __global__ DW_DTYPE const *__restrict__ const buoyancy_y, #endif DW_DTYPE const *__restrict__ const buoyancy_x, + int64_t const *__restrict__ const + gradient_mask_indices, #if DW_NDIM >= 3 DW_DTYPE *__restrict__ const grad_bz_shot, DW_DTYPE const *__restrict__ const bz_grad_store_1, @@ -471,6 +485,7 @@ __launch_bounds__(128) __global__ int64_t const j = x; #endif int64_t const i = shot_idx * shot_numel + j; + int64_t const grad_idx = get_grad_idx(gradient_mask_indices, j); DW_DTYPE const *__restrict__ const k_shot = k_batched ? k + shot_idx * shot_numel : k; @@ -495,9 +510,11 @@ __launch_bounds__(128) __global__ psi_zn[i] = azh[z] * psi_z[i] - dt * buoyancy_z_shot[j] * azh[z] * vz[i]; - if (b_requires_grad) { + if (b_requires_grad && grad_idx >= 0) { grad_bz_shot[i] -= - dt * vz[i] * bz_grad_store_1[i] * (DW_DTYPE)step_ratio; + dt * vz[i] * + bz_grad_store_1[shot_idx * grad_shot_numel + grad_idx] * + (DW_DTYPE)step_ratio; } } #endif @@ -511,9 +528,11 @@ __launch_bounds__(128) __global__ psi_yn[i] = ayh[y] * psi_y[i] - dt * buoyancy_y_shot[j] * ayh[y] * vy[i]; - if (b_requires_grad) { + if (b_requires_grad && grad_idx >= 0) { grad_by_shot[i] -= - dt * vy[i] * by_grad_store_1[i] * (DW_DTYPE)step_ratio; + dt * vy[i] * + by_grad_store_1[shot_idx * grad_shot_numel + grad_idx] * + (DW_DTYPE)step_ratio; } } #endif @@ -526,9 +545,11 @@ __launch_bounds__(128) __global__ psi_xn[i] = axh[x] * psi_x[i] - dt * buoyancy_x_shot[j] * axh[x] * vx[i]; - if (b_requires_grad) { + if (b_requires_grad && grad_idx >= 0) { grad_bx_shot[i] -= - dt * vx[i] * bx_grad_store_1[i] * (DW_DTYPE)step_ratio; + dt * vx[i] * + bx_grad_store_1[shot_idx * grad_shot_numel + grad_idx] * + (DW_DTYPE)step_ratio; } } @@ -569,6 +590,8 @@ __launch_bounds__(128) __global__ DW_DTYPE const *__restrict__ const buoyancy_y, #endif DW_DTYPE const *__restrict__ const buoyancy_x, + int64_t const *__restrict__ const + gradient_mask_indices, DW_DTYPE *__restrict__ const grad_k_shot, DW_DTYPE const *__restrict__ const k_grad_store_1, #if DW_NDIM >= 3 @@ -606,6 +629,7 @@ __launch_bounds__(128) __global__ int64_t const j = x; #endif int64_t const i = shot_idx * shot_numel + j; + int64_t const grad_idx = get_grad_idx(gradient_mask_indices, j); DW_DTYPE const *__restrict__ const k_shot = k_batched ? k + shot_idx * shot_numel : k; @@ -622,8 +646,11 @@ __launch_bounds__(128) __global__ bool const pml_x = x < pml_x0 || x >= pml_x1; if (pml_x) phi_x[i] = ax[x] * phi_x[i] - dt * k_shot[j] * ax[x] * p[i]; - if (k_requires_grad) { - grad_k_shot[i] -= dt * p[i] * k_grad_store_1[i] * (DW_DTYPE)step_ratio; + if (k_requires_grad && grad_idx >= 0) { + grad_k_shot[i] -= + dt * p[i] * + k_grad_store_1[shot_idx * grad_shot_numel + grad_idx] * + (DW_DTYPE)step_ratio; } // Update P @@ -663,7 +690,8 @@ int set_config( #endif /* x-dimension */ DW_DTYPE const rdx_h, int64_t const nx_h, int64_t const shot_numel_h, - int64_t const pml_x0_h, int64_t const pml_x1_h, + int64_t const grad_shot_numel_h, int64_t const pml_x0_h, + int64_t const pml_x1_h, /* other */ DW_DTYPE const dt_h, int64_t const n_shots_h, int64_t const step_ratio_h, bool const k_batched_h, bool const b_batched_h) { @@ -683,6 +711,8 @@ int set_config( gpuErrchk(cudaMemcpyToSymbol(rdx, &rdx_h, sizeof(DW_DTYPE))); gpuErrchk(cudaMemcpyToSymbol(nx, &nx_h, sizeof(int64_t))); gpuErrchk(cudaMemcpyToSymbol(shot_numel, &shot_numel_h, sizeof(int64_t))); + gpuErrchk( + cudaMemcpyToSymbol(grad_shot_numel, &grad_shot_numel_h, sizeof(int64_t))); gpuErrchk(cudaMemcpyToSymbol(n_shots, &n_shots_h, sizeof(int64_t))); gpuErrchk(cudaMemcpyToSymbol(step_ratio, &step_ratio_h, sizeof(int64_t))); gpuErrchk(cudaMemcpyToSymbol(pml_x0, &pml_x0_h, sizeof(int64_t))); @@ -765,6 +795,8 @@ extern "C" void *__restrict__ const bx_grad_store_3, char const *__restrict__ const *__restrict__ const bx_grad_filenames_ptr, + int64_t const *__restrict__ const gradient_mask_indices, + int64_t const grad_shot_numel_h, DW_DTYPE *__restrict__ const r_p, #if DW_NDIM >= 3 DW_DTYPE *__restrict__ const r_vz, @@ -893,8 +925,8 @@ extern "C" #if DW_NDIM >= 2 rdy_h, ny_h, pml_y0_h, pml_y1_h, #endif - rdx_h, nx_h, shot_numel_h, pml_x0_h, pml_x1_h, dt_h, n_shots_h, - step_ratio_h, k_batched_h, b_batched_h); + rdx_h, nx_h, shot_numel_h, grad_shot_numel_h, pml_x0_h, pml_x1_h, dt_h, + n_shots_h, step_ratio_h, k_batched_h, b_batched_h); if (err != 0) return err; } @@ -940,7 +972,7 @@ extern "C" DW_DTYPE *__restrict name##_store_1_t = \ name##_store_1a + \ ((storage_mode == STORAGE_DEVICE && !storage_compression) \ - ? (t / step_ratio_h) * n_shots_h * shot_numel_h \ + ? (t / step_ratio_h) * n_shots_h * grad_shot_numel_h \ : 0); \ void *__restrict const name##_store_2_t = \ (uint8_t *)name##_store_2 + \ @@ -1055,6 +1087,7 @@ extern "C" buoyancy_y, #endif buoyancy_x, + gradient_mask_indices, #if DW_NDIM >= 3 bz_grad_store_1_t, #endif @@ -1124,7 +1157,7 @@ extern "C" #if DW_NDIM >= 2 phi_y, #endif - phi_x, k, k_grad_store_1_t, + phi_x, k, gradient_mask_indices, k_grad_store_1_t, #if DW_NDIM >= 3 az, bz, #endif @@ -1144,7 +1177,7 @@ extern "C" if (storage_save_snapshot_gpu( \ name##_store_1_t, name##_store_2_t, name##_store_3_t, fp_##name, \ storage_mode, storage_compression, step_idx, shot_bytes_uncomp, \ - shot_bytes_comp, n_shots_h, shot_numel_h, \ + shot_bytes_comp, n_shots_h, grad_shot_numel_h, \ sizeof(DW_DTYPE) == sizeof(double), \ use_double_buffering ? stream_storage : stream_compute) != 0) \ return 1; \ @@ -1260,6 +1293,8 @@ extern "C" void *__restrict__ const bx_grad_store_3, char const *__restrict__ const *__restrict__ const bx_grad_filenames_ptr, + int64_t const *__restrict__ const gradient_mask_indices, + int64_t const grad_shot_numel_h, DW_DTYPE *__restrict__ const grad_f_p, #if DW_NDIM >= 3 DW_DTYPE *__restrict__ const grad_f_vz, @@ -1404,8 +1439,8 @@ extern "C" #if DW_NDIM >= 2 rdy_h, ny_h, pml_y0_h, pml_y1_h, #endif - rdx_h, nx_h, shot_numel_h, pml_x0_h, pml_x1_h, dt_h, n_shots_h, - step_ratio_h, k_batched_h, b_batched_h); + rdx_h, nx_h, shot_numel_h, grad_shot_numel_h, pml_x0_h, pml_x1_h, dt_h, + n_shots_h, step_ratio_h, k_batched_h, b_batched_h); if (err != 0) return err; } @@ -1449,7 +1484,7 @@ extern "C" DW_DTYPE *__restrict__ name##_store_1_t = \ name##_store_1a + \ ((storage_mode == STORAGE_DEVICE && !storage_compression) \ - ? (t / step_ratio_h) * n_shots_h * shot_numel_h \ + ? (t / step_ratio_h) * n_shots_h * grad_shot_numel_h \ : 0); \ void *__restrict__ const name##_store_2_t = \ (uint8_t *)name##_store_2 + \ @@ -1470,7 +1505,7 @@ extern "C" if (storage_load_snapshot_gpu( \ (void *)name##_store_1_t, name##_store_2_t, name##_store_3_t, \ fp_##name, storage_mode, storage_compression, step_idx, \ - shot_bytes_uncomp, shot_bytes_comp, n_shots_h, shot_numel_h, \ + shot_bytes_uncomp, shot_bytes_comp, n_shots_h, grad_shot_numel_h,\ sizeof(DW_DTYPE) == sizeof(double), \ use_double_buffering ? stream_storage : stream_compute) != 0) \ return 1; \ @@ -1549,6 +1584,7 @@ extern "C" buoyancy_y, #endif buoyancy_x, + gradient_mask_indices, #if DW_NDIM >= 3 grad_bz_shot, bz_grad_store_1_t, azh, bz, bzh, #endif @@ -1626,7 +1662,7 @@ extern "C" #if DW_NDIM >= 2 buoyancy_y, #endif - buoyancy_x, grad_k_shot, k_grad_store_1_t, + buoyancy_x, gradient_mask_indices, grad_k_shot, k_grad_store_1_t, #if DW_NDIM >= 3 az, bzh, bz, #endif diff --git a/src/deepwave/acoustic.py b/src/deepwave/acoustic.py index c7118d8..90073c4 100644 --- a/src/deepwave/acoustic.py +++ b/src/deepwave/acoustic.py @@ -180,6 +180,7 @@ def forward( origin: Optional[Sequence[int]] = None, nt: Optional[int] = None, model_gradient_sampling_interval: int = 1, + gradient_mask: Optional[torch.Tensor] = None, freq_taper_frac: float = 0.0, time_pad_frac: float = 0.0, time_taper: bool = False, @@ -227,6 +228,7 @@ def forward( origin=origin, nt=nt, model_gradient_sampling_interval=model_gradient_sampling_interval, + gradient_mask=gradient_mask, freq_taper_frac=freq_taper_frac, time_pad_frac=time_pad_frac, time_taper=time_taper, @@ -275,6 +277,7 @@ def acoustic( origin: Optional[Sequence[int]] = None, nt: Optional[int] = None, model_gradient_sampling_interval: int = 1, + gradient_mask: Optional[torch.Tensor] = None, freq_taper_frac: float = 0.0, time_pad_frac: float = 0.0, time_taper: bool = False, @@ -326,6 +329,13 @@ def acoustic( origin: Origin of initial wavefields. nt: Number of time steps. model_gradient_sampling_interval: Interval for gradient sampling. + gradient_mask: A boolean Tensor with the same spatial shape as the + model, specifying which cells should have gradients computed. + Optional. If not provided, gradients will be computed everywhere in + the model. If the model is padded internally, an unpadded mask will + be padded with False values. True values indicate cells where + gradients should be computed, while False values indicate cells + where gradients should be set to 0. freq_taper_frac: Frequency taper fraction. time_pad_frac: Time padding fraction. time_taper: Time taper flag. @@ -378,6 +388,31 @@ def acoustic( "y dimension should not be provided." ) + if gradient_mask is not None: + if not isinstance(gradient_mask, torch.Tensor): + raise TypeError("gradient_mask must be a torch.Tensor if provided.") + if gradient_mask.dtype != torch.bool: + raise TypeError("gradient_mask must be a boolean Tensor.") + if gradient_mask.shape != v.shape[-ndim:]: + raise RuntimeError( + "gradient_mask must match the spatial shape of v and have no batch " + "dimension." + ) + if python_backend: + raise RuntimeError("gradient_mask is not supported in the Python backend.") + + if (v.requires_grad or rho.requires_grad): + # In the backend, gradient masking is applied to the K/B gradient storage and accumulation + # but not to v and rho directly. Since buoyancy B depends on neighbor averages of rho, + # autograd gradients on masked cells spread to adjacent rho entries outside the mask, + # causing nonzero gradients there. This violates the API contract. To prevent this, we + # disable gradients on masked cells. + gradient_mask = gradient_mask.to( + device=v.device, dtype=torch.bool, copy=False + ) + v = torch.where(gradient_mask, v, v.detach()) + rho = torch.where(gradient_mask, rho, rho.detach()) + # Prepare initial wavefields list initial_wavefields: List[Optional[torch.Tensor]] = [] initial_wavefields.append(pressure_0) @@ -461,6 +496,16 @@ def acoustic( source_locations_list.insert(0, source_locations_p) receiver_locations_list.insert(0, receiver_locations_p) + models_list = [v, rho] + model_pad_modes = ["replicate", "replicate"] + gradient_mask_for_setup: Optional[torch.Tensor] = None + if gradient_mask is not None: + gradient_mask_for_setup = gradient_mask.to( + device=v.device, dtype=v.dtype, copy=False + ) + models_list.append(gradient_mask_for_setup) + model_pad_modes.append("constant") + ( models, source_amplitudes_out, @@ -483,8 +528,8 @@ def acoustic( device, dtype, ) = deepwave.common.setup_propagator( - [v, rho], - ["replicate", "replicate"], + models_list, + model_pad_modes, grid_spacing, dt, source_amplitudes_list, @@ -524,6 +569,11 @@ def acoustic( receiver_locations_x, ) + gradient_mask_prepared: Optional[torch.Tensor] = None + if gradient_mask is not None: + gradient_mask_prepared = models.pop() + del gradient_mask_for_setup + models = prepare_models(models[0], models[1]) # Scale source amplitudes @@ -533,6 +583,15 @@ def acoustic( model_shape = models[0].shape[-ndim:] flat_model_shape = int(torch.prod(torch.tensor(model_shape)).item()) + # Prepare compact gradient mask indices and number of elements + gradient_mask_indices: Optional[torch.Tensor] = None + gradient_mask_numel = flat_model_shape + if gradient_mask_prepared is not None: + gradient_mask_indices, gradient_mask_numel = _prepare_gradient_indices( + gradient_mask_prepared, + ) + del gradient_mask_prepared + for i, (src_amp, src_loc, model) in enumerate( zip(source_amplitudes_out, sources_i, models) ): @@ -586,6 +645,8 @@ def acoustic( storage_mode, storage_path, storage_compression, + gradient_mask_indices, + gradient_mask_numel, *models, *source_amplitudes_out, *sources_i, @@ -606,6 +667,32 @@ def acoustic( return outputs +def _prepare_gradient_indices( + gradient_mask_prepared: torch.Tensor, +) -> Tuple[torch.Tensor, int]: + """ + Build compact indices from a prepared (padded/sliced) gradient mask. + + Produces a spatial-only index tensor aligned with the internally padded + model: -1 for masked-out cells and 0..N-1 for masked-in cells (N = number + of True entries), along with that count. Enforces a single shared mask + (no batch dimension) across all shots. + """ + gradient_mask_prepared = gradient_mask_prepared.to( + dtype=torch.bool, copy=False + ) + if gradient_mask_prepared.shape[0] != 1: + raise RuntimeError("gradient_mask must not have a batch dimension.") + gradient_mask_prepared = gradient_mask_prepared[0].contiguous() + mask_flat = gradient_mask_prepared.flatten() + indices_flat = torch.full_like(mask_flat, -1, dtype=torch.int64) + cumsum = mask_flat.cumsum(dim=0).to(torch.int64) - 1 + indices_flat = torch.where(mask_flat, cumsum, indices_flat) + gradient_mask_indices = indices_flat.reshape_as(gradient_mask_prepared).contiguous() + gradient_mask_numel = int(mask_flat.sum().item()) + return gradient_mask_indices, gradient_mask_numel + + def zero_edge(tensor: torch.Tensor, fd_pad: int, dim: int) -> torch.Tensor: """Sets values at the end of a dimension of a tensor to zero.""" tensor[(slice(None),) + (slice(None),) * dim + (-fd_pad,)].fill_(0) @@ -714,6 +801,8 @@ def forward( storage_mode_str: str, storage_path: str, storage_compression: bool, + gradient_mask_indices: Optional[torch.Tensor], + gradient_mask_numel: int, *args: torch.Tensor, ) -> Tuple[torch.Tensor, ...]: """Performs the forward propagation of the acoustic wave equation.""" @@ -763,6 +852,19 @@ def forward( device = models[0].device dtype = models[0].dtype + is_cuda = models[0].is_cuda + model_shape = models[0].shape[-ndim:] + + gradient_mask_indices_tensor: Optional[torch.Tensor] = None + gradient_mask_indices_ptr: int = 0 + if gradient_mask_indices is not None: + if gradient_mask_indices.shape != model_shape: + raise RuntimeError( + "gradient_mask must match the padded model spatial shape and have " + "no batch dimension." + ) + gradient_mask_indices_tensor = gradient_mask_indices.contiguous() + gradient_mask_indices_ptr = gradient_mask_indices_tensor.data_ptr() # Setup storage if str(device) == "cpu" and storage_mode_str == "cpu": @@ -779,13 +881,14 @@ def forward( else: raise ValueError(f"Invalid storage_mode {storage_mode_str}") - is_cuda = models[0].is_cuda - model_shape = models[0].shape[-ndim:] + storage_shape: Tuple[int, ...] = tuple(model_shape) + if gradient_mask_indices_tensor is not None: + storage_shape = (gradient_mask_numel,) n_sources_per_shot_list = [locs.numel() // n_shots for locs in sources_i] n_receivers_per_shot_list = [locs.numel() // n_shots for locs in receivers_i] storage_manager = deepwave.common.StorageManager( - model_shape, + storage_shape, dtype, n_shots, nt, @@ -821,6 +924,7 @@ def forward( ctx.backward_callback = backward_callback ctx.callback_frequency = callback_frequency ctx.storage_manager = storage_manager + ctx.gradient_mask_indices = gradient_mask_indices_tensor fd_pad = accuracy // 2 fd_pad_list = [fd_pad, fd_pad - 1] * ndim @@ -931,6 +1035,8 @@ def forward( *[amp.data_ptr() for amp in source_amplitudes], *[w.data_ptr() for w in wavefields], *storage_manager.storage_ptrs, + gradient_mask_indices_ptr, + storage_manager.num_elements_per_shot, *[amp.data_ptr() for amp in receiver_amplitudes], *[p.data_ptr() for p in pml_profiles], *[loc.data_ptr() for loc in sources_i], @@ -975,6 +1081,12 @@ def backward(ctx: Any, *args: torch.Tensor) -> Tuple[Optional[torch.Tensor], ... # Unpack grid_spacing = ctx.grid_spacing ndim = len(grid_spacing) + gradient_mask_indices = ctx.gradient_mask_indices + gradient_mask_indices_ptr = ( + 0 + if gradient_mask_indices is None + else gradient_mask_indices.data_ptr() + ) grad_wavefields = list(args[: -ndim - 1]) grad_r = list(args[-ndim - 1 :]) @@ -1151,6 +1263,8 @@ def backward(ctx: Any, *args: torch.Tensor) -> Tuple[Optional[torch.Tensor], ... *[field.data_ptr() for field in grad_wavefields], *[field.data_ptr() for field in aux_wavefields], *storage_manager.storage_ptrs, + gradient_mask_indices_ptr, + storage_manager.num_elements_per_shot, *[g.data_ptr() for g in grad_f_list], *[g.data_ptr() for g in grad_models], *grad_models_tmp_ptr, @@ -1212,7 +1326,7 @@ def backward(ctx: Any, *args: torch.Tensor) -> Tuple[Optional[torch.Tensor], ... *(slice(fd_pad, shape - (fd_pad - 1)) for shape in model_shape), ) return tuple( - [None] * 14 + [None] * 16 + grad_models + grad_f_list + [None] * num_source_types # sources_i @@ -1316,11 +1430,15 @@ def acoustic_python( storage_mode_str: str, storage_path: str, storage_compression: bool, + gradient_mask_indices: Optional[torch.Tensor], + gradient_mask_numel: int, *args: torch.Tensor, ) -> Tuple[torch.Tensor, ...]: """Python backend for acoustic wave propagation.""" if backward_callback is not None: raise RuntimeError("backward_callback is not supported in the Python backend.") + if gradient_mask_indices is not None: + raise RuntimeError("gradient_mask is not supported in the Python backend.") if storage_mode_str != "device": raise RuntimeError( "Specifying storage mode is not supported in Python backend." diff --git a/src/deepwave/backend_utils.py b/src/deepwave/backend_utils.py index 6eb34b2..6b286af 100644 --- a/src/deepwave/backend_utils.py +++ b/src/deepwave/backend_utils.py @@ -451,6 +451,8 @@ def get_acoustic_forward_template(ndim: int) -> List[Any]: args += [c_void_p] * (1 + 3 * ndim) # p, v, phi, psi args += [c_void_p] * 5 # k_store_1a, k_store_1b, k_store_2, k_store_3, k_filenames args += [c_void_p] * (5 * ndim) # b_store... + args += [c_void_p] # gradient_mask_indices + args += [c_int64] # grad_shot_numel args += [c_void_p] * (1 + ndim) # receiver_amplitudes args += [c_void_p] * (4 * ndim) # a, b, ah, bh args += [c_void_p] * (1 + ndim) # sources_i @@ -482,6 +484,8 @@ def get_acoustic_backward_template(ndim: int) -> List[Any]: args += [c_void_p] * (ndim) # psin args += [c_void_p] * 5 # k_store args += [c_void_p] * (5 * ndim) # b_store + args += [c_void_p] # gradient_mask_indices + args += [c_int64] # grad_shot_numel args += [c_void_p] * (1 + ndim) # grad_f args += [c_void_p] * (2 + 2 * ndim) # grad_k, grad_b, grad_k_thread, grad_b_thread args += [c_void_p] * (4 * ndim) # a, b, ah, bh diff --git a/tests/test_acoustic.py b/tests/test_acoustic.py index 50ebd04..c91bab4 100644 --- a/tests/test_acoustic.py +++ b/tests/test_acoustic.py @@ -1,5 +1,6 @@ """Tests for deepwave.acoustic.""" +from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union import pytest @@ -54,6 +55,7 @@ def acousticprop( psi_x_0: Optional[torch.Tensor] = None, nt: Optional[int] = None, model_gradient_sampling_interval: int = 1, + gradient_mask: Optional[torch.Tensor] = None, functional: bool = True, ) -> Tuple[torch.Tensor, ...]: """Wraps the acoustic propagator.""" @@ -120,6 +122,7 @@ def acousticprop( psi_x_0=psi_x_0, nt=nt, model_gradient_sampling_interval=model_gradient_sampling_interval, + gradient_mask=gradient_mask, **prop_kwargs, ) @@ -150,6 +153,7 @@ def acousticprop( psi_x_0=psi_x_0, nt=nt, model_gradient_sampling_interval=model_gradient_sampling_interval, + gradient_mask=gradient_mask, **prop_kwargs, ) @@ -161,6 +165,115 @@ def test_python_backends() -> None: run_forward(propagator=acousticprop, prop_kwargs={"python_backend": "compile"}) +def _compute_gradient(mask: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + nt = 4 + v = torch.ones(5, 5, requires_grad=True) + rho = torch.ones_like(v, requires_grad=True) + source_amplitudes = torch.zeros(1, 1, nt) + source_amplitudes[0, 0, 0] = 1 + locations = torch.tensor([[[2, 2]]], dtype=torch.long) + outputs = acoustic( + v, + rho, + 1.0, + 0.001, + source_amplitudes_p=source_amplitudes, + source_locations_p=locations, + receiver_locations_p=locations, + nt=nt, + pml_width=2, + gradient_mask=mask, + ) + receivers = outputs[-3:] + loss = sum(o.sum() for o in receivers) + loss.backward() + assert v.grad is not None + assert rho.grad is not None + return v.grad.detach(), rho.grad.detach() + + +def _run_forward_for_storage(storage_dir: Path, mask: Optional[torch.Tensor]) -> int: + storage_dir.mkdir() + nt = 3 + v = torch.ones(5, 5, requires_grad=True) + rho = torch.ones_like(v) + source_amplitudes = torch.zeros(1, 1, nt) + source_amplitudes[0, 0, 0] = 1 + locations = torch.tensor([[[2, 2]]], dtype=torch.long) + outputs = acoustic( + v, + rho, + 1.0, + 0.001, + source_amplitudes_p=source_amplitudes, + source_locations_p=locations, + receiver_locations_p=locations, + nt=nt, + pml_width=2, + storage_mode="disk", + storage_path=str(storage_dir), + storage_compression=False, + gradient_mask=mask, + ) + assert outputs + return sum(p.stat().st_size for p in storage_dir.rglob("*") if p.is_file()) + + +def test_gradient_mask_reduces_storage(tmp_path: Path) -> None: + mask = torch.zeros(5, 5, dtype=torch.bool) + mask[2, 2] = True + masked_size = _run_forward_for_storage(tmp_path / "masked", mask) + unmasked_size = _run_forward_for_storage(tmp_path / "unmasked", None) + assert masked_size < (0.1 * unmasked_size) + + +def test_gradient_mask_zeroes_outside_mask() -> None: + mask = torch.zeros(5, 5, dtype=torch.bool) + mask[2, 2] = True + grad_v, grad_rho = _compute_gradient(mask) + assert torch.count_nonzero(grad_v[mask]) == 1 + assert torch.count_nonzero(grad_v[~mask]) == 0 + assert torch.count_nonzero(grad_rho[mask]) == 1 + assert torch.count_nonzero(grad_rho[~mask]) == 0 + + +def test_gradient_mask_default_computes_everywhere() -> None: + full_mask = torch.ones(5, 5, dtype=torch.bool) + grad_no_mask_v, grad_no_mask_rho = _compute_gradient(None) + grad_full_mask_v, grad_full_mask_rho = _compute_gradient(full_mask) + assert torch.count_nonzero(grad_no_mask_v) > 0 + assert torch.count_nonzero(grad_no_mask_rho) > 0 + assert torch.allclose(grad_no_mask_v, grad_full_mask_v, rtol=1e-30) + assert torch.allclose(grad_no_mask_rho, grad_full_mask_rho, rtol=1e-30) + + +def test_gradient_mask_python_backend_raises() -> None: + nt = 4 + v = torch.ones(5, 5) + rho = torch.ones_like(v) + source_amplitudes = torch.zeros(1, 1, nt) + source_amplitudes[0, 0, 0] = 1 + locations = torch.tensor([[[2, 2]]], dtype=torch.long) + mask = torch.ones(5, 5, dtype=torch.bool) + with pytest.raises( + RuntimeError, + match=r"gradient_mask is not supported in the Python backend\.", + ): + acoustic( + v, + rho, + 1.0, + 0.001, + source_amplitudes_p=source_amplitudes, + source_locations_p=locations, + receiver_locations_p=locations, + nt=nt, + pml_width=2, + python_backend="eager", + gradient_mask=mask, + ) + + @pytest.mark.parametrize( ("nx", "dx"), [ @@ -1052,6 +1165,7 @@ def wrap(python): psi_x_0, nt, 1, + None, True, ) out = propagator(*inputs)