Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Improved max_iters handling #3235

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add max_iters tests for state dict
leej3 committed Apr 19, 2024
commit 0bc872cdcb68ac02ffa4ca0c9ffef7c76291b1b8
67 changes: 54 additions & 13 deletions tests/ignite/engine/test_engine_state_dict.py
Original file line number Diff line number Diff line change
@@ -14,6 +14,7 @@ def test_state_dict():
assert "iteration" in sd and sd["iteration"] == 0
assert "max_epochs" in sd and sd["max_epochs"] is None
assert "epoch_length" in sd and sd["epoch_length"] is None
assert "max_iters" in sd and sd["max_iters"] is None

def _test(state):
engine.state = state
@@ -23,8 +24,14 @@ def _test(state):
assert sd["epoch_length"] == engine.state.epoch_length
assert sd["max_epochs"] == engine.state.max_epochs

if state.max_epochs is not None:
assert sd["max_epochs"] == engine.state.max_epochs
else:
assert sd["max_iters"] == engine.state.max_iters

_test(State(iteration=500, epoch_length=1000, max_epochs=100))
_test(State(epoch=5, epoch_length=1000, max_epochs=100))
_test(State(epoch=5, epoch_length=1000, max_iters=500))


def test_state_dict_with_user_keys():
@@ -40,23 +47,40 @@ def _test(state):
)
assert sd["iteration"] == engine.state.iteration
assert sd["epoch_length"] == engine.state.epoch_length
assert sd["max_epochs"] == engine.state.max_epochs
if state.max_epochs is not None:
assert sd["max_epochs"] == engine.state.max_epochs
else:
assert sd["max_iters"] == engine.state.max_iters
assert sd["alpha"] == engine.state.alpha
assert sd["beta"] == engine.state.beta

_test(State(iteration=500, epoch_length=1000, max_epochs=100, alpha=0.01, beta="Good"))
_test(State(iteration=500, epoch_length=1000, max_iters=2000, alpha=0.01, beta="Good"))


def test_state_dict_integration():
engine = Engine(lambda e, b: 1)
data = range(100)
engine.run(data, max_epochs=10)
sd = engine.state_dict()
assert isinstance(sd, Mapping) and len(sd) == len(engine._state_dict_all_req_keys) + 1
assert sd["iteration"] == engine.state.iteration == 10 * 100
assert sd["epoch_length"] == engine.state.epoch_length == 100
assert sd["max_epochs"] == engine.state.max_epochs == 10

def test_state_dict_integration():
def _test(max_epochs, max_iters):
engine = Engine(lambda e, b: 1)
data = range(100)
engine.run(data, max_epochs=max_epochs, max_iters=max_iters)
sd = engine.state_dict()
assert isinstance(sd, Mapping)
assert len(sd) == len(engine._state_dict_all_req_keys) + 1

if max_epochs is None and max_iters is None:
max_epochs = 1
n_iters = max_iters if max_iters is not None else max_epochs * 100
assert sd["iteration"] == engine.state.iteration == n_iters
assert sd["epoch_length"] == engine.state.epoch_length == 100
if engine.state.max_epochs is not None:
assert sd["max_epochs"] == engine.state.max_epochs == max_epochs
else:
assert sd["max_iters"] == engine.state.max_iters == max_iters

_test(max_epochs=10, max_iters=None)
_test(max_epochs=None, max_iters=None)
_test(max_epochs=None, max_iters=10 * 100)

def test_load_state_dict_asserts():
engine = Engine(lambda e, b: 1)
@@ -93,11 +117,29 @@ def _test(sd):
elif "epoch" in sd:
assert sd["epoch"] == engine.state.epoch
assert sd["epoch_length"] == engine.state.epoch_length
assert sd["max_epochs"] == engine.state.max_epochs
if "max_epochs" in sd:
assert sd["max_epochs"] == engine.state.max_epochs
else:
assert sd["max_iters"] == engine.state.max_iters

_test({"max_epochs": 100, "epoch_length": 120, "iteration": 123})
_test({"max_epochs": 100, "epoch_length": 120, "epoch": 5})

with pytest.raises(ValueError, match=r"Argument max_epochs should be larger than"):
_test({"max_epochs": 10, "epoch_length": 120, "epoch": 50})

with pytest.raises(ValueError, match=r"Argument max_epochs should be larger than"):
_test({"max_epochs": 10, "epoch_length": 120, "iteration": 5000})

_test({"max_iters": 500, "epoch_length": 120, "iteration": 123})
_test({"max_iters": 500, "epoch_length": 120, "epoch": 3})

with pytest.raises(ValueError, match=r"Argument max_iters should be larger than"):
_test({"max_iters": 500, "epoch_length": 120, "epoch": 5})

with pytest.raises(ValueError, match=r"Argument max_iters should be larger than"):
_test({"max_iters": 500, "epoch_length": 120, "iteration": 501})


def test_load_state_dict_with_user_keys():
engine = Engine(lambda e, b: 1)
@@ -142,8 +184,7 @@ def test_load_state_dict_with_params_overriding_integration():
assert state.max_epochs == new_max_epochs
assert state.iteration == state_dict["epoch_length"] * new_max_epochs
assert state.epoch == new_max_epochs

with pytest.raises(ValueError, match=r"Argument max_epochs should be greater than or equal to the start epoch"):
with pytest.raises(ValueError, match=r"Argument max_epochs should be larger than the current epoch"):
engine.load_state_dict(state_dict)
engine.run(data, max_epochs=3)