Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 68 additions & 18 deletions src/deepwave/acoustic.c
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,14 @@
#define SET_I int64_t const i = x
#endif

static inline int64_t get_grad_idx(
int64_t const *__restrict const gradient_mask_indices, int64_t i) {
if (gradient_mask_indices == NULL) {
return i;
}
return gradient_mask_indices[i];
}

#ifdef _WIN32
__declspec(dllexport)
#endif
Expand Down Expand Up @@ -173,6 +181,8 @@ __declspec(dllexport)
void *__restrict const bx_grad_store_2,
void *__restrict const bx_grad_store_3,
char const *__restrict const *__restrict const bx_grad_filenames_ptr,
int64_t const *__restrict const gradient_mask_indices,
int64_t const n_grad_grid_points,
DW_DTYPE *__restrict const r_p,
#if DW_NDIM >= 3
DW_DTYPE *__restrict const r_vz,
Expand Down Expand Up @@ -298,6 +308,7 @@ __declspec(dllexport)
#endif

int64_t const si = shot * n_grid_points;
int64_t const grad_si = shot * n_grad_grid_points;
int64_t const src_i_p = shot * n_sources_p_per_shot;
#if DW_NDIM >= 3
int64_t const src_i_vz = shot * n_sources_vz_per_shot;
Expand Down Expand Up @@ -372,9 +383,9 @@ __declspec(dllexport)
for (t = start_t; t < start_t + nt; ++t) {
#define SETUP_STORE_SAVE(name, grad_cond) \
DW_DTYPE *__restrict const name##_store_1_t = \
name##_store_1 + si + \
name##_store_1 + grad_si + \
((storage_mode == STORAGE_DEVICE && !storage_compression) \
? (t / step_ratio) * n_shots * n_grid_points \
? (t / step_ratio) * n_shots * n_grad_grid_points \
: 0); \
void *__restrict const name##_store_2_t = \
(uint8_t *)name##_store_2 + \
Expand Down Expand Up @@ -469,7 +480,11 @@ __declspec(dllexport)
}

if (b_requires_grad_t) {
bz_grad_store_1_t[i] = term_z;
int64_t const grad_idx =
get_grad_idx(gradient_mask_indices, i);
if (grad_idx >= 0) {
bz_grad_store_1_t[grad_idx] = term_z;
}
}
vz_shot[i] -= dt * buoyancy_z_shot[i] * term_z;
}
Expand All @@ -492,7 +507,11 @@ __declspec(dllexport)
}

if (b_requires_grad_t) {
by_grad_store_1_t[i] = term_y;
int64_t const grad_idx =
get_grad_idx(gradient_mask_indices, i);
if (grad_idx >= 0) {
by_grad_store_1_t[grad_idx] = term_y;
}
}
vy_shot[i] -= dt * buoyancy_y_shot[i] * term_y;
}
Expand All @@ -516,7 +535,11 @@ __declspec(dllexport)
}

if (b_requires_grad_t) {
bx_grad_store_1_t[i] = term_x;
int64_t const grad_idx =
get_grad_idx(gradient_mask_indices, i);
if (grad_idx >= 0) {
bx_grad_store_1_t[grad_idx] = term_x;
}
}
vx_shot[i] -= dt * buoyancy_x_shot[i] * term_x;
}
Expand Down Expand Up @@ -599,7 +622,11 @@ __declspec(dllexport)
div_v += d_x;

if (k_requires_grad_t) {
k_grad_store_1_t[i] = div_v;
int64_t const grad_idx =
get_grad_idx(gradient_mask_indices, i);
if (grad_idx >= 0) {
k_grad_store_1_t[grad_idx] = div_v;
}
}
p_shot[i] -= dt * k_shot[i] * div_v;
}
Expand All @@ -619,7 +646,7 @@ __declspec(dllexport)
storage_save_snapshot_cpu( \
name##_store_1_t, name##_store_2_t, fp_##name, storage_mode, \
storage_compression, step_idx, shot_bytes_uncomp, shot_bytes_comp, \
n_grid_points, sizeof(DW_DTYPE) == sizeof(double)); \
n_grad_grid_points, sizeof(DW_DTYPE) == sizeof(double)); \
}

SAVE_SNAPSHOT(k_grad, k_requires_grad_t)
Expand Down Expand Up @@ -721,6 +748,8 @@ __declspec(dllexport)
void *__restrict const bx_grad_store_2,
void *__restrict const bx_grad_store_3,
char const *__restrict const *__restrict const bx_grad_filenames_ptr,
int64_t const *__restrict const gradient_mask_indices,
int64_t const n_grad_grid_points,
DW_DTYPE *__restrict const grad_f_p,
#if DW_NDIM >= 3
DW_DTYPE *__restrict const grad_f_vz,
Expand Down Expand Up @@ -859,6 +888,7 @@ __declspec(dllexport)
#endif /* _OPENMP */

int64_t const si = shot * n_grid_points;
int64_t const grad_si = shot * n_grad_grid_points;
int64_t const src_i_p = shot * n_sources_p_per_shot;
#if DW_NDIM >= 3
int64_t const src_i_vz = shot * n_sources_vz_per_shot;
Expand Down Expand Up @@ -941,9 +971,9 @@ __declspec(dllexport)
for (t = start_t - 1; t >= start_t - nt; --t) {
#define SETUP_STORE_LOAD(name, grad_cond) \
DW_DTYPE *__restrict const name##_store_1_t = \
name##_store_1 + si + \
name##_store_1 + grad_si + \
((storage_mode == STORAGE_DEVICE && !storage_compression) \
? (t / step_ratio) * n_shots * n_grid_points \
? (t / step_ratio) * n_shots * n_grad_grid_points \
: 0); \
void *__restrict const name##_store_2_t = \
(uint8_t *)name##_store_2 + \
Expand All @@ -956,7 +986,7 @@ __declspec(dllexport)
storage_load_snapshot_cpu( \
(void *)name##_store_1_t, name##_store_2_t, fp_##name, storage_mode, \
storage_compression, step_idx, shot_bytes_uncomp, shot_bytes_comp, \
n_grid_points, sizeof(DW_DTYPE) == sizeof(double)); \
n_grad_grid_points, sizeof(DW_DTYPE) == sizeof(double)); \
}

SETUP_STORE_LOAD(k_grad, k_requires_grad)
Expand Down Expand Up @@ -1039,8 +1069,13 @@ __declspec(dllexport)
dt * buoyancy_z_shot[i] * azh[z] * vz_shot[i];

if (b_requires_grad_t) {
grad_bz_shot[i] -= dt * vz_shot[i] * bz_grad_store_1_t[i] *
(DW_DTYPE)step_ratio;
int64_t const grad_idx =
get_grad_idx(gradient_mask_indices, i);
if (grad_idx >= 0) {
grad_bz_shot[i] -=
dt * vz_shot[i] *
bz_grad_store_1_t[grad_idx] * (DW_DTYPE)step_ratio;
}
}
}
#endif
Expand All @@ -1061,8 +1096,13 @@ __declspec(dllexport)
dt * buoyancy_y_shot[i] * ayh[y] * vy_shot[i];

if (b_requires_grad_t) {
grad_by_shot[i] -= dt * vy_shot[i] * by_grad_store_1_t[i] *
(DW_DTYPE)step_ratio;
int64_t const grad_idx =
get_grad_idx(gradient_mask_indices, i);
if (grad_idx >= 0) {
grad_by_shot[i] -=
dt * vy_shot[i] *
by_grad_store_1_t[grad_idx] * (DW_DTYPE)step_ratio;
}
}
}
#endif
Expand All @@ -1084,8 +1124,13 @@ __declspec(dllexport)
dt * buoyancy_x_shot[i] * axh[x] * vx_shot[i];

if (b_requires_grad_t) {
grad_bx_shot[i] -= dt * vx_shot[i] * bx_grad_store_1_t[i] *
(DW_DTYPE)step_ratio;
int64_t const grad_idx =
get_grad_idx(gradient_mask_indices, i);
if (grad_idx >= 0) {
grad_bx_shot[i] -=
dt * vx_shot[i] *
bx_grad_store_1_t[grad_idx] * (DW_DTYPE)step_ratio;
}
}
}
}
Expand Down Expand Up @@ -1162,8 +1207,13 @@ __declspec(dllexport)
dt * k_shot[i] * ax[x] * p_shot[i];

if (k_requires_grad_t) {
grad_k_shot[i] -= dt * p_shot[i] * k_grad_store_1_t[i] *
(DW_DTYPE)step_ratio;
int64_t const grad_idx =
get_grad_idx(gradient_mask_indices, i);
if (grad_idx >= 0) {
grad_k_shot[i] -=
dt * p_shot[i] *
k_grad_store_1_t[grad_idx] * (DW_DTYPE)step_ratio;
}
}

// Update P
Expand Down
Loading