Skip to content

Commit

Permalink
Add functionality to read ASE *.traj file in Trajectory class method …
Browse files Browse the repository at this point in the history
…from_file() (materialsproject#3422)

* Add functionality to read ASE *.traj files in Trajectory.from_file().

* fix Trajectory.from_file turning abs path into file name

---------

Signed-off-by: Jingyang Wang <[email protected]>
Signed-off-by: Janosh Riebesell <[email protected]>
Co-authored-by: Janosh Riebesell <[email protected]>
  • Loading branch information
exenGT and janosh authored Dec 23, 2023
1 parent ed6da26 commit c67739a
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 20 deletions.
4 changes: 2 additions & 2 deletions pymatgen/analysis/elasticity/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,8 +927,8 @@ def get_strain_state_dict(strains, stresses, eq_stress=None, tol: float = 1e-10,
stresses (Nx3x3 array-like): stress matrices
eq_stress (Nx3x3 array-like): equilibrium stress
tol (float): tolerance for sorting strain states
add_eq (bool): flag for whether to add eq_strain
to stress-strain sets for each strain state
add_eq (bool): Whether to add eq_strain to stress-strain sets for each strain state.
Defaults to True.
sort (bool): flag for whether to sort strain states
Returns:
Expand Down
2 changes: 1 addition & 1 deletion pymatgen/analysis/magnetism/heisenberg.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def _get_nn_dict(self):
all_dists += [0] * (3 - len(all_dists))

all_dists = all_dists[:3]
labels = ["nn", "nnn", "nnnn"]
labels = ("nn", "nnn", "nnnn")
dists = dict(zip(labels, all_dists))

# Get dictionary keys for interactions
Expand Down
46 changes: 31 additions & 15 deletions pymatgen/core/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from monty.json import MSONable

from pymatgen.core.structure import Composition, DummySpecies, Element, Lattice, Molecule, Species, Structure
from pymatgen.io.ase import AseAtomsAdaptor
from pymatgen.io.vasp.outputs import Vasprun, Xdatcar

__author__ = "Eric Sivonxay, Shyam Dwaraknath, Mingjian Wen, Evan Spotte-Smith"
Expand Down Expand Up @@ -525,32 +526,47 @@ def from_molecules(cls, molecules: list[Molecule], **kwargs) -> Trajectory:

@classmethod
def from_file(cls, filename: str | Path, constant_lattice: bool = True, **kwargs) -> Trajectory:
"""Create trajectory from XDATCAR or vasprun.xml file.
"""Create trajectory from XDATCAR, vasprun.xml file, or ASE trajectory (.traj) file.
Args:
filename: Path to the file to read from.
constant_lattice: Whether the lattice changes during the simulation,
such as in an NPT MD simulation.
filename (str | Path): Path to the file to read from.
constant_lattice (bool): Whether the lattice changes during the simulation,
such as in an NPT MD simulation. Defaults to True.
**kwargs: Additional kwargs passed to Trajectory constructor.
Returns:
A trajectory from the file.
Trajectory: containing the structures or molecules in the file.
"""
fname = Path(filename).expanduser().resolve().name
filename = str(Path(filename).expanduser().resolve())
is_mol = False

if fnmatch(fname, "*XDATCAR*"):
if fnmatch(filename, "*XDATCAR*"):
structures = Xdatcar(filename).structures
elif fnmatch(fname, "vasprun*.xml*"):
elif fnmatch(filename, "vasprun*.xml*"):
structures = Vasprun(filename).structures
elif fnmatch(filename, "*.traj"):
try:
from ase.io.trajectory import Trajectory as AseTrajectory

ase_traj = AseTrajectory(filename)
# periodic boundary conditions should be the same for all frames so just check the first
pbc = ase_traj[0].pbc
if any(pbc):
structures = [AseAtomsAdaptor.get_structure(atoms) for atoms in ase_traj]
else:
molecules = [AseAtomsAdaptor.get_molecule(atoms) for atoms in ase_traj]
is_mol = True

except ImportError as exc:
raise exc

else:
supported = ("XDATCAR", "vasprun.xml")
raise ValueError(f"Expect file to be one of {supported}; got {filename}.")
supported_file_types = ("XDATCAR", "vasprun.xml", "*.traj")
raise ValueError(f"Expect file to be one of {supported_file_types}; got {filename}.")

return cls.from_structures(
structures,
constant_lattice=constant_lattice,
**kwargs,
)
if is_mol:
return cls.from_molecules(molecules, **kwargs)
return cls.from_structures(structures, constant_lattice=constant_lattice, **kwargs)

@staticmethod
def _combine_lattice(lat1: np.ndarray, lat2: np.ndarray, len1: int, len2: int) -> tuple[np.ndarray, bool]:
Expand Down
16 changes: 14 additions & 2 deletions tests/core/test_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ def setUp(self):
coords = out.data["geometries"]

self.molecules = []
for c in coords:
mol = Molecule(species, c, charge=int(last_mol.charge), spin_multiplicity=int(last_mol.spin_multiplicity))
for coord in coords:
mol = Molecule(
species, coord, charge=int(last_mol.charge), spin_multiplicity=int(last_mol.spin_multiplicity)
)
self.molecules.append(mol)

self.traj_mols = Trajectory(
Expand Down Expand Up @@ -468,3 +470,13 @@ def test_xdatcar_write(self):
# Load trajectory from written XDATCAR and compare to original
written_traj = Trajectory.from_file(f"{self.tmp_path}/traj_test_XDATCAR")
self._check_traj_equality(self.traj, written_traj)

def test_from_file(self):
traj = Trajectory.from_file(f"{TEST_FILES_DIR}/LiMnO2_chgnet_relax.traj")
assert isinstance(traj, Trajectory)

# Check length of the trajectory
assert len(traj) == 2

# Check composition of the first frame of the trajectory
assert traj[0].formula == "Li2 Mn2 O4"
Binary file added tests/files/LiMnO2_chgnet_relax.traj
Binary file not shown.

0 comments on commit c67739a

Please sign in to comment.