From e29e074eeb07880aa282c0374d7d4be12d538e2c Mon Sep 17 00:00:00 2001 From: SchrodingersCattt Date: Tue, 23 Jun 2026 10:04:44 +0000 Subject: [PATCH 1/4] feat(extxyz): support stress header, improve robustness and ASE compatibility MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add stress= header → virial conversion (virial = stress_sign * V * stress), supporting both 3×3 and Voigt (6-component) formats. Default stress_sign=-1 follows ASE convention. User can pass stress_sign=1 for opposite convention. Fixes #973. - Thread **kwargs from MultiSystems.from_file() through to the parser so that stress_sign (and future options) reach handle_single_xyz_frame(). - Tolerate unknown per-atom properties (magmom, charges, tags, etc.) with a warning instead of crashing with RuntimeError. - Recognize both 'force' and 'forces' as per-atom property keys. Writer now outputs 'forces' for ASE SinglePointCalculator compatibility (addresses qchempku2017's comment on #973). - Support flexible energy key lookup: try energy, Energy, free_energy, REF_energy, energies in order. - Handle missing Lattice gracefully (nopbc=True + 100Å dummy box) and parse pbc="F F F" / pbc="T T T" header field. - Writer now emits pbc field and stress field (stress = -virial/volume) alongside virial for better ASE interoperability. - Fix from_labeled_system to actually read files when called directly (previously was a broken passthrough causing TypeError). - Add comprehensive tests for all new functionality. --- dpdata/formats/xyz/quip_gap_xyz.py | 185 ++++++++++++++++------ dpdata/plugins/xyz.py | 12 +- tests/test_quip_gap_xyz.py | 237 +++++++++++++++++++++++++++++ tests/xyz/energy_key_variants.xyz | 5 + tests/xyz/forces_key.xyz | 5 + tests/xyz/nopbc.xyz | 5 + tests/xyz/pbc_false.xyz | 5 + tests/xyz/stress_only.xyz | 5 + tests/xyz/stress_voigt.xyz | 5 + tests/xyz/unknown_props.xyz | 5 + 10 files changed, 421 insertions(+), 48 deletions(-) create mode 100644 tests/xyz/energy_key_variants.xyz create mode 100644 tests/xyz/forces_key.xyz create mode 100644 tests/xyz/nopbc.xyz create mode 100644 tests/xyz/pbc_false.xyz create mode 100644 tests/xyz/stress_only.xyz create mode 100644 tests/xyz/stress_voigt.xyz create mode 100644 tests/xyz/unknown_props.xyz diff --git a/dpdata/formats/xyz/quip_gap_xyz.py b/dpdata/formats/xyz/quip_gap_xyz.py index 71e976de6..847878bd6 100644 --- a/dpdata/formats/xyz/quip_gap_xyz.py +++ b/dpdata/formats/xyz/quip_gap_xyz.py @@ -3,25 +3,68 @@ from __future__ import annotations import re +import warnings from collections import OrderedDict import numpy as np from dpdata.periodic_table import Element +# Possible keys for the energy field in the extxyz comment line, +# checked in order of priority. +_ENERGY_KEYS = ("energy", "Energy", "free_energy", "REF_energy", "energies") + + +def _parse_stress_to_virials(stress_str, cell, stress_sign=-1): + """Convert a stress field string to virial tensor. + + Parameters + ---------- + stress_str : str + Space-separated stress values. Accepts either 9 values (3×3 matrix, + row-major) or 6 values (Voigt notation: xx yy zz yz xz xy). + cell : np.ndarray + 3×3 cell matrix (Å). + stress_sign : int + Sign convention for ``virial = stress_sign * volume * stress``. + Default ``-1`` follows the ASE convention where + ``virial = -V * stress`` (stress in eV/ų). + + Returns + ------- + np.ndarray + Virial tensor with shape ``(1, 3, 3)`` in eV. + """ + vals = list(filter(bool, stress_str.split(" "))) + vals = np.array(vals, dtype=np.float64) + if len(vals) == 9: + stress = vals.reshape(3, 3) + elif len(vals) == 6: + # Voigt order: xx yy zz yz xz xy + xx, yy, zz, yz, xz, xy = vals + stress = np.array([[xx, xy, xz], [xy, yy, yz], [xz, yz, zz]]) + else: + raise ValueError( + f"stress field must have 6 (Voigt) or 9 (3×3) values, got {len(vals)}" + ) + volume = abs(np.linalg.det(cell)) + virials = stress_sign * volume * stress + return np.array([virials]) + class QuipGapxyzSystems: - """deal with QuipGapxyzFile.""" + """Parse an extended XYZ (QUIP/GAP) file frame by frame.""" - def __init__(self, file_name): + def __init__(self, file_name, **kwargs): self.file_object = open(file_name) + self.kwargs = kwargs self.block_generator = self.get_block_generator() def __iter__(self): return self def __next__(self): - return self.handle_single_xyz_frame(next(self.block_generator)) + return self.handle_single_xyz_frame(next(self.block_generator), **self.kwargs) def __del__(self): self.file_object.close() @@ -45,7 +88,20 @@ def get_block_generator(self): yield lines @staticmethod - def handle_single_xyz_frame(lines): + def handle_single_xyz_frame(lines, stress_sign=-1, **kwargs): + """Parse a single extended XYZ frame. + + Parameters + ---------- + lines : list[str] + Raw lines for one frame (atom count + comment + atom lines). + stress_sign : int, optional + Sign convention for stress→virial conversion. + ``-1`` (default) follows the ASE convention: + ``virial = -V * stress``. + **kwargs : dict + Additional keyword arguments (reserved for future use). + """ atom_num = int(lines[0].strip("\n").strip()) if len(lines) != atom_num + 2: raise RuntimeError( @@ -82,59 +138,53 @@ def handle_single_xyz_frame(lines): coords_array = None Z_array = None force_array = None - virials = None for kv_dict in prop_list: - if kv_dict["key"] == "species": + field_length = int(kv_dict["value"]) + key = kv_dict["key"] + + if key == "species": if kv_dict["datatype"] != "S": raise RuntimeError( - "datatype for species must be 'S' instead of {}".format( - kv_dict["datatype"] - ) + f"datatype for species must be 'S' instead of {kv_dict['datatype']}" ) - field_length = int(kv_dict["value"]) type_array = data_array[ :, used_colomn : used_colomn + field_length ].flatten() used_colomn += field_length - continue - elif kv_dict["key"] == "pos": + elif key == "pos": if kv_dict["datatype"] != "R": raise RuntimeError( - "datatype for pos must be 'R' instead of {}".format( - kv_dict["datatype"] - ) + f"datatype for pos must be 'R' instead of {kv_dict['datatype']}" ) - field_length = int(kv_dict["value"]) coords_array = data_array[:, used_colomn : used_colomn + field_length] used_colomn += field_length - continue - elif kv_dict["key"] == "Z": + elif key == "Z": if kv_dict["datatype"] != "I": raise RuntimeError( - "datatype for pos must be 'R' instead of {}".format( - kv_dict["datatype"] - ) + f"datatype for Z must be 'I' instead of {kv_dict['datatype']}" ) - field_length = int(kv_dict["value"]) Z_array = data_array[ :, used_colomn : used_colomn + field_length ].flatten() used_colomn += field_length - continue - elif kv_dict["key"] == "force": + elif key in ("force", "forces"): if kv_dict["datatype"] != "R": raise RuntimeError( - "datatype for pos must be 'R' instead of {}".format( - kv_dict["datatype"] - ) + f"datatype for {key} must be 'R' instead of {kv_dict['datatype']}" ) - field_length = int(kv_dict["value"]) force_array = data_array[:, used_colomn : used_colomn + field_length] used_colomn += field_length - continue else: - raise RuntimeError("unknown field {}".format(kv_dict["key"])) + # Skip unknown per-atom properties (e.g. magmom, charges, + # tags, local_energy) instead of crashing. + warnings.warn( + f"Skipping unknown per-atom property '{key}' " + f"(type={kv_dict['datatype']}, width={field_length})", + stacklevel=2, + ) + used_colomn += field_length + # --- atom type bookkeeping --- type_num_dict = OrderedDict() atom_type_list = [] type_map = {} @@ -156,7 +206,30 @@ def handle_single_xyz_frame(lines): for atom_type, atom_num in type_num_dict.items(): type_num_list.append((atom_type, atom_num)) type_num_array = np.array(type_num_list) - if field_dict.get("virial", None): + + # --- cells / Lattice (parsed early so volume is available for stress→virial) --- + info_dict = {} + if "Lattice" in field_dict and field_dict["Lattice"].strip(): + lattice_values = list(filter(bool, field_dict["Lattice"].split(" "))) + cells = np.array(lattice_values, dtype=np.float64).reshape(3, 3) + info_dict["cells"] = np.array([cells]) + info_dict["nopbc"] = False + else: + cells = np.diag([100.0, 100.0, 100.0]) + info_dict["cells"] = np.array([cells]) + info_dict["nopbc"] = True + + # Override nopbc if explicit pbc field is present + if "pbc" in field_dict: + pbc_flags = field_dict["pbc"].replace('"', "").replace("'", "").split() + if all(f.upper() in ("F", "FALSE", "0") for f in pbc_flags): + info_dict["nopbc"] = True + elif all(f.upper() in ("T", "TRUE", "1") for f in pbc_flags): + info_dict["nopbc"] = False + + # --- virial / stress --- + virials = None + if field_dict.get("virial"): virials = np.array( [ np.array( @@ -164,22 +237,29 @@ def handle_single_xyz_frame(lines): ).reshape(3, 3) ] ).astype(np.float64) - else: - virials = None + elif field_dict.get("stress"): + virials = _parse_stress_to_virials( + field_dict["stress"], cells, stress_sign=stress_sign + ) - info_dict = {} + # --- energy (try several common keys) --- + energy_value = None + for ekey in _ENERGY_KEYS: + if ekey in field_dict: + energy_value = field_dict[ekey] + break + if energy_value is None: + raise ValueError( + f"No energy field found in extxyz comment line. " + f"Tried: {_ENERGY_KEYS}. Available keys: {list(field_dict.keys())}" + ) + + # --- assemble output --- info_dict["atom_names"] = list(type_num_array[:, 0]) info_dict["atom_numbs"] = list(type_num_array[:, 1].astype(int)) info_dict["atom_types"] = np.array(atom_type_list).astype(int) - info_dict["cells"] = np.array( - [ - np.array(list(filter(bool, field_dict["Lattice"].split(" ")))).reshape( - 3, 3 - ) - ] - ).astype(np.float64) info_dict["coords"] = np.array([coords_array]).astype(np.float64) - info_dict["energies"] = np.array([field_dict["energy"]]).astype(np.float64) + info_dict["energies"] = np.array([energy_value]).astype(np.float64) info_dict["forces"] = np.array([force_array]).astype(np.float64) if virials is not None: info_dict["virials"] = virials @@ -188,7 +268,7 @@ def handle_single_xyz_frame(lines): def format_single_frame(data, frame_idx): - """Format a single frame of system data into QUIP/GAP XYZ format lines. + """Format a single frame of system data into extended XYZ format lines. Parameters ---------- @@ -212,19 +292,32 @@ def format_single_frame(data, frame_idx): energy = data["energies"][frame_idx] header_parts.append(f"energy={energy:.12e}") - # Virial (if present) + # Virial and stress (if present) if "virials" in data: virial = data["virials"][frame_idx] virial_str = " ".join(f"{v:.12e}" for v in virial.flatten()) header_parts.append(f'virial="{virial_str}"') + # Also write stress for ASE compatibility: stress = -virial / volume + cell = data["cells"][frame_idx] + volume = abs(np.linalg.det(cell)) + if volume > 0: + stress = -virial / volume + stress_str = " ".join(f"{s:.12e}" for s in stress.flatten()) + header_parts.append(f'stress="{stress_str}"') # Lattice cell = data["cells"][frame_idx] lattice_str = " ".join(f"{c:.12e}" for c in cell.flatten()) header_parts.append(f'Lattice="{lattice_str}"') - # Properties - header_parts.append("Properties=species:S:1:pos:R:3:Z:I:1:force:R:3") + # pbc + if data.get("nopbc", False): + header_parts.append('pbc="F F F"') + else: + header_parts.append('pbc="T T T"') + + # Properties — use "forces" for ASE compatibility (not "force") + header_parts.append("Properties=species:S:1:pos:R:3:Z:I:1:forces:R:3") header_line = " ".join(header_parts) diff --git a/dpdata/plugins/xyz.py b/dpdata/plugins/xyz.py index 63aaeabe3..14f6bac2a 100644 --- a/dpdata/plugins/xyz.py +++ b/dpdata/plugins/xyz.py @@ -56,11 +56,19 @@ def from_system(self, file_name: FileType, **kwargs): @Format.register("mace/xyz") class QuipGapXYZFormat(Format): def from_labeled_system(self, data, **kwargs): - return data + # When called via from_multi_systems iteration, data is already + # a parsed info_dict — return as-is. + if isinstance(data, dict): + return data + # When called directly with a filename, read the first frame. + file_name = data + for frame in QuipGapxyzSystems(file_name, **kwargs): + return frame + raise RuntimeError(f"No frames found in {file_name}") def from_multi_systems(self, file_name, **kwargs): # here directory is the file_name - return QuipGapxyzSystems(file_name) + return QuipGapxyzSystems(file_name, **kwargs) def to_labeled_system(self, data, file_name: FileType, **kwargs): """Write LabeledSystem data to QUIP/GAP XYZ format file. diff --git a/tests/test_quip_gap_xyz.py b/tests/test_quip_gap_xyz.py index a265544ce..efc9483dc 100644 --- a/tests/test_quip_gap_xyz.py +++ b/tests/test_quip_gap_xyz.py @@ -1,7 +1,10 @@ from __future__ import annotations +import tempfile import unittest +import warnings +import numpy as np from comp_sys import CompLabeledSys, IsPBC from context import dpdata @@ -116,5 +119,239 @@ def setUp(self): self.f_places = 6 +# ---------- stress / virial conversion (fixes #973) ---------- + + +class TestStressToVirial(unittest.TestCase): + """Read extxyz with stress= header → virials via -V*stress.""" + + def setUp(self): + self.ms = dpdata.MultiSystems.from_file("xyz/stress_only.xyz", fmt="extxyz") + self.system = list(self.ms.systems.values())[0] + + def test_has_virials(self): + self.assertIn("virials", self.system.data) + + def test_virial_values(self): + """Virial = -V * stress, V=27, stress=diag(0.01,0.02,0.03).""" + expected = np.array([[[-0.27, 0, 0], [0, -0.54, 0], [0, 0, -0.81]]]) + np.testing.assert_allclose(self.system.data["virials"], expected, atol=1e-10) + + def test_energy(self): + np.testing.assert_allclose(self.system.data["energies"], [-1.5]) + + def test_forces(self): + expected = np.array([[[0.1, 0, 0], [-0.1, 0, 0], [0, 0, 0]]]) + np.testing.assert_allclose(self.system.data["forces"], expected) + + +class TestStressVoigt(unittest.TestCase): + """Read extxyz with 6-component Voigt stress.""" + + def setUp(self): + self.ms = dpdata.MultiSystems.from_file("xyz/stress_voigt.xyz", fmt="extxyz") + self.system = list(self.ms.systems.values())[0] + + def test_has_virials(self): + self.assertIn("virials", self.system.data) + + def test_virial_values(self): + """Voigt stress=[0.01,0.02,0.03,0.004,0.005,0.006] → 3×3 → virial=-V*stress.""" + stress = np.array( + [[0.01, 0.006, 0.005], [0.006, 0.02, 0.004], [0.005, 0.004, 0.03]] + ) + expected = np.array([-27.0 * stress]) + np.testing.assert_allclose(self.system.data["virials"], expected, atol=1e-10) + + +class TestStressSignPositive(unittest.TestCase): + """Test stress_sign=1 (opposite convention).""" + + def setUp(self): + self.ms = dpdata.MultiSystems.from_file( + "xyz/stress_only.xyz", fmt="extxyz", stress_sign=1 + ) + self.system = list(self.ms.systems.values())[0] + + def test_virial_values_positive(self): + """Virial = +V * stress with stress_sign=1.""" + expected = np.array([[[0.27, 0, 0], [0, 0.54, 0], [0, 0, 0.81]]]) + np.testing.assert_allclose(self.system.data["virials"], expected, atol=1e-10) + + +# ---------- robustness ---------- + + +class TestUnknownProperties(unittest.TestCase): + """Read extxyz with extra per-atom props (magmom) — should warn, not crash.""" + + def test_parses_without_crash(self): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + ms = dpdata.MultiSystems.from_file("xyz/unknown_props.xyz", fmt="extxyz") + system = list(ms.systems.values())[0] + self.assertEqual(system.get_nframes(), 1) + warn_msgs = [str(x.message) for x in w] + self.assertTrue( + any("magmom" in msg for msg in warn_msgs), + f"Expected warning about 'magmom', got: {warn_msgs}", + ) + + def test_forces_correct(self): + ms = dpdata.MultiSystems.from_file("xyz/unknown_props.xyz", fmt="extxyz") + system = list(ms.systems.values())[0] + expected = np.array([[[0.1, 0, 0], [-0.1, 0, 0], [0, 0, 0]]]) + np.testing.assert_allclose(system.data["forces"], expected) + + +class TestForcesKey(unittest.TestCase): + """Read extxyz using 'forces' (ASE style) instead of 'force'.""" + + def setUp(self): + self.ms = dpdata.MultiSystems.from_file("xyz/forces_key.xyz", fmt="extxyz") + self.system = list(self.ms.systems.values())[0] + + def test_forces_parsed(self): + expected = np.array([[[0.1, 0, 0], [-0.1, 0, 0], [0, 0, 0]]]) + np.testing.assert_allclose(self.system.data["forces"], expected) + + +class TestNoPBC(unittest.TestCase): + """Read extxyz without Lattice → nopbc=True, dummy 100Å box.""" + + def setUp(self): + self.ms = dpdata.MultiSystems.from_file("xyz/nopbc.xyz", fmt="extxyz") + self.system = list(self.ms.systems.values())[0] + + def test_nopbc(self): + self.assertTrue(self.system.data.get("nopbc", False)) + + def test_dummy_cell(self): + expected = np.diag([100.0, 100.0, 100.0]) + np.testing.assert_allclose(self.system.data["cells"][0], expected) + + def test_energy(self): + np.testing.assert_allclose(self.system.data["energies"], [-1.5]) + + +class TestEnergyKeyVariants(unittest.TestCase): + """Read extxyz with 'Energy' instead of 'energy'.""" + + def setUp(self): + self.ms = dpdata.MultiSystems.from_file( + "xyz/energy_key_variants.xyz", fmt="extxyz" + ) + self.system = list(self.ms.systems.values())[0] + + def test_energy_parsed(self): + np.testing.assert_allclose(self.system.data["energies"], [-1.5]) + + +class TestPBCFalseHeader(unittest.TestCase): + """Read extxyz with pbc='F F F' → nopbc=True even with Lattice present.""" + + def setUp(self): + self.ms = dpdata.MultiSystems.from_file("xyz/pbc_false.xyz", fmt="extxyz") + self.system = list(self.ms.systems.values())[0] + + def test_nopbc(self): + self.assertTrue(self.system.data.get("nopbc", False)) + + def test_cell_still_parsed(self): + """Lattice should still be read even if pbc=F.""" + expected = np.diag([3.0, 3.0, 3.0]) + np.testing.assert_allclose(self.system.data["cells"][0], expected) + + +# ---------- writer ---------- + + +class TestWriteForcesKey(unittest.TestCase): + """Writer should output 'forces' (not 'force') for ASE compat.""" + + def test_output_has_forces_key(self): + ms = dpdata.MultiSystems.from_file("xyz/stress_only.xyz", fmt="extxyz") + system = list(ms.systems.values())[0] + with tempfile.NamedTemporaryFile(mode="w", suffix=".xyz", delete=False) as f: + system.to("extxyz", f.name) + f.flush() + with open(f.name) as fread: + content = fread.read() + self.assertIn("forces:R:3", content) + self.assertNotIn("force:R:3", content) + + +class TestWriteStressField(unittest.TestCase): + """Writer should output stress= alongside virial= when virials present.""" + + def test_output_has_stress(self): + ms = dpdata.MultiSystems.from_file("xyz/xyz_unittest.xyz", fmt="quip/gap/xyz") + system = list(ms.systems.values())[0] + with tempfile.NamedTemporaryFile(mode="w", suffix=".xyz", delete=False) as f: + system.to("extxyz", f.name) + f.flush() + with open(f.name) as fread: + content = fread.read() + self.assertIn("virial=", content) + self.assertIn("stress=", content) + + +class TestWritePBC(unittest.TestCase): + """Writer should output pbc field.""" + + def test_pbc_true(self): + ms = dpdata.MultiSystems.from_file("xyz/xyz_unittest.xyz", fmt="quip/gap/xyz") + system = list(ms.systems.values())[0] + with tempfile.NamedTemporaryFile(mode="w", suffix=".xyz", delete=False) as f: + system.to("extxyz", f.name) + f.flush() + with open(f.name) as fread: + content = fread.read() + self.assertIn('pbc="T T T"', content) + + def test_pbc_false(self): + ms = dpdata.MultiSystems.from_file("xyz/nopbc.xyz", fmt="extxyz") + system = list(ms.systems.values())[0] + with tempfile.NamedTemporaryFile(mode="w", suffix=".xyz", delete=False) as f: + system.to("extxyz", f.name) + f.flush() + with open(f.name) as fread: + content = fread.read() + self.assertIn('pbc="F F F"', content) + + +# ---------- roundtrip ---------- + + +class TestRoundtripStress(unittest.TestCase): + """Write extxyz with stress → read back → virials preserved.""" + + def test_roundtrip(self): + ms1 = dpdata.MultiSystems.from_file("xyz/stress_only.xyz", fmt="extxyz") + sys1 = list(ms1.systems.values())[0] + with tempfile.NamedTemporaryFile(mode="w", suffix=".xyz", delete=False) as f: + sys1.to("extxyz", f.name) + f.flush() + ms2 = dpdata.MultiSystems.from_file(f.name, fmt="extxyz") + sys2 = list(ms2.systems.values())[0] + np.testing.assert_allclose( + sys1.data["virials"], sys2.data["virials"], atol=1e-6 + ) + np.testing.assert_allclose( + sys1.data["energies"], sys2.data["energies"], atol=1e-10 + ) + np.testing.assert_allclose(sys1.data["forces"], sys2.data["forces"], atol=1e-10) + + +class TestFromLabeledSystemDirect(unittest.TestCase): + """LabeledSystem('file.xyz', fmt='extxyz') should work directly.""" + + def test_direct_read(self): + system = dpdata.LabeledSystem("xyz/stress_only.xyz", fmt="extxyz") + self.assertEqual(system.get_nframes(), 1) + np.testing.assert_allclose(system.data["energies"], [-1.5]) + self.assertIn("virials", system.data) + + if __name__ == "__main__": unittest.main() diff --git a/tests/xyz/energy_key_variants.xyz b/tests/xyz/energy_key_variants.xyz new file mode 100644 index 000000000..3cf18f630 --- /dev/null +++ b/tests/xyz/energy_key_variants.xyz @@ -0,0 +1,5 @@ +3 +Lattice="3.0 0.0 0.0 0.0 3.0 0.0 0.0 0.0 3.0" Energy=-1.5 Properties=species:S:1:pos:R:3:Z:I:1:force:R:3 +H 0.0 0.0 0.0 1 0.1 0.0 0.0 +H 1.0 0.0 0.0 1 -0.1 0.0 0.0 +O 2.0 0.0 0.0 8 0.0 0.0 0.0 diff --git a/tests/xyz/forces_key.xyz b/tests/xyz/forces_key.xyz new file mode 100644 index 000000000..ddbda21a5 --- /dev/null +++ b/tests/xyz/forces_key.xyz @@ -0,0 +1,5 @@ +3 +Lattice="3.0 0.0 0.0 0.0 3.0 0.0 0.0 0.0 3.0" energy=-1.5 Properties=species:S:1:pos:R:3:Z:I:1:forces:R:3 +H 0.0 0.0 0.0 1 0.1 0.0 0.0 +H 1.0 0.0 0.0 1 -0.1 0.0 0.0 +O 2.0 0.0 0.0 8 0.0 0.0 0.0 diff --git a/tests/xyz/nopbc.xyz b/tests/xyz/nopbc.xyz new file mode 100644 index 000000000..a11ce6a4e --- /dev/null +++ b/tests/xyz/nopbc.xyz @@ -0,0 +1,5 @@ +3 +energy=-1.5 Properties=species:S:1:pos:R:3:Z:I:1:force:R:3 +H 0.0 0.0 0.0 1 0.1 0.0 0.0 +H 1.0 0.0 0.0 1 -0.1 0.0 0.0 +O 2.0 0.0 0.0 8 0.0 0.0 0.0 diff --git a/tests/xyz/pbc_false.xyz b/tests/xyz/pbc_false.xyz new file mode 100644 index 000000000..c14cda148 --- /dev/null +++ b/tests/xyz/pbc_false.xyz @@ -0,0 +1,5 @@ +3 +Lattice="3.0 0.0 0.0 0.0 3.0 0.0 0.0 0.0 3.0" energy=-1.5 pbc="F F F" Properties=species:S:1:pos:R:3:Z:I:1:force:R:3 +H 0.0 0.0 0.0 1 0.1 0.0 0.0 +H 1.0 0.0 0.0 1 -0.1 0.0 0.0 +O 2.0 0.0 0.0 8 0.0 0.0 0.0 diff --git a/tests/xyz/stress_only.xyz b/tests/xyz/stress_only.xyz new file mode 100644 index 000000000..c7e5adc24 --- /dev/null +++ b/tests/xyz/stress_only.xyz @@ -0,0 +1,5 @@ +3 +Lattice="3.0 0.0 0.0 0.0 3.0 0.0 0.0 0.0 3.0" energy=-1.5 stress="0.01 0.0 0.0 0.0 0.02 0.0 0.0 0.0 0.03" Properties=species:S:1:pos:R:3:Z:I:1:force:R:3 +H 0.0 0.0 0.0 1 0.1 0.0 0.0 +H 1.0 0.0 0.0 1 -0.1 0.0 0.0 +O 2.0 0.0 0.0 8 0.0 0.0 0.0 diff --git a/tests/xyz/stress_voigt.xyz b/tests/xyz/stress_voigt.xyz new file mode 100644 index 000000000..455cac52b --- /dev/null +++ b/tests/xyz/stress_voigt.xyz @@ -0,0 +1,5 @@ +3 +Lattice="3.0 0.0 0.0 0.0 3.0 0.0 0.0 0.0 3.0" energy=-1.5 stress="0.01 0.02 0.03 0.004 0.005 0.006" Properties=species:S:1:pos:R:3:Z:I:1:force:R:3 +H 0.0 0.0 0.0 1 0.1 0.0 0.0 +H 1.0 0.0 0.0 1 -0.1 0.0 0.0 +O 2.0 0.0 0.0 8 0.0 0.0 0.0 diff --git a/tests/xyz/unknown_props.xyz b/tests/xyz/unknown_props.xyz new file mode 100644 index 000000000..aed8c4e8e --- /dev/null +++ b/tests/xyz/unknown_props.xyz @@ -0,0 +1,5 @@ +3 +Lattice="3.0 0.0 0.0 0.0 3.0 0.0 0.0 0.0 3.0" energy=-1.5 Properties=species:S:1:pos:R:3:Z:I:1:magmom:R:1:force:R:3 +H 0.0 0.0 0.0 1 0.5 0.1 0.0 0.0 +H 1.0 0.0 0.0 1 -0.5 -0.1 0.0 0.0 +O 2.0 0.0 0.0 8 0.0 0.0 0.0 0.0 From e57ad413dbed677e6698c13f3f03a2126343ff35 Mon Sep 17 00:00:00 2001 From: SchrodingersCattt Date: Tue, 23 Jun 2026 10:17:28 +0000 Subject: [PATCH 2/4] style(extxyz): replace unicode chars with ASCII for RUF002 compliance --- dpdata/formats/xyz/quip_gap_xyz.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dpdata/formats/xyz/quip_gap_xyz.py b/dpdata/formats/xyz/quip_gap_xyz.py index 847878bd6..9caa29f90 100644 --- a/dpdata/formats/xyz/quip_gap_xyz.py +++ b/dpdata/formats/xyz/quip_gap_xyz.py @@ -21,14 +21,14 @@ def _parse_stress_to_virials(stress_str, cell, stress_sign=-1): Parameters ---------- stress_str : str - Space-separated stress values. Accepts either 9 values (3×3 matrix, + Space-separated stress values. Accepts either 9 values (3x3 matrix, row-major) or 6 values (Voigt notation: xx yy zz yz xz xy). cell : np.ndarray - 3×3 cell matrix (Å). + 3x3 cell matrix (angstrom). stress_sign : int Sign convention for ``virial = stress_sign * volume * stress``. Default ``-1`` follows the ASE convention where - ``virial = -V * stress`` (stress in eV/ų). + ``virial = -V * stress`` (stress in eV/angstrom^3). Returns ------- @@ -45,7 +45,7 @@ def _parse_stress_to_virials(stress_str, cell, stress_sign=-1): stress = np.array([[xx, xy, xz], [xy, yy, yz], [xz, yz, zz]]) else: raise ValueError( - f"stress field must have 6 (Voigt) or 9 (3×3) values, got {len(vals)}" + f"stress field must have 6 (Voigt) or 9 (3x3) values, got {len(vals)}" ) volume = abs(np.linalg.det(cell)) virials = stress_sign * volume * stress From 4d8048b51ee0cc00caa4854844c28b7f892fd908 Mon Sep 17 00:00:00 2001 From: SchrodingersCattt Date: Tue, 23 Jun 2026 10:19:36 +0000 Subject: [PATCH 3/4] refactor(extxyz): extract _FORCE_KEYS constant for consistency with _ENERGY_KEYS --- dpdata/formats/xyz/quip_gap_xyz.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dpdata/formats/xyz/quip_gap_xyz.py b/dpdata/formats/xyz/quip_gap_xyz.py index 9caa29f90..b35b99572 100644 --- a/dpdata/formats/xyz/quip_gap_xyz.py +++ b/dpdata/formats/xyz/quip_gap_xyz.py @@ -14,6 +14,9 @@ # checked in order of priority. _ENERGY_KEYS = ("energy", "Energy", "free_energy", "REF_energy", "energies") +# Accepted per-atom property names for forces. +_FORCE_KEYS = ("force", "forces") + def _parse_stress_to_virials(stress_str, cell, stress_sign=-1): """Convert a stress field string to virial tensor. @@ -167,7 +170,7 @@ def handle_single_xyz_frame(lines, stress_sign=-1, **kwargs): :, used_colomn : used_colomn + field_length ].flatten() used_colomn += field_length - elif key in ("force", "forces"): + elif key in _FORCE_KEYS: if kv_dict["datatype"] != "R": raise RuntimeError( f"datatype for {key} must be 'R' instead of {kv_dict['datatype']}" From dd585ffea85f978907107f432a58c8366829a2d2 Mon Sep 17 00:00:00 2001 From: SchrodingersCattt Date: Tue, 23 Jun 2026 10:23:32 +0000 Subject: [PATCH 4/4] feat(extxyz): support 'virials'/'stresses' key variants with _VIRIAL_KEYS/_STRESS_KEYS constants Recognize both singular and plural forms for virial and stress header keys, consistent with _ENERGY_KEYS and _FORCE_KEYS patterns. --- dpdata/formats/xyz/quip_gap_xyz.py | 29 ++++++++++++++++++++-------- tests/test_quip_gap_xyz.py | 31 ++++++++++++++++++++++++++++++ tests/xyz/stresses_key.xyz | 5 +++++ tests/xyz/virials_key.xyz | 5 +++++ 4 files changed, 62 insertions(+), 8 deletions(-) create mode 100644 tests/xyz/stresses_key.xyz create mode 100644 tests/xyz/virials_key.xyz diff --git a/dpdata/formats/xyz/quip_gap_xyz.py b/dpdata/formats/xyz/quip_gap_xyz.py index b35b99572..fe4158728 100644 --- a/dpdata/formats/xyz/quip_gap_xyz.py +++ b/dpdata/formats/xyz/quip_gap_xyz.py @@ -17,6 +17,12 @@ # Accepted per-atom property names for forces. _FORCE_KEYS = ("force", "forces") +# Accepted header keys for virial tensor. +_VIRIAL_KEYS = ("virial", "virials") + +# Accepted header keys for stress tensor. +_STRESS_KEYS = ("stress", "stresses") + def _parse_stress_to_virials(stress_str, cell, stress_sign=-1): """Convert a stress field string to virial tensor. @@ -232,17 +238,24 @@ def handle_single_xyz_frame(lines, stress_sign=-1, **kwargs): # --- virial / stress --- virials = None - if field_dict.get("virial"): + virial_raw = None + for vkey in _VIRIAL_KEYS: + if field_dict.get(vkey): + virial_raw = field_dict[vkey] + break + stress_raw = None + for skey in _STRESS_KEYS: + if field_dict.get(skey): + stress_raw = field_dict[skey] + break + + if virial_raw is not None: virials = np.array( - [ - np.array( - list(filter(bool, field_dict["virial"].split(" "))) - ).reshape(3, 3) - ] + [np.array(list(filter(bool, virial_raw.split(" ")))).reshape(3, 3)] ).astype(np.float64) - elif field_dict.get("stress"): + elif stress_raw is not None: virials = _parse_stress_to_virials( - field_dict["stress"], cells, stress_sign=stress_sign + stress_raw, cells, stress_sign=stress_sign ) # --- energy (try several common keys) --- diff --git a/tests/test_quip_gap_xyz.py b/tests/test_quip_gap_xyz.py index efc9483dc..67cb1a715 100644 --- a/tests/test_quip_gap_xyz.py +++ b/tests/test_quip_gap_xyz.py @@ -179,6 +179,37 @@ def test_virial_values_positive(self): np.testing.assert_allclose(self.system.data["virials"], expected, atol=1e-10) +class TestVirialsKey(unittest.TestCase): + """Read extxyz with 'virials' (plural) instead of 'virial'.""" + + def setUp(self): + self.ms = dpdata.MultiSystems.from_file("xyz/virials_key.xyz", fmt="extxyz") + self.system = list(self.ms.systems.values())[0] + + def test_has_virials(self): + self.assertIn("virials", self.system.data) + + def test_virial_values(self): + expected = np.array([[[0.27, 0, 0], [0, 0.54, 0], [0, 0, 0.81]]]) + np.testing.assert_allclose(self.system.data["virials"], expected, atol=1e-10) + + +class TestStressesKey(unittest.TestCase): + """Read extxyz with 'stresses' (plural) instead of 'stress'.""" + + def setUp(self): + self.ms = dpdata.MultiSystems.from_file("xyz/stresses_key.xyz", fmt="extxyz") + self.system = list(self.ms.systems.values())[0] + + def test_has_virials(self): + self.assertIn("virials", self.system.data) + + def test_virial_values(self): + """Same as stress_only: virial = -27 * diag(0.01,0.02,0.03).""" + expected = np.array([[[-0.27, 0, 0], [0, -0.54, 0], [0, 0, -0.81]]]) + np.testing.assert_allclose(self.system.data["virials"], expected, atol=1e-10) + + # ---------- robustness ---------- diff --git a/tests/xyz/stresses_key.xyz b/tests/xyz/stresses_key.xyz new file mode 100644 index 000000000..f3d12d2e6 --- /dev/null +++ b/tests/xyz/stresses_key.xyz @@ -0,0 +1,5 @@ +3 +Lattice="3.0 0.0 0.0 0.0 3.0 0.0 0.0 0.0 3.0" energy=-1.5 stresses="0.01 0.0 0.0 0.0 0.02 0.0 0.0 0.0 0.03" Properties=species:S:1:pos:R:3:Z:I:1:force:R:3 +H 0.0 0.0 0.0 1 0.1 0.0 0.0 +H 1.0 0.0 0.0 1 -0.1 0.0 0.0 +O 2.0 0.0 0.0 8 0.0 0.0 0.0 diff --git a/tests/xyz/virials_key.xyz b/tests/xyz/virials_key.xyz new file mode 100644 index 000000000..51d218620 --- /dev/null +++ b/tests/xyz/virials_key.xyz @@ -0,0 +1,5 @@ +3 +Lattice="3.0 0.0 0.0 0.0 3.0 0.0 0.0 0.0 3.0" energy=-1.5 virials="0.27 0 0 0 0.54 0 0 0 0.81" Properties=species:S:1:pos:R:3:Z:I:1:force:R:3 +H 0.0 0.0 0.0 1 0.1 0.0 0.0 +H 1.0 0.0 0.0 1 -0.1 0.0 0.0 +O 2.0 0.0 0.0 8 0.0 0.0 0.0