Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import warnings
from collections.abc import Callable
from functools import partial
from typing import Any, get_args
Expand Down Expand Up @@ -1086,6 +1087,57 @@ def test_frechet_cell_fire_optimization(
)


def test_frechet_lbfgs_clamps_extreme_deformation(
ar_supercell_sim_state: SimState, lj_model: ModelInterface
) -> None:
"""LBFGS + Frechet cell filter clamps extreme log-space deformation.

Injects extreme cell_positions so that cell_positions / cell_factor > 2.0,
then verifies: (1) the clamp warning fires, (2) the log-space deformation
is bounded after the step, and (3) positions/cell remain finite.
"""
state = ts.lbfgs_init(
state=ar_supercell_sim_state,
model=lj_model,
cell_filter=ts.CellFilter.frechet,
)

# Inject extreme cell_positions: log-deform = cell_positions/cell_factor = 10
# This far exceeds the MAX_LOG_DEFORM=2.0 clamp threshold.
state.cell_positions = state.cell_positions + 10.0 * state.cell_factor.view(
-1, 1, 1
) * torch.eye(3, device=state.cell.device, dtype=state.cell.dtype).unsqueeze(0)

log_deform_before = (
(state.cell_positions / state.cell_factor.view(-1, 1, 1)).abs().max().item()
)
assert log_deform_before > 5.0, "Setup: log-deform should be extreme before step"

with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter("always")
state = ts.lbfgs_step(state=state, model=lj_model)

# 1. Log-space deformation must be bounded after the clamp
log_deform_after = (
(state.cell_positions / state.cell_factor.view(-1, 1, 1)).abs().max().item()
)
assert log_deform_after <= 2.5, (
f"Log-space deformation should be clamped to ~2.0, got {log_deform_after:.2f}"
)

# 2. Positions and cell must remain finite
assert not torch.isnan(state.positions).any(), "Positions contain NaN"
assert not torch.isinf(state.positions).any(), "Positions contain Inf"
assert not torch.isnan(state.cell).any(), "Cell contains NaN"
assert not torch.isinf(state.cell).any(), "Cell contains Inf"

# 3. The clamp warning must have fired
clamp_warnings = [
warn for warn in caught_warnings if "Clamping log-space" in str(warn.message)
]
assert len(clamp_warnings) > 0, "Expected clamping warning but none was emitted"


@pytest.mark.parametrize(
"filter_func",
[None, ts.CellFilter.unit, ts.CellFilter.frechet],
Expand Down
10 changes: 9 additions & 1 deletion torch_sim/optimizers/bfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@

import torch_sim as ts
from torch_sim.optimizers import cell_filters
from torch_sim.optimizers.cell_filters import CellBFGSState, frechet_cell_filter_init
from torch_sim.optimizers.cell_filters import (
CellBFGSState,
_clamp_deform_grad_log,
frechet_cell_filter_init,
)
from torch_sim.state import SimState


Expand Down Expand Up @@ -507,6 +511,10 @@ def bfgs_step( # noqa: C901, PLR0915
# Frechet: deform_grad = exp(cell_positions / cell_factor)
cell_factor_reshaped = state.cell_factor.view(state.n_systems, 1, 1)
deform_grad_log_new = cell_positions_new / cell_factor_reshaped # [S, 3, 3]
deform_grad_log_new, cell_positions_new = _clamp_deform_grad_log(
deform_grad_log_new, cell_positions_new, cell_factor_reshaped
)
state.cell_positions = cell_positions_new # [S, 3, 3]
deform_grad_new = torch.matrix_exp(deform_grad_log_new) # [S, 3, 3]
else:
# UnitCell: deform_grad = cell_positions / cell_factor
Expand Down
48 changes: 48 additions & 0 deletions torch_sim/optimizers/cell_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
during optimization.
"""

import warnings
from collections.abc import Callable
from dataclasses import dataclass, field
from enum import StrEnum
Expand All @@ -19,6 +20,46 @@
from torch_sim.state import SimState


MAX_LOG_DEFORM = 2.0


def _clamp_deform_grad_log(
deform_grad_log: torch.Tensor,
cell_positions: torch.Tensor,
cell_factor_reshaped: torch.Tensor,
*,
max_log_deform: float = MAX_LOG_DEFORM,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Clamp log-space deformation gradient to prevent matrix_exp overflow.

When cell_positions grow unbounded (from diverging structures or extreme
steps), matrix_exp overflows to Inf/NaN. This clamps the log-space values
and writes back the clamped cell_positions so they don't re-accumulate.

Args:
deform_grad_log: Log of the deformation gradient, shape (S, 3, 3).
cell_positions: Current cell positions in log space, shape (S, 3, 3).
cell_factor_reshaped: Cell factor broadcast to (S, 1, 1).
max_log_deform: Maximum absolute value for log-space entries.

Returns:
Tuple of (clamped deform_grad_log, clamped cell_positions).
"""
exceeds = deform_grad_log.abs() > max_log_deform
if exceeds.any():
n_clamped = int(exceeds.any(dim=(-2, -1)).sum().item())
warnings.warn(
f"Clamping log-space deformation gradient for {n_clamped} "
f"system(s) to [-{max_log_deform}, {max_log_deform}] "
f"(max |log(F)| = {deform_grad_log.abs().max().item():.2f}). "
f"This prevents matrix_exp overflow from diverging cell optimization.",
stacklevel=3,
)
deform_grad_log = deform_grad_log.clamp(-max_log_deform, max_log_deform)
cell_positions = deform_grad_log * cell_factor_reshaped
return deform_grad_log, cell_positions


def _setup_cell_factor(
state: SimState,
cell_factor: float | torch.Tensor | None,
Expand Down Expand Up @@ -294,6 +335,9 @@ def frechet_cell_step[T: AnyCellState](state: T, cell_lr: float | torch.Tensor)
# Convert from log space to deformation gradient
cell_factor_reshaped = state.cell_factor.view(state.n_systems, 1, 1)
deform_grad_log_new = cell_positions_new / cell_factor_reshaped
deform_grad_log_new, cell_positions_new = _clamp_deform_grad_log(
deform_grad_log_new, cell_positions_new, cell_factor_reshaped
)
deform_grad_new = torch.matrix_exp(deform_grad_log_new)

# Update cell from new deformation gradient
Expand Down Expand Up @@ -336,6 +380,10 @@ def compute_cell_forces[T: AnyCellState](
deform_grad_log = tsm.matrix_log_33(
cur_deform_grad, sim_dtype=cur_deform_grad.dtype
)
# Clamp to the same limit used in lbfgs_step to prevent NaN from
# propagating into expm_frechet. Systems hitting the clamp have
# diverging cells; their cell forces will be approximate but finite.
deform_grad_log = deform_grad_log.clamp(-MAX_LOG_DEFORM, MAX_LOG_DEFORM)
frechet_method = getattr(state, "frechet_method", None)
cell_forces = _frechet_cell_forces(
deform_grad_log, ucf_cell_grad, frechet_method=frechet_method
Expand Down
5 changes: 5 additions & 0 deletions torch_sim/optimizers/fire.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch_sim.math as tsm
from torch_sim._duecredit import dcite
from torch_sim.optimizers import CellFireState, cell_filters
from torch_sim.optimizers.cell_filters import _clamp_deform_grad_log
from torch_sim.state import SimState


Expand Down Expand Up @@ -428,6 +429,10 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915
if is_frechet: # Frechet: convert from log space to deformation gradient
cell_factor_reshaped = state.cell_factor.view(state.n_systems, 1, 1)
deform_grad_log_new = cell_positions_new / cell_factor_reshaped
deform_grad_log_new, cell_positions_new = _clamp_deform_grad_log(
deform_grad_log_new, cell_positions_new, cell_factor_reshaped
)
state.cell_positions = cell_positions_new
deform_grad_new = torch.matrix_exp(deform_grad_log_new)
else: # Unit cell: positions are scaled deformation gradient
cell_factor_expanded = state.cell_factor.expand(state.n_systems, 3, 1)
Expand Down
7 changes: 6 additions & 1 deletion torch_sim/optimizers/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch_sim as ts
from torch_sim.optimizers.cell_filters import (
CellLBFGSState,
_clamp_deform_grad_log,
compute_cell_forces,
deform_grad,
frechet_cell_filter_init,
Expand Down Expand Up @@ -486,7 +487,6 @@ def lbfgs_step( # noqa: PLR0915, C901
# Apply cell step
dr_cell = step_cell # [S, 3, 3]
cell_positions_new = state.cell_positions + dr_cell # [S, 3, 3]
state.cell_positions = cell_positions_new # [S, 3, 3]

# Determine if Frechet filter
init_fn, _step_fn = state.cell_filter
Expand All @@ -496,12 +496,17 @@ def lbfgs_step( # noqa: PLR0915, C901
# Frechet: deform_grad = exp(cell_positions / cell_factor)
cell_factor_reshaped = state.cell_factor.view(n_systems, 1, 1)
deform_grad_log_new = cell_positions_new / cell_factor_reshaped # [S, 3, 3]
deform_grad_log_new, cell_positions_new = _clamp_deform_grad_log(
deform_grad_log_new, cell_positions_new, cell_factor_reshaped
)
deform_grad_new = torch.matrix_exp(deform_grad_log_new) # [S, 3, 3]
else:
# UnitCell: deform_grad = cell_positions / cell_factor
cell_factor_expanded = state.cell_factor.expand(n_systems, 3, 1)
deform_grad_new = cell_positions_new / cell_factor_expanded # [S, 3, 3]

state.cell_positions = cell_positions_new # [S, 3, 3]

# Update cell: new_cell = reference_cell @ deform_grad^T
# Use set_constrained_cell to apply cell constraints (e.g. FixSymmetry)
new_col_vector_cell = torch.bmm(
Expand Down
Loading