Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 72 additions & 6 deletions examples/scripts/2_structural_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

# Number of steps to run
SMOKE_TEST = os.getenv("CI") is not None
N_steps = 10 if SMOKE_TEST else 500
N_steps = 10 if SMOKE_TEST else 100


# ============================================================================
Expand Down Expand Up @@ -111,7 +111,7 @@

# Run optimization
for step in range(N_steps):
if step % 100 == 0:
if step % (N_steps // 5) == 0:
print(f"Step {step}: Potential energy: {state.energy[0].item()} eV")
state = ts.fire_step(state=state, model=lj_model, dt_max=0.01)

Expand Down Expand Up @@ -174,7 +174,7 @@

print("\nRunning FIRE:")
for step in range(N_steps):
if step % 20 == 0:
if step % (N_steps // 5) == 0:
print(f"Step {step}, Energy: {[energy.item() for energy in state.energy]}")

state = ts.fire_step(state=state, model=model, dt_max=0.01)
Expand Down Expand Up @@ -254,7 +254,7 @@

print("\nRunning batched unit cell gradient descent:")
for step in range(N_steps):
if step % 20 == 0:
if step % (N_steps // 5) == 0:
P1 = -torch.trace(state.stress[0]) * UnitConversion.eV_per_Ang3_to_GPa / 3
P2 = -torch.trace(state.stress[1]) * UnitConversion.eV_per_Ang3_to_GPa / 3
P3 = -torch.trace(state.stress[2]) * UnitConversion.eV_per_Ang3_to_GPa / 3
Expand Down Expand Up @@ -308,7 +308,7 @@

print("\nRunning batched unit cell FIRE:")
for step in range(N_steps):
if step % 20 == 0:
if step % (N_steps // 5) == 0:
P1 = -torch.trace(state.stress[0]) * UnitConversion.eV_per_Ang3_to_GPa / 3
P2 = -torch.trace(state.stress[1]) * UnitConversion.eV_per_Ang3_to_GPa / 3
P3 = -torch.trace(state.stress[2]) * UnitConversion.eV_per_Ang3_to_GPa / 3
Expand Down Expand Up @@ -360,7 +360,7 @@

print("\nRunning batched frechet cell filter with FIRE:")
for step in range(N_steps):
if step % 20 == 0:
if step % (N_steps // 5) == 0:
P1 = -torch.trace(state.stress[0]) * UnitConversion.eV_per_Ang3_to_GPa / 3
P2 = -torch.trace(state.stress[1]) * UnitConversion.eV_per_Ang3_to_GPa / 3
P3 = -torch.trace(state.stress[2]) * UnitConversion.eV_per_Ang3_to_GPa / 3
Expand All @@ -386,6 +386,72 @@
print(f"Initial pressure: {initial_pressure} GPa")
print(f"Final pressure: {final_pressure} GPa")

# ============================================================================
# SECTION 7: Batched MACE L-BFGS
# ============================================================================
print("\n" + "=" * 70)
print("SECTION 7: Batched MACE L-BFGS")
print("=" * 70)

# Recreate structures with perturbations
si_dc = bulk("Si", "diamond", a=5.21).repeat((2, 2, 2))
si_dc.positions += 0.2 * rng.standard_normal(si_dc.positions.shape)

cu_dc = bulk("Cu", "fcc", a=3.85).repeat((2, 2, 2))
cu_dc.positions += 0.2 * rng.standard_normal(cu_dc.positions.shape)

fe_dc = bulk("Fe", "bcc", a=2.95).repeat((2, 2, 2))
fe_dc.positions += 0.2 * rng.standard_normal(fe_dc.positions.shape)

atoms_list = [si_dc, cu_dc, fe_dc]

state = ts.io.atoms_to_state(atoms_list, device=device, dtype=dtype)
results = model(state)
state = ts.lbfgs_init(state=state, model=model, alpha=70.0, step_size=1.0)

print("\nRunning L-BFGS:")
for step in range(N_steps):
if step % (N_steps // 5) == 0:
print(f"Step {step}, Energy: {[energy.item() for energy in state.energy]}")
state = ts.lbfgs_step(state=state, model=model, max_history=100)

print(f"Initial energies: {[energy.item() for energy in results['energy']]} eV")
print(f"Final energies: {[energy.item() for energy in state.energy]} eV")


# ============================================================================
# SECTION 8: Batched MACE BFGS
# ============================================================================
print("\n" + "=" * 70)
print("SECTION 8: Batched MACE BFGS")
print("=" * 70)

# Recreate structures with perturbations
si_dc = bulk("Si", "diamond", a=5.21).repeat((2, 2, 2))
si_dc.positions += 0.2 * rng.standard_normal(si_dc.positions.shape)

cu_dc = bulk("Cu", "fcc", a=3.85).repeat((2, 2, 2))
cu_dc.positions += 0.2 * rng.standard_normal(cu_dc.positions.shape)

fe_dc = bulk("Fe", "bcc", a=2.95).repeat((2, 2, 2))
fe_dc.positions += 0.2 * rng.standard_normal(fe_dc.positions.shape)

atoms_list = [si_dc, cu_dc, fe_dc]

state = ts.io.atoms_to_state(atoms_list, device=device, dtype=dtype)
results = model(state)
state = ts.bfgs_init(state=state, model=model, alpha=70.0)

print("\nRunning BFGS:")
for step in range(N_steps):
if step % (N_steps // 5) == 0:
print(f"Step {step}, Energy: {[energy.item() for energy in state.energy]}")
state = ts.bfgs_step(state=state, model=model)

print(f"Initial energies: {[energy.item() for energy in results['energy']]} eV")
print(f"Final energies: {[energy.item() for energy in state.energy]} eV")


print("\n" + "=" * 70)
print("Structural optimization examples completed!")
print("=" * 70)
6 changes: 6 additions & 0 deletions torch_sim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,19 @@
from torch_sim.monte_carlo import SwapMCState, swap_mc_init, swap_mc_step
from torch_sim.optimizers import (
OPTIM_REGISTRY,
BFGSState,
FireState,
LBFGSState,
Optimizer,
OptimState,
bfgs_init,
bfgs_step,
fire_init,
fire_step,
gradient_descent_init,
gradient_descent_step,
lbfgs_init,
lbfgs_step,
)
from torch_sim.optimizers.cell_filters import (
CELL_FILTER_REGISTRY,
Expand Down
13 changes: 12 additions & 1 deletion torch_sim/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,20 @@
from enum import StrEnum
from typing import Any, Final, Literal, get_args

from torch_sim.optimizers.bfgs import bfgs_init, bfgs_step
from torch_sim.optimizers.cell_filters import CellFireState, CellOptimState # noqa: F401
from torch_sim.optimizers.fire import fire_init, fire_step
from torch_sim.optimizers.gradient_descent import (
gradient_descent_init,
gradient_descent_step,
)
from torch_sim.optimizers.state import FireState, OptimState # noqa: F401
from torch_sim.optimizers.lbfgs import lbfgs_init, lbfgs_step
from torch_sim.optimizers.state import ( # noqa: F401
BFGSState,
FireState,
LBFGSState,
OptimState,
)


FireFlavor = Literal["vv_fire", "ase_fire"]
Expand All @@ -28,9 +35,13 @@ class Optimizer(StrEnum):

gradient_descent = "gradient_descent"
fire = "fire"
lbfgs = "lbfgs"
bfgs = "bfgs"


OPTIM_REGISTRY: Final[dict[Optimizer, tuple[Callable[..., Any], Callable[..., Any]]]] = {
Optimizer.gradient_descent: (gradient_descent_init, gradient_descent_step),
Optimizer.fire: (fire_init, fire_step),
Optimizer.lbfgs: (lbfgs_init, lbfgs_step),
Optimizer.bfgs: (bfgs_init, bfgs_step),
}
Loading