diff --git a/dmff/torch_tools.py b/dmff/torch_tools.py index 66bd29c5..f812a2d9 100644 --- a/dmff/torch_tools.py +++ b/dmff/torch_tools.py @@ -41,37 +41,64 @@ def j2t_pytree(v): def wrap_torch_potential_kernel(potential_t): - @partial(jax.custom_jvp, nondiff_argnums=(2,)) + # jvp, good for push-forward mode + # @partial(jax.custom_jvp, nondiff_argnums=(2,)) + # def potential(positions, box, pairs, params): + # res = potential_t(j2t_pytree(positions), \ + # j2t_pytree(box), \ + # np.array(pairs), \ + # j2t_pytree(params)) + # return res + + # @potential.defjvp + # def potential_jvp(pairs, primals, tangents): + # positions, box, params = primals + # dpositions, dbox, dparams = tangents + # # convert inputs to torch + # positions_t = j2t_pytree(positions) + # box_t = j2t_pytree(box) + # params_t = j2t_pytree(params) + # # do fwd and bwd in torch + # primal_out_torch = potential_t(positions_t, box_t, np.array(pairs), params_t) + # primal_out_torch.backward() + # # read gradient in torch + # g_positions = t2j_extract_grad(positions_t) + # g_box = t2j_extract_grad(box_t) + # g_params = t2j_extract_grad(params_t) + # # prepare output + # primal_out = t2j(primal_out_torch) + # tangent_out = jnp.sum(g_positions * dpositions) + jnp.sum(g_box * box) + # tangents_leaves = jax.tree.leaves(dparams) + # grad_leaves = jax.tree.leaves(g_params) + # for x, y in zip(tangents_leaves, grad_leaves): + # tangent_out += jnp.sum(x * y) + # return primal_out, tangent_out + + # vjp: good for backward + @partial(jax.custom_vjp, nondiff_argnums=(2,)) def potential(positions, box, pairs, params): - res = potential_t(j2t_pytree(positions), \ - j2t_pytree(box), \ - np.array(pairs), \ + res = potential_t(j2t_pytree(positions), + j2t_pytree(box), + np.array(pairs), j2t_pytree(params)) return res - @potential.defjvp - def potential_jvp(pairs, primals, tangents): - positions, box, params = primals - dpositions, dbox, dparams = tangents - # convert inputs to torch - positions_t = j2t_pytree(positions) + def potential_fwd(positions, box, pairs, params): + pos_t = j2t_pytree(positions) box_t = j2t_pytree(box) + pairs = np.array(pairs) params_t = j2t_pytree(params) - # do fwd and bwd in torch - primal_out_torch = potential_t(positions_t, box_t, np.array(pairs), params_t) - primal_out_torch.backward() - # read gradient in torch - g_positions = t2j_extract_grad(positions_t) - g_box = t2j_extract_grad(box_t) - g_params = t2j_extract_grad(params_t) - # prepare output - primal_out = t2j(primal_out_torch) - tangent_out = jnp.sum(g_positions * dpositions) + jnp.sum(g_box * box) - tangents_leaves = jax.tree.leaves(dparams) - grad_leaves = jax.tree.leaves(g_params) - for x, y in zip(tangents_leaves, grad_leaves): - tangent_out += jnp.sum(x * y) - return primal_out, tangent_out + energy = potential_t(pos_t, box_t, pairs, params_t) + energy.backward() + grads = (t2j_extract_grad(pos_t), + t2j_extract_grad(box_t), + t2j_extract_grad(params_t)) + return t2j(energy), grads + + def potential_bwd(pairs, res, g): + return res[0]*g, res[1]*g, jax.tree.map(lambda x: x*g, res[2]) + + potential.defvjp(potential_fwd, potential_bwd) return potential diff --git a/examples/eann/eann_model.pickle b/examples/eann/eann_model.pickle index cf431d04..4747bb10 100644 Binary files a/examples/eann/eann_model.pickle and b/examples/eann/eann_model.pickle differ diff --git a/tests/data/eann_model.pickle b/tests/data/eann_model.pickle index cf431d04..4747bb10 100644 Binary files a/tests/data/eann_model.pickle and b/tests/data/eann_model.pickle differ diff --git a/tests/data/water_eann.pickle b/tests/data/water_eann.pickle index c5473060..d29cf930 100644 Binary files a/tests/data/water_eann.pickle and b/tests/data/water_eann.pickle differ