From e387e6a1c39afd5d4809268f66137db56334157d Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sat, 10 Jan 2026 12:37:34 -0500 Subject: [PATCH 1/3] fea: use validate_model_outputs in testing, only test stress if we have PBC true --- tests/models/conftest.py | 106 ++------------------------ tests/models/test_nequip_framework.py | 4 +- tests/models/test_sevennet.py | 4 +- torch_sim/models/interface.py | 28 +++++-- 4 files changed, 30 insertions(+), 112 deletions(-) diff --git a/tests/models/conftest.py b/tests/models/conftest.py index 0e481967..3f12f4a0 100644 --- a/tests/models/conftest.py +++ b/tests/models/conftest.py @@ -103,7 +103,7 @@ def test_model_calculator_consistency( atol=force_atol, ) - if "stress" in model_results: + if "stress" in model_results and all(sim_state.pbc): calc_stress = torch.tensor( atoms.get_stress(), device=device, @@ -123,7 +123,7 @@ def test_model_calculator_consistency( return test_model_calculator_consistency -def make_validate_model_outputs_test( # noqa: PLR0915 +def make_validate_model_outputs_test( model_fixture_name: str, device: torch.device = DEVICE, dtype: torch.dtype = torch.float64, @@ -131,108 +131,16 @@ def make_validate_model_outputs_test( # noqa: PLR0915 """Factory function to create model output validation tests. Args: - test_name: Name of the test (used in the function name and messages) model_fixture_name: Name of the model fixture to validate + device: Device to run validation on + dtype: Data type to use for validation """ + from torch_sim.models.interface import validate_model_outputs - def test_model_output_validation(request: pytest.FixtureRequest) -> None: # noqa: PLR0915 + def test_model_output_validation(request: pytest.FixtureRequest) -> None: """Test that a model implementation follows the ModelInterface contract.""" - # Get the model fixture dynamically model: ModelInterface = request.getfixturevalue(model_fixture_name) + validate_model_outputs(model, device, dtype) - from ase.build import bulk - - assert model.dtype is not None - assert model.device is not None - assert model.compute_stress is not None - assert model.compute_forces is not None - - try: - if not model.compute_stress: - model.compute_stress = True - stress_computed = True - except NotImplementedError: - stress_computed = False - - try: - if not model.compute_forces: - model.compute_forces = True - force_computed = True - except NotImplementedError: - force_computed = False - - si_atoms = bulk("Si", "diamond", a=5.43, cubic=True) - fe_atoms = bulk("Fe", "fcc", a=5.26, cubic=True).repeat([3, 1, 1]) - - sim_state = ts.io.atoms_to_state([si_atoms, fe_atoms], device, dtype) - - og_positions = sim_state.positions.clone() - og_cell = sim_state.cell.clone() - og_batch = sim_state.system_idx.clone() - og_atomic_nums = sim_state.atomic_numbers.clone() - - model_output = model.forward(sim_state) - - # assert model did not mutate the input - assert torch.allclose(og_positions, sim_state.positions) - assert torch.allclose(og_cell, sim_state.cell) - assert torch.allclose(og_batch, sim_state.system_idx) - assert torch.allclose(og_atomic_nums, sim_state.atomic_numbers) - - # assert model output has the correct keys - assert "energy" in model_output - assert "forces" in model_output if force_computed else True - assert "stress" in model_output if stress_computed else True - - # assert model output shapes are correct - assert model_output["energy"].shape == (2,) - assert model_output["forces"].shape == (20, 3) if force_computed else True - assert model_output["stress"].shape == (2, 3, 3) if stress_computed else True - - si_state = ts.io.atoms_to_state([si_atoms], device, dtype) - fe_state = ts.io.atoms_to_state([fe_atoms], device, dtype) - - si_model_output = model.forward(si_state) - assert torch.allclose( - si_model_output["energy"], model_output["energy"][0], atol=10e-3 - ) - assert torch.allclose( - si_model_output["forces"], - model_output["forces"][: si_state.n_atoms], - atol=10e-3, - ) - # assert torch.allclose( - # si_model_output["stress"], - # model_output["stress"][0], - # atol=10e-3, - # ) - - fe_model_output = model.forward(fe_state) - si_model_output = model.forward(si_state) - - assert torch.allclose( - fe_model_output["energy"], model_output["energy"][1], atol=10e-2 - ) - assert torch.allclose( - fe_model_output["forces"], - model_output["forces"][si_state.n_atoms :], - atol=10e-2, - ) - # assert torch.allclose( - # fe_model_output["stress"], - # model_output["stress"][1], - # atol=10e-3, - # ) - - # Test single system output - assert fe_model_output["energy"].shape == (1,) - # forces should be shape (n_atoms, 3) for n_atoms in the system - if force_computed: - assert fe_model_output["forces"].shape == (12, 3) - # stress should be shape (1, 3, 3) for 1 system - if stress_computed: - assert fe_model_output["stress"].shape == (1, 3, 3) - - # Rename the function to include the test name test_model_output_validation.__name__ = f"test_{model_fixture_name}_output_validation" return test_model_output_validation diff --git a/tests/models/test_nequip_framework.py b/tests/models/test_nequip_framework.py index 424ba972..6dda94c1 100644 --- a/tests/models/test_nequip_framework.py +++ b/tests/models/test_nequip_framework.py @@ -96,13 +96,11 @@ def nequip_calculator(compiled_ase_nequip_model_path: Path) -> NequIPCalculator: ) -# NOTE: we take [:-1] to skip benzene. This is because the stress calculation in NequIP -# for non-periodic systems gave infinity. test_nequip_consistency = make_model_calculator_consistency_test( test_name="nequip", model_fixture_name="nequip_model", calculator_fixture_name="nequip_calculator", - sim_state_names=consistency_test_simstate_fixtures[:-1], + sim_state_names=consistency_test_simstate_fixtures, energy_atol=5e-5, dtype=DTYPE, device=DEVICE, diff --git a/tests/models/test_sevennet.py b/tests/models/test_sevennet.py index 1e373b79..ce8a63e1 100644 --- a/tests/models/test_sevennet.py +++ b/tests/models/test_sevennet.py @@ -60,13 +60,11 @@ def test_sevennet_initialization(pretrained_sevenn_model: AtomGraphSequential) - assert model.device == DEVICE -# NOTE: we take [:-1] to skipbenzene due to eps volume giving numerically -# unstable stress off diagonal in xy. See: https://github.com/MDIL-SNU/SevenNet/issues/212 test_sevennet_consistency = make_model_calculator_consistency_test( test_name="sevennet", model_fixture_name="sevenn_model", calculator_fixture_name="sevenn_calculator", - sim_state_names=consistency_test_simstate_fixtures[:-1], + sim_state_names=consistency_test_simstate_fixtures, dtype=DTYPE, ) diff --git a/torch_sim/models/interface.py b/torch_sim/models/interface.py index 58f233e8..5c6a243a 100644 --- a/torch_sim/models/interface.py +++ b/torch_sim/models/interface.py @@ -259,30 +259,44 @@ def validate_model_outputs( # noqa: C901, PLR0915 raise ValueError(f"{model_output['stress'].shape=} != (2, 3, 3)") si_state = ts.io.atoms_to_state([si_atoms], device, dtype) - fe_state = ts.io.atoms_to_state([fe_atoms], device, dtype) si_model_output = model.forward(si_state) if not torch.allclose( - si_model_output["energy"], model_output["energy"][0], atol=10e-3 + si_model_output["energy"], model_output["energy"][0], atol=1e-3 ): raise ValueError(f"{si_model_output['energy']=} != {model_output['energy'][0]=}") if not torch.allclose( forces := si_model_output["forces"], expected_forces := model_output["forces"][: si_state.n_atoms], - atol=10e-3, + atol=1e-3, ): raise ValueError(f"{forces=} != {expected_forces=}") - fe_model_output = model.forward(fe_state) - si_model_output = model.forward(si_state) + # Test single Si system output shapes (8 atoms) + if si_model_output["energy"].shape != (1,): + raise ValueError(f"{si_model_output['energy'].shape=} != (1,)") + if force_computed and si_model_output["forces"].shape != (8, 3): + raise ValueError(f"{si_model_output['forces'].shape=} != (8, 3)") + if stress_computed and si_model_output["stress"].shape != (1, 3, 3): + raise ValueError(f"{si_model_output['stress'].shape=} != (1, 3, 3)") + fe_state = ts.io.atoms_to_state([fe_atoms], device, dtype) + fe_model_output = model.forward(fe_state) if not torch.allclose( - fe_model_output["energy"], model_output["energy"][1], atol=10e-2 + fe_model_output["energy"], model_output["energy"][1], atol=1e-3 ): raise ValueError(f"{fe_model_output['energy']=} != {model_output['energy'][1]=}") if not torch.allclose( forces := fe_model_output["forces"], expected_forces := model_output["forces"][si_state.n_atoms :], - atol=10e-2, + atol=1e-3, ): raise ValueError(f"{forces=} != {expected_forces=}") + + # Test single Fe system output shapes (12 atoms) + if fe_model_output["energy"].shape != (1,): + raise ValueError(f"{fe_model_output['energy'].shape=} != (1,)") + if force_computed and fe_model_output["forces"].shape != (12, 3): + raise ValueError(f"{fe_model_output['forces'].shape=} != (12, 3)") + if stress_computed and fe_model_output["stress"].shape != (1, 3, 3): + raise ValueError(f"{fe_model_output['stress'].shape=} != (1, 3, 3)") From 1bc34518991609abf56b2228ce415951cb79d854 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sat, 10 Jan 2026 12:40:15 -0500 Subject: [PATCH 2/3] fox: revert stress test change as was based on wrong assumption. --- tests/models/conftest.py | 2 +- tests/models/test_nequip_framework.py | 4 +++- tests/models/test_sevennet.py | 2 ++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/models/conftest.py b/tests/models/conftest.py index 3f12f4a0..ca1add6e 100644 --- a/tests/models/conftest.py +++ b/tests/models/conftest.py @@ -103,7 +103,7 @@ def test_model_calculator_consistency( atol=force_atol, ) - if "stress" in model_results and all(sim_state.pbc): + if "stress" in model_results: calc_stress = torch.tensor( atoms.get_stress(), device=device, diff --git a/tests/models/test_nequip_framework.py b/tests/models/test_nequip_framework.py index 6dda94c1..424ba972 100644 --- a/tests/models/test_nequip_framework.py +++ b/tests/models/test_nequip_framework.py @@ -96,11 +96,13 @@ def nequip_calculator(compiled_ase_nequip_model_path: Path) -> NequIPCalculator: ) +# NOTE: we take [:-1] to skip benzene. This is because the stress calculation in NequIP +# for non-periodic systems gave infinity. test_nequip_consistency = make_model_calculator_consistency_test( test_name="nequip", model_fixture_name="nequip_model", calculator_fixture_name="nequip_calculator", - sim_state_names=consistency_test_simstate_fixtures, + sim_state_names=consistency_test_simstate_fixtures[:-1], energy_atol=5e-5, dtype=DTYPE, device=DEVICE, diff --git a/tests/models/test_sevennet.py b/tests/models/test_sevennet.py index ce8a63e1..c6a4eafd 100644 --- a/tests/models/test_sevennet.py +++ b/tests/models/test_sevennet.py @@ -60,6 +60,8 @@ def test_sevennet_initialization(pretrained_sevenn_model: AtomGraphSequential) - assert model.device == DEVICE +# NOTE: we take [:-1] to skipbenzene due to eps volume giving numerically +# unstable stress off diagonal in xy. See: https://github.com/MDIL-SNU/SevenNet/issues/212 test_sevennet_consistency = make_model_calculator_consistency_test( test_name="sevennet", model_fixture_name="sevenn_model", From e34bdf61d3fb10fc5b69b251ffe4ef0fa3f22836 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sat, 10 Jan 2026 12:58:01 -0500 Subject: [PATCH 3/3] fix: exclude benzene 7net again --- tests/models/test_sevennet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_sevennet.py b/tests/models/test_sevennet.py index c6a4eafd..1e373b79 100644 --- a/tests/models/test_sevennet.py +++ b/tests/models/test_sevennet.py @@ -66,7 +66,7 @@ def test_sevennet_initialization(pretrained_sevenn_model: AtomGraphSequential) - test_name="sevennet", model_fixture_name="sevenn_model", calculator_fixture_name="sevenn_calculator", - sim_state_names=consistency_test_simstate_fixtures, + sim_state_names=consistency_test_simstate_fixtures[:-1], dtype=DTYPE, )