Skip to content

Commit

Permalink
Add comprehensive tests for HelmTable features and inversion error ha…
Browse files Browse the repository at this point in the history
…ndling
  • Loading branch information
msbc committed Nov 30, 2024
1 parent d5a6c36 commit 99a2e7e
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 3 deletions.
30 changes: 27 additions & 3 deletions tests/test_helm_table.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Test accuracy of code"""
import numpy as np
from helmeos import HelmTable
from helmeos import HelmTable, default
from helmeos.helm_table import OldStyleInputs, _translate, _DelayedTable
import helmeos.table_param as tab
try:
import helmholtz
Expand All @@ -10,7 +11,7 @@
raise


def test_table(nrand=100, vars_to_test=None, silent=False, err_dict=None, tol=1e-14):
def test_table_values(nrand=100, vars_to_test=None, silent=False, err_dict=None, tol=1e-14):
if err_dict is None:
err_dict = {'etot': 4e-15, 'ptot': 6e-16, 'cs': 3e-13, 'sele': 0}
if vars_to_test is None:
Expand Down Expand Up @@ -60,5 +61,28 @@ def test_table(nrand=100, vars_to_test=None, silent=False, err_dict=None, tol=1e
print(fmt.format(var, dif[i]))


def test_table_features():
ht = HelmTable()
ot = OldStyleInputs()
dt = _DelayedTable()
assert default.fn == ht.fn
h_data = ht.full_table(overwite=True)
o_data = ot.full_table()
assert h_data.keys() == o_data.keys()
for key in h_data:
assert np.all(h_data[key] == o_data[key])
# test cache
h_data = ht.full_table()
for key in h_data:
assert np.all(h_data[key] == o_data[key])
data = ot.eos_DT(1e-7, 1e4, 1.0, 1.0)
assert ht.eos_DT(1e-7, 1e4, 1.0, 1.0) == data
assert dt.eos_DT(1e-7, 1e4, 1.0, 1.0) == data
for key, val in _translate.items():
assert ot.eos_DT(1e-7, 1e4, 1.0, 1.0, outvar=key)[key] == data[val]
assert ot.eos_DT(1e-7, 1e4, 1.0, 1.0, outvar=val)[val] == data[val]


if __name__ == "__main__":
test_table(silent=False)
test_table_values()
test_table_features()
16 changes: 16 additions & 0 deletions tests/test_inversion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import helmeos
import pytest


def test_inversion(tol=2e-10, vars_to_print=None):
Expand All @@ -18,6 +19,21 @@ def test_inversion(tol=2e-10, vars_to_print=None):
assert diff.max() <= tol, f"Test failed for {i}. Relative error: {diff.max()}"
print(f"Test passed for {i}. Max relative error: {diff.max()}")

# test inversion error handling and bracketing
dens = 1e-7
var = 165585 # ptot for dens=1e-7, temp=1e4
var_name = "ptot"
der_name = "none"
args = (dens, abar, zbar, var, var_name, der_name)
with pytest.raises(ValueError, match="Did not converge."):
helmeos.default._adaptive_root_find(1e7, args=args, tol=1e-14, maxiter=3)
with pytest.raises(ValueError, match="Bracket must be strictly monotonically increasing."):
helmeos.default._adaptive_root_find(1e7, args=args, bracket=(1e7, 1e7))
# test initial bracketing correction
with pytest.warns(RuntimeWarning, match="invalid value encountered in scalar divide"):
temp = helmeos.default._adaptive_root_find(9e3, args=args, bracket=(1e3, 1e5))
assert np.isclose(temp, 1e4, rtol=1e-5), f"Test failed for {var_name}. Temp: {temp}"


if __name__ == "__main__":
test_inversion()
35 changes: 35 additions & 0 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from helmeos import default
import pytest


def test_import_error(monkeypatch):
# Save the original `__import__` method
original_import = __import__

# Mock `__import__` to raise ImportError for `matplotlib.pyplot`
def mock_import(name, *args, **kwargs):
if name == "matplotlib.pyplot":
raise ImportError("No module named 'matplotlib.pyplot'")
return original_import(name, *args, **kwargs)

# Apply the monkeypatch
monkeypatch.setattr("builtins.__import__", mock_import)

with pytest.raises(ImportError):
default.plot_var('cs')


def test_plotting(monkeypatch):
import matplotlib.pyplot as plt
default.plot_var('cs')
plt.close()
default.plot_var('cs', log=True)
plt.close()
ax = plt.subplot()
default.plot_var('cs', ax=ax)
plt.close()


if __name__ == "__main__":
test_import_error(pytest.MonkeyPatch)
test_plotting()

0 comments on commit 99a2e7e

Please sign in to comment.