Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 65 additions & 9 deletions dpdata/amber/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,70 @@
force_convert = energy_convert


def cell_lengths_angles_to_cell(
cell_lengths: np.ndarray, cell_angles: np.ndarray
) -> np.ndarray:
"""Convert cell lengths and angles to cell vectors.

Parameters
----------
cell_lengths
Cell lengths with shape ``(..., 3)`` where the last dimension is
``a, b, c``.
cell_angles
Cell angles in degrees with shape ``(..., 3)`` where the last dimension
is ``alpha, beta, gamma``.

Returns
-------
np.ndarray
Cell vectors with shape ``(..., 3, 3)``.
"""
alpha = np.deg2rad(cell_angles[..., 0])
beta = np.deg2rad(cell_angles[..., 1])
gamma = np.deg2rad(cell_angles[..., 2])

a = cell_lengths[..., 0]
b = cell_lengths[..., 1]
c = cell_lengths[..., 2]

if np.any(cell_lengths <= 0.0):
raise RuntimeError("Invalid AMBER cell lengths")
if np.any((cell_angles <= 0.0) | (cell_angles >= 180.0)):
raise RuntimeError("Invalid AMBER cell angles")

cos_alpha = np.cos(alpha)
cos_beta = np.cos(beta)
cos_gamma = np.cos(gamma)
sin_gamma = np.sin(gamma)
ly = b * sin_gamma
if np.any(ly <= 1e-8):
raise RuntimeError("Invalid AMBER cell angles")

z_factor = (
1
- cos_alpha**2
- cos_beta**2
- cos_gamma**2
+ 2 * cos_alpha * cos_beta * cos_gamma
)
lz2 = c**2 * z_factor / sin_gamma**2
if np.any(lz2 <= 1e-8):
raise RuntimeError("Invalid AMBER cell angles")

z = np.sqrt(z_factor) / sin_gamma

shape = (*cell_lengths.shape[:-1], 3, 3)
cells = np.zeros(shape)
cells[..., 0, 0] = a
cells[..., 1, 0] = b * cos_gamma
cells[..., 1, 1] = b * sin_gamma
cells[..., 2, 0] = c * cos_beta
cells[..., 2, 1] = c * (cos_alpha - cos_beta * cos_gamma) / sin_gamma
cells[..., 2, 2] = c * z
Comment thread
njzjz-bot marked this conversation as resolved.
return cells


def read_amber_traj(
parm7_file,
nc_file,
Expand Down Expand Up @@ -85,15 +149,7 @@ def read_amber_traj(
coords = np.array(f.variables["coordinates"][:])
cell_lengths = np.array(f.variables["cell_lengths"][:])
cell_angles = np.array(f.variables["cell_angles"][:])
if np.all(cell_angles > 89.99) and np.all(cell_angles < 90.01):
# only support 90
# TODO: support other angles
shape = cell_lengths.shape
cells = np.zeros((shape[0], 3, 3))
for ii in range(3):
cells[:, ii, ii] = cell_lengths[:, ii]
else:
raise RuntimeError("Unsupported cells")
cells = cell_lengths_angles_to_cell(cell_lengths, cell_angles)

if labeled:
with netcdf_file(mdfrc_file, "r") as f:
Expand Down
117 changes: 117 additions & 0 deletions tests/test_amber_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@

import os
import shutil
import tempfile
import unittest

import numpy as np
from comp_sys import CompLabeledSys, IsPBC
from context import dpdata

from dpdata.amber.md import cell_lengths_angles_to_cell

try:
import parmed # noqa: F401
except ModuleNotFoundError:
Expand All @@ -30,6 +34,119 @@ def tearDown(self):
shutil.rmtree("tmp.deepmd.npy")


class TestAmberMDNonOrthogonal(unittest.TestCase):
def test_cell_lengths_angles_to_cell(self):
cell_lengths = np.array([[10.0, 10.0, 15.0], [8.0, 10.0, 12.0]])
cell_angles = np.array([[90.0, 90.0, 120.0], [70.0, 80.0, 110.0]])

cells = cell_lengths_angles_to_cell(cell_lengths, cell_angles)

self.assertEqual(cells.shape, (2, 3, 3))
np.testing.assert_allclose(
cells[0],
np.array(
[
[10.0, 0.0, 0.0],
[-5.0, 5.0 * np.sqrt(3.0), 0.0],
[0.0, 0.0, 15.0],
]
),
atol=1e-12,
)
for frame, lengths, angles in zip(cells, cell_lengths, cell_angles):
np.testing.assert_allclose(np.linalg.norm(frame, axis=1), lengths)
alpha = np.rad2deg(
np.arccos(
np.dot(frame[1], frame[2])
/ (np.linalg.norm(frame[1]) * np.linalg.norm(frame[2]))
)
)
beta = np.rad2deg(
np.arccos(
np.dot(frame[0], frame[2])
/ (np.linalg.norm(frame[0]) * np.linalg.norm(frame[2]))
)
)
gamma = np.rad2deg(
np.arccos(
np.dot(frame[0], frame[1])
/ (np.linalg.norm(frame[0]) * np.linalg.norm(frame[1]))
)
)
np.testing.assert_allclose([alpha, beta, gamma], angles)

def test_invalid_cell_lengths(self):
cell_lengths = np.array([[0.0, 8.0, 12.0], [5.0, -8.0, 12.0]])
cell_angles = np.array([[90.0, 90.0, 90.0], [90.0, 90.0, 90.0]])

with self.assertRaisesRegex(RuntimeError, "Invalid AMBER cell lengths"):
cell_lengths_angles_to_cell(cell_lengths, cell_angles)

def test_invalid_cell_angles(self):
cell_lengths = np.array(
[
[5.0, 8.0, 12.0],
[5.0, 8.0, 12.0],
[5.0, 8.0, 12.0],
[5.0, 8.0, 12.0],
]
)
cell_angles = np.array(
[
[60.0, 70.0, 130.0],
[90.0, 90.0, 0.0],
[90.0, 90.0, 180.0],
[90.0, 90.0, 180.1],
]
)

with self.assertRaisesRegex(RuntimeError, "Invalid AMBER cell angles"):
cell_lengths_angles_to_cell(cell_lengths, cell_angles)

def test_read_amber_traj_with_nonorthogonal_cells(self):
from scipy.io import netcdf_file

cell_angles = np.array([90.0, 90.0, 120.0])
with tempfile.TemporaryDirectory() as tmpdir:
nc_file = os.path.join(tmpdir, "nonorthogonal.nc")
shutil.copy("amber/02_Heat.nc", nc_file)
with netcdf_file(nc_file, "a", mmap=False) as f:
cell_lengths = np.array(f.variables["cell_lengths"][:])
f.variables["cell_angles"].data[:] = cell_angles

system = dpdata.LabeledSystem(
"amber/02_Heat",
nc_file=nc_file,
fmt="amber/md",
)

cells = system.data["cells"]
self.assertEqual(system.get_nframes(), cell_lengths.shape[0])
np.testing.assert_allclose(
np.linalg.norm(cells, axis=2), cell_lengths, rtol=1e-7, atol=1e-7
)
dot_products = np.stack(
[
np.sum(cells[:, 1] * cells[:, 2], axis=1),
np.sum(cells[:, 0] * cells[:, 2], axis=1),
np.sum(cells[:, 0] * cells[:, 1], axis=1),
],
axis=1,
)
computed_angles = np.rad2deg(
np.arccos(
dot_products / cell_lengths[:, [1, 0, 0]] / cell_lengths[:, [2, 2, 1]]
)
)
np.testing.assert_allclose(
computed_angles,
np.broadcast_to(cell_angles, computed_angles.shape),
rtol=1e-7,
atol=1e-7,
)
self.assertTrue(np.any(np.abs(cells[:, 1, 0]) > 1e-7))


@unittest.skipIf(
skip_parmed_related_test, "skip parmed related test. install parmed to fix"
)
Expand Down
Loading