diff --git a/tests/models/conftest.py b/tests/models/conftest.py index 0e481967..ca1add6e 100644 --- a/tests/models/conftest.py +++ b/tests/models/conftest.py @@ -123,7 +123,7 @@ def test_model_calculator_consistency( return test_model_calculator_consistency -def make_validate_model_outputs_test( # noqa: PLR0915 +def make_validate_model_outputs_test( model_fixture_name: str, device: torch.device = DEVICE, dtype: torch.dtype = torch.float64, @@ -131,108 +131,16 @@ def make_validate_model_outputs_test( # noqa: PLR0915 """Factory function to create model output validation tests. Args: - test_name: Name of the test (used in the function name and messages) model_fixture_name: Name of the model fixture to validate + device: Device to run validation on + dtype: Data type to use for validation """ + from torch_sim.models.interface import validate_model_outputs - def test_model_output_validation(request: pytest.FixtureRequest) -> None: # noqa: PLR0915 + def test_model_output_validation(request: pytest.FixtureRequest) -> None: """Test that a model implementation follows the ModelInterface contract.""" - # Get the model fixture dynamically model: ModelInterface = request.getfixturevalue(model_fixture_name) + validate_model_outputs(model, device, dtype) - from ase.build import bulk - - assert model.dtype is not None - assert model.device is not None - assert model.compute_stress is not None - assert model.compute_forces is not None - - try: - if not model.compute_stress: - model.compute_stress = True - stress_computed = True - except NotImplementedError: - stress_computed = False - - try: - if not model.compute_forces: - model.compute_forces = True - force_computed = True - except NotImplementedError: - force_computed = False - - si_atoms = bulk("Si", "diamond", a=5.43, cubic=True) - fe_atoms = bulk("Fe", "fcc", a=5.26, cubic=True).repeat([3, 1, 1]) - - sim_state = ts.io.atoms_to_state([si_atoms, fe_atoms], device, dtype) - - og_positions = sim_state.positions.clone() - og_cell = sim_state.cell.clone() - og_batch = sim_state.system_idx.clone() - og_atomic_nums = sim_state.atomic_numbers.clone() - - model_output = model.forward(sim_state) - - # assert model did not mutate the input - assert torch.allclose(og_positions, sim_state.positions) - assert torch.allclose(og_cell, sim_state.cell) - assert torch.allclose(og_batch, sim_state.system_idx) - assert torch.allclose(og_atomic_nums, sim_state.atomic_numbers) - - # assert model output has the correct keys - assert "energy" in model_output - assert "forces" in model_output if force_computed else True - assert "stress" in model_output if stress_computed else True - - # assert model output shapes are correct - assert model_output["energy"].shape == (2,) - assert model_output["forces"].shape == (20, 3) if force_computed else True - assert model_output["stress"].shape == (2, 3, 3) if stress_computed else True - - si_state = ts.io.atoms_to_state([si_atoms], device, dtype) - fe_state = ts.io.atoms_to_state([fe_atoms], device, dtype) - - si_model_output = model.forward(si_state) - assert torch.allclose( - si_model_output["energy"], model_output["energy"][0], atol=10e-3 - ) - assert torch.allclose( - si_model_output["forces"], - model_output["forces"][: si_state.n_atoms], - atol=10e-3, - ) - # assert torch.allclose( - # si_model_output["stress"], - # model_output["stress"][0], - # atol=10e-3, - # ) - - fe_model_output = model.forward(fe_state) - si_model_output = model.forward(si_state) - - assert torch.allclose( - fe_model_output["energy"], model_output["energy"][1], atol=10e-2 - ) - assert torch.allclose( - fe_model_output["forces"], - model_output["forces"][si_state.n_atoms :], - atol=10e-2, - ) - # assert torch.allclose( - # fe_model_output["stress"], - # model_output["stress"][1], - # atol=10e-3, - # ) - - # Test single system output - assert fe_model_output["energy"].shape == (1,) - # forces should be shape (n_atoms, 3) for n_atoms in the system - if force_computed: - assert fe_model_output["forces"].shape == (12, 3) - # stress should be shape (1, 3, 3) for 1 system - if stress_computed: - assert fe_model_output["stress"].shape == (1, 3, 3) - - # Rename the function to include the test name test_model_output_validation.__name__ = f"test_{model_fixture_name}_output_validation" return test_model_output_validation diff --git a/torch_sim/models/interface.py b/torch_sim/models/interface.py index 58f233e8..5c6a243a 100644 --- a/torch_sim/models/interface.py +++ b/torch_sim/models/interface.py @@ -259,30 +259,44 @@ def validate_model_outputs( # noqa: C901, PLR0915 raise ValueError(f"{model_output['stress'].shape=} != (2, 3, 3)") si_state = ts.io.atoms_to_state([si_atoms], device, dtype) - fe_state = ts.io.atoms_to_state([fe_atoms], device, dtype) si_model_output = model.forward(si_state) if not torch.allclose( - si_model_output["energy"], model_output["energy"][0], atol=10e-3 + si_model_output["energy"], model_output["energy"][0], atol=1e-3 ): raise ValueError(f"{si_model_output['energy']=} != {model_output['energy'][0]=}") if not torch.allclose( forces := si_model_output["forces"], expected_forces := model_output["forces"][: si_state.n_atoms], - atol=10e-3, + atol=1e-3, ): raise ValueError(f"{forces=} != {expected_forces=}") - fe_model_output = model.forward(fe_state) - si_model_output = model.forward(si_state) + # Test single Si system output shapes (8 atoms) + if si_model_output["energy"].shape != (1,): + raise ValueError(f"{si_model_output['energy'].shape=} != (1,)") + if force_computed and si_model_output["forces"].shape != (8, 3): + raise ValueError(f"{si_model_output['forces'].shape=} != (8, 3)") + if stress_computed and si_model_output["stress"].shape != (1, 3, 3): + raise ValueError(f"{si_model_output['stress'].shape=} != (1, 3, 3)") + fe_state = ts.io.atoms_to_state([fe_atoms], device, dtype) + fe_model_output = model.forward(fe_state) if not torch.allclose( - fe_model_output["energy"], model_output["energy"][1], atol=10e-2 + fe_model_output["energy"], model_output["energy"][1], atol=1e-3 ): raise ValueError(f"{fe_model_output['energy']=} != {model_output['energy'][1]=}") if not torch.allclose( forces := fe_model_output["forces"], expected_forces := model_output["forces"][si_state.n_atoms :], - atol=10e-2, + atol=1e-3, ): raise ValueError(f"{forces=} != {expected_forces=}") + + # Test single Fe system output shapes (12 atoms) + if fe_model_output["energy"].shape != (1,): + raise ValueError(f"{fe_model_output['energy'].shape=} != (1,)") + if force_computed and fe_model_output["forces"].shape != (12, 3): + raise ValueError(f"{fe_model_output['forces'].shape=} != (12, 3)") + if stress_computed and fe_model_output["stress"].shape != (1, 3, 3): + raise ValueError(f"{fe_model_output['stress'].shape=} != (1, 3, 3)")