Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 6 additions & 98 deletions tests/models/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,116 +123,24 @@ def test_model_calculator_consistency(
return test_model_calculator_consistency


def make_validate_model_outputs_test( # noqa: PLR0915
def make_validate_model_outputs_test(
model_fixture_name: str,
device: torch.device = DEVICE,
dtype: torch.dtype = torch.float64,
):
"""Factory function to create model output validation tests.

Args:
test_name: Name of the test (used in the function name and messages)
model_fixture_name: Name of the model fixture to validate
device: Device to run validation on
dtype: Data type to use for validation
"""
from torch_sim.models.interface import validate_model_outputs

def test_model_output_validation(request: pytest.FixtureRequest) -> None: # noqa: PLR0915
def test_model_output_validation(request: pytest.FixtureRequest) -> None:
"""Test that a model implementation follows the ModelInterface contract."""
# Get the model fixture dynamically
model: ModelInterface = request.getfixturevalue(model_fixture_name)
validate_model_outputs(model, device, dtype)

from ase.build import bulk

assert model.dtype is not None
assert model.device is not None
assert model.compute_stress is not None
assert model.compute_forces is not None

try:
if not model.compute_stress:
model.compute_stress = True
stress_computed = True
except NotImplementedError:
stress_computed = False

try:
if not model.compute_forces:
model.compute_forces = True
force_computed = True
except NotImplementedError:
force_computed = False

si_atoms = bulk("Si", "diamond", a=5.43, cubic=True)
fe_atoms = bulk("Fe", "fcc", a=5.26, cubic=True).repeat([3, 1, 1])

sim_state = ts.io.atoms_to_state([si_atoms, fe_atoms], device, dtype)

og_positions = sim_state.positions.clone()
og_cell = sim_state.cell.clone()
og_batch = sim_state.system_idx.clone()
og_atomic_nums = sim_state.atomic_numbers.clone()

model_output = model.forward(sim_state)

# assert model did not mutate the input
assert torch.allclose(og_positions, sim_state.positions)
assert torch.allclose(og_cell, sim_state.cell)
assert torch.allclose(og_batch, sim_state.system_idx)
assert torch.allclose(og_atomic_nums, sim_state.atomic_numbers)

# assert model output has the correct keys
assert "energy" in model_output
assert "forces" in model_output if force_computed else True
assert "stress" in model_output if stress_computed else True

# assert model output shapes are correct
assert model_output["energy"].shape == (2,)
assert model_output["forces"].shape == (20, 3) if force_computed else True
assert model_output["stress"].shape == (2, 3, 3) if stress_computed else True

si_state = ts.io.atoms_to_state([si_atoms], device, dtype)
fe_state = ts.io.atoms_to_state([fe_atoms], device, dtype)

si_model_output = model.forward(si_state)
assert torch.allclose(
si_model_output["energy"], model_output["energy"][0], atol=10e-3
)
assert torch.allclose(
si_model_output["forces"],
model_output["forces"][: si_state.n_atoms],
atol=10e-3,
)
# assert torch.allclose(
# si_model_output["stress"],
# model_output["stress"][0],
# atol=10e-3,
# )

fe_model_output = model.forward(fe_state)
si_model_output = model.forward(si_state)

assert torch.allclose(
fe_model_output["energy"], model_output["energy"][1], atol=10e-2
)
assert torch.allclose(
fe_model_output["forces"],
model_output["forces"][si_state.n_atoms :],
atol=10e-2,
)
# assert torch.allclose(
# fe_model_output["stress"],
# model_output["stress"][1],
# atol=10e-3,
# )

# Test single system output
assert fe_model_output["energy"].shape == (1,)
# forces should be shape (n_atoms, 3) for n_atoms in the system
if force_computed:
assert fe_model_output["forces"].shape == (12, 3)
# stress should be shape (1, 3, 3) for 1 system
if stress_computed:
assert fe_model_output["stress"].shape == (1, 3, 3)

# Rename the function to include the test name
test_model_output_validation.__name__ = f"test_{model_fixture_name}_output_validation"
return test_model_output_validation
28 changes: 21 additions & 7 deletions torch_sim/models/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,30 +259,44 @@ def validate_model_outputs( # noqa: C901, PLR0915
raise ValueError(f"{model_output['stress'].shape=} != (2, 3, 3)")

si_state = ts.io.atoms_to_state([si_atoms], device, dtype)
fe_state = ts.io.atoms_to_state([fe_atoms], device, dtype)

si_model_output = model.forward(si_state)
if not torch.allclose(
si_model_output["energy"], model_output["energy"][0], atol=10e-3
si_model_output["energy"], model_output["energy"][0], atol=1e-3
):
raise ValueError(f"{si_model_output['energy']=} != {model_output['energy'][0]=}")
if not torch.allclose(
forces := si_model_output["forces"],
expected_forces := model_output["forces"][: si_state.n_atoms],
atol=10e-3,
atol=1e-3,
):
raise ValueError(f"{forces=} != {expected_forces=}")

fe_model_output = model.forward(fe_state)
si_model_output = model.forward(si_state)
# Test single Si system output shapes (8 atoms)
if si_model_output["energy"].shape != (1,):
raise ValueError(f"{si_model_output['energy'].shape=} != (1,)")
if force_computed and si_model_output["forces"].shape != (8, 3):
raise ValueError(f"{si_model_output['forces'].shape=} != (8, 3)")
if stress_computed and si_model_output["stress"].shape != (1, 3, 3):
raise ValueError(f"{si_model_output['stress'].shape=} != (1, 3, 3)")

fe_state = ts.io.atoms_to_state([fe_atoms], device, dtype)
fe_model_output = model.forward(fe_state)
if not torch.allclose(
fe_model_output["energy"], model_output["energy"][1], atol=10e-2
fe_model_output["energy"], model_output["energy"][1], atol=1e-3
):
raise ValueError(f"{fe_model_output['energy']=} != {model_output['energy'][1]=}")
if not torch.allclose(
forces := fe_model_output["forces"],
expected_forces := model_output["forces"][si_state.n_atoms :],
atol=10e-2,
atol=1e-3,
):
raise ValueError(f"{forces=} != {expected_forces=}")

# Test single Fe system output shapes (12 atoms)
if fe_model_output["energy"].shape != (1,):
raise ValueError(f"{fe_model_output['energy'].shape=} != (1,)")
if force_computed and fe_model_output["forces"].shape != (12, 3):
raise ValueError(f"{fe_model_output['forces'].shape=} != (12, 3)")
if stress_computed and fe_model_output["stress"].shape != (1, 3, 3):
raise ValueError(f"{fe_model_output['stress'].shape=} != (1, 3, 3)")