Skip to content

Commit

Permalink
Dataset.get_differences now uses assert_almost_equal for variables th…
Browse files Browse the repository at this point in the history
…at represent numeric data
  • Loading branch information
gmatteo committed Jan 17, 2025
1 parent 7b99eb1 commit dc6d097
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 11 deletions.
54 changes: 48 additions & 6 deletions abipy/abio/abivars.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def is_abivar(varname: str) -> bool:
s.lower() for s in (
"au", "nm",
"Angstr", "Angstrom", "Angstroms", "Bohr", "Bohrs",
"eV", "Ha", "Hartree", "Hartrees", "K", "Ry", "Rydberg", "Rydbergs",
"eV", "meV", "Ha", "Hartree", "Hartrees", "K", "Ry", "Rydberg", "Rydbergs",
"T", "Tesla",)
}

Expand Down Expand Up @@ -462,10 +462,18 @@ def write_notebook(self, nbpath=None) -> str:

return self._write_nb_nbpath(nb, nbpath)

def get_differences(self, other, ignore_vars=None) -> list[str]:
def get_differences(self, other, ignore_vars=None, decimal=7) -> list[str]:
"""
Get the differences between this AbinitInputFile and another.
Get the differences between this Dataset and another.
Args:
other: Other Dataset instance.
ignore_vars: List of variable names that should be ignored
decimal: Number of decimal places to check for equality. Default is 7.
"""
from numpy.testing import assert_almost_equal
abivars_database = get_codevars()["abinit"]

diffs = []
to_ignore = {"acell", "angdeg", "rprim", "ntypat", "natom", "znucl", "typat", "xred", "xcart", "xangst"}
if ignore_vars is not None:
Expand Down Expand Up @@ -494,11 +502,45 @@ def get_differences(self, other, ignore_vars=None) -> list[str]:
if other_only_keys:
diffs.append(f"The following variables are in other file but not in this one: "
f"{', '.join([str(k) for k in other_only_keys])}")

for k in common_keys:
if self_dataset_dict[k] != other_dataset_dict[k]:
v1, v2 = self_dataset_dict[k], other_dataset_dict[k]

ok = True
exc_msg = ""
try:
if abivars_database[k].vartype == "string":
# Strict equality for abinit variables whose type is string.
ok = v1 == v2

elif isinstance(v1, str):
# v1 is a string representing a number/vector with/without units.
tokens_1, tokens_2 = v1.split(), v2.split()
if is_abiunit(tokens_1[-1]):
assert_almost_equal(np.array(tokens_1[:-1], dtype=float),
np.array(tokens_2[:-1], dtype=float), decimal=decimal)
# Check units string as well
ok = tokens_1[-1] == tokens_2[-1]
else:
assert_almost_equal(np.array(tokens_1, dtype=float),
np.array(tokens_2, dtype=float), decimal=decimal)
else:
# Assume v1 and v2 are numpy arrays, list, tuple.
assert_almost_equal(v1, v2, decimal)

except Exception as exc:
ok = False
exc_msg = str(exc)

if not ok:
diffs.append(f"The variable '{k}' is different in the two files:\n"
f" - this file: '{self_dataset_dict[k]}'\n"
f" - other file: '{other_dataset_dict[k]}'")
f" - this file: '{v1}'\n"
f" - other file: '{v2}'")
#if exc_msg:
# diffs[-1] += f"\npython exception: {exc_msg}"

print(diffs[-1])

return diffs


Expand Down
3 changes: 3 additions & 0 deletions abipy/abio/tests/test_abivars.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,3 +399,6 @@ def test_get_differences(self):
assert diffs == ["The following variables are in other file but not in this one: ecut"]
diffs = inp2.get_differences(inp4, ignore_vars=["ngkpt", "ecut"])
assert diffs == []

#inp1 = AbinitInputFile.from_string(s1)
#inp2 = AbinitInputFile.from_string(s2)
6 changes: 1 addition & 5 deletions abipy/core/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@
import json
import tempfile
import unittest
try:
import numpy.testing as nptu
except ImportError:
import numpy.testing.utils as nptu
import abipy.data as abidata
import numpy.testing as nptu

from typing import Optional
from functools import wraps
Expand Down

0 comments on commit dc6d097

Please sign in to comment.