From 30825afe27964c5363436ace00b9307e66fc30b3 Mon Sep 17 00:00:00 2001 From: Kuang Yu Date: Sat, 13 Sep 2025 11:57:14 +0800 Subject: [PATCH 1/2] change torch wrapper from jvp to vjp --- dmff/torch_tools.py | 77 ++++++++++++++++++++++++++++++--------------- 1 file changed, 52 insertions(+), 25 deletions(-) diff --git a/dmff/torch_tools.py b/dmff/torch_tools.py index 66bd29c5..f812a2d9 100644 --- a/dmff/torch_tools.py +++ b/dmff/torch_tools.py @@ -41,37 +41,64 @@ def j2t_pytree(v): def wrap_torch_potential_kernel(potential_t): - @partial(jax.custom_jvp, nondiff_argnums=(2,)) + # jvp, good for push-forward mode + # @partial(jax.custom_jvp, nondiff_argnums=(2,)) + # def potential(positions, box, pairs, params): + # res = potential_t(j2t_pytree(positions), \ + # j2t_pytree(box), \ + # np.array(pairs), \ + # j2t_pytree(params)) + # return res + + # @potential.defjvp + # def potential_jvp(pairs, primals, tangents): + # positions, box, params = primals + # dpositions, dbox, dparams = tangents + # # convert inputs to torch + # positions_t = j2t_pytree(positions) + # box_t = j2t_pytree(box) + # params_t = j2t_pytree(params) + # # do fwd and bwd in torch + # primal_out_torch = potential_t(positions_t, box_t, np.array(pairs), params_t) + # primal_out_torch.backward() + # # read gradient in torch + # g_positions = t2j_extract_grad(positions_t) + # g_box = t2j_extract_grad(box_t) + # g_params = t2j_extract_grad(params_t) + # # prepare output + # primal_out = t2j(primal_out_torch) + # tangent_out = jnp.sum(g_positions * dpositions) + jnp.sum(g_box * box) + # tangents_leaves = jax.tree.leaves(dparams) + # grad_leaves = jax.tree.leaves(g_params) + # for x, y in zip(tangents_leaves, grad_leaves): + # tangent_out += jnp.sum(x * y) + # return primal_out, tangent_out + + # vjp: good for backward + @partial(jax.custom_vjp, nondiff_argnums=(2,)) def potential(positions, box, pairs, params): - res = potential_t(j2t_pytree(positions), \ - j2t_pytree(box), \ - np.array(pairs), \ + res = potential_t(j2t_pytree(positions), + j2t_pytree(box), + np.array(pairs), j2t_pytree(params)) return res - @potential.defjvp - def potential_jvp(pairs, primals, tangents): - positions, box, params = primals - dpositions, dbox, dparams = tangents - # convert inputs to torch - positions_t = j2t_pytree(positions) + def potential_fwd(positions, box, pairs, params): + pos_t = j2t_pytree(positions) box_t = j2t_pytree(box) + pairs = np.array(pairs) params_t = j2t_pytree(params) - # do fwd and bwd in torch - primal_out_torch = potential_t(positions_t, box_t, np.array(pairs), params_t) - primal_out_torch.backward() - # read gradient in torch - g_positions = t2j_extract_grad(positions_t) - g_box = t2j_extract_grad(box_t) - g_params = t2j_extract_grad(params_t) - # prepare output - primal_out = t2j(primal_out_torch) - tangent_out = jnp.sum(g_positions * dpositions) + jnp.sum(g_box * box) - tangents_leaves = jax.tree.leaves(dparams) - grad_leaves = jax.tree.leaves(g_params) - for x, y in zip(tangents_leaves, grad_leaves): - tangent_out += jnp.sum(x * y) - return primal_out, tangent_out + energy = potential_t(pos_t, box_t, pairs, params_t) + energy.backward() + grads = (t2j_extract_grad(pos_t), + t2j_extract_grad(box_t), + t2j_extract_grad(params_t)) + return t2j(energy), grads + + def potential_bwd(pairs, res, g): + return res[0]*g, res[1]*g, jax.tree.map(lambda x: x*g, res[2]) + + potential.defvjp(potential_fwd, potential_bwd) return potential From e634dc43f1efd45c2ff48e3866d3c54e4a1586f9 Mon Sep 17 00:00:00 2001 From: KuangYu Date: Tue, 2 Dec 2025 23:41:01 +0800 Subject: [PATCH 2/2] repack the eann pickle file to be compatible with new jax version --- examples/eann/eann_model.pickle | Bin 96288 -> 95959 bytes tests/data/eann_model.pickle | Bin 96288 -> 95959 bytes tests/data/water_eann.pickle | Bin 58254 -> 58026 bytes 3 files changed, 0 insertions(+), 0 deletions(-) diff --git a/examples/eann/eann_model.pickle b/examples/eann/eann_model.pickle index cf431d04519d203e97d1d01d57501167d8f98bab..4747bb102178b0b3f8cca1a119d269f1b6954ed8 100644 GIT binary patch delta 787 zcmZ4Rf%W=ZR@nxYsZ$ad89<8G@U6Cvq}_RF&%);!))~1)?em zud3SVoXkpKvy8Bs#pvw-beQ1g?Xw{2lJxMa`#2Y7o-tPQ9H8cDFXTL+nKV^j3>-KN zATQ(40P^zooJ2;LpG~lv3s$xy1*$Ab51+C-=|E)=Uz=hz4eIOdV%boINqYDd2IRp_ zG{b9RS0PYgEI5jBCt@HuHL?PVf&K3Q%a^J_3)*n<`rj_RO%HZ7A58u zW3@#9qiZm}ycEZ(d$*eu`dZ bUS>%_ehC)!jKG8o^ae|tcZW5|`K5XQ7W?~K delta 1120 zcmccqmUY1gR@VlWshWw53?NWDMWctgT+d)i4_{VdgpkI zS&%xVvxhq`F*h|OzBnTh$gQ1HS~4YQN)K}q(8dfFjSTh-E^mg`DH%NNQ=n$GP00}6 zd}snEvp&$M49U(6DY&L`Jwtq&UQb5Sl%Wc@39lxR>72}lFq`0-%Jqz}+r;SY0Q9Qx zX8T!4+TadMBBU*AE|RtkN2n9a^^CFG<^Z*A>H^LK7E>RKfny95WvJnjgi}8#Ft;C1 zV8j+M<$5MKYz6E3kc?GV5`JBh=|ElB!qXJHUC{8{?vjO~4;C^>dIa?4=VH}Yu4jhd zzF7r8ef6-AgvUe@PJN&->2&6Kc5C_Gm8-Y!UD&>4@01c?dPa>3zLeCw;>?mty@JG| z#N1-+ZV-T&bztUnmJ-H8X6R;c!OSSar+@m}62@Z|FufV`AW4Y_rZ+RMBoUA8)9;ir o9t%R(&6}5(o1dbWnU`5okY9pLJ0mcw0)v62&AY=I)BB}*0A>JsC;$Ke diff --git a/tests/data/eann_model.pickle b/tests/data/eann_model.pickle index cf431d04519d203e97d1d01d57501167d8f98bab..4747bb102178b0b3f8cca1a119d269f1b6954ed8 100644 GIT binary patch delta 787 zcmZ4Rf%W=ZR@nxYsZ$ad89<8G@U6Cvq}_RF&%);!))~1)?em zud3SVoXkpKvy8Bs#pvw-beQ1g?Xw{2lJxMa`#2Y7o-tPQ9H8cDFXTL+nKV^j3>-KN zATQ(40P^zooJ2;LpG~lv3s$xy1*$Ab51+C-=|E)=Uz=hz4eIOdV%boINqYDd2IRp_ zG{b9RS0PYgEI5jBCt@HuHL?PVf&K3Q%a^J_3)*n<`rj_RO%HZ7A58u zW3@#9qiZm}ycEZ(d$*eu`dZ bUS>%_ehC)!jKG8o^ae|tcZW5|`K5XQ7W?~K delta 1120 zcmccqmUY1gR@VlWshWw53?NWDMWctgT+d)i4_{VdgpkI zS&%xVvxhq`F*h|OzBnTh$gQ1HS~4YQN)K}q(8dfFjSTh-E^mg`DH%NNQ=n$GP00}6 zd}snEvp&$M49U(6DY&L`Jwtq&UQb5Sl%Wc@39lxR>72}lFq`0-%Jqz}+r;SY0Q9Qx zX8T!4+TadMBBU*AE|RtkN2n9a^^CFG<^Z*A>H^LK7E>RKfny95WvJnjgi}8#Ft;C1 zV8j+M<$5MKYz6E3kc?GV5`JBh=|ElB!qXJHUC{8{?vjO~4;C^>dIa?4=VH}Yu4jhd zzF7r8ef6-AgvUe@PJN&->2&6Kc5C_Gm8-Y!UD&>4@01c?dPa>3zLeCw;>?mty@JG| z#N1-+ZV-T&bztUnmJ-H8X6R;c!OSSar+@m}62@Z|FufV`AW4Y_rZ+RMBoUA8)9;ir o9t%R(&6}5(o1dbWnU`5okY9pLJ0mcw0)v62&AY=I)BB}*0A>JsC;$Ke diff --git a/tests/data/water_eann.pickle b/tests/data/water_eann.pickle index c5473060c90978080df952597e03a32872882944..d29cf9309f97e046f7963b28afa5f02c66197097 100644 GIT binary patch delta 756 zcmeA>&b;au!X zG@O#blELoH&^jf9vwaG(BEikSnK+r0z^aU}s$%qZXrGcHxY>Y}6Ka|T^-2IPG^ntAD*! zz&^3yim?U_4=5lfuX&iJnbgArQ=6Grl88sO>!UR7q#iz) z>Vm|g#9Vx)KX{a8n$*LamzSHLqL-PMSyGT+g2i}7U?K&^DT}_5(IHTLmg)fjF3RvF delta 984 zcmZ2=l)3LXvuy**)aJ(wU{E_nqlY<3&tOUqUshs;UVL#;vR-0QQDWtk9-;W6)a3lU z;*z4$R~R*;O1mXL3NQC4jmvDbvkn_uxYf@&TX`t zQZfZuUpX)eu!ohm3fS)lc{x2${9cYfbYF#{`K4TsK6@Hh?)HUL7Fj=E*_Y! z%)F9BeERAhrkNt?