From f20b6a41ba76ef4d05969d702b2ce18d29d05de8 Mon Sep 17 00:00:00 2001 From: leej3 Date: Wed, 17 Apr 2024 12:01:46 +0100 Subject: [PATCH 1/8] add tests for max_iters --- tests/ignite/engine/test_engine.py | 554 +++++++++++++++++++++++------ 1 file changed, 442 insertions(+), 112 deletions(-) diff --git a/tests/ignite/engine/test_engine.py b/tests/ignite/engine/test_engine.py index 13021242650..c1a3c466c14 100644 --- a/tests/ignite/engine/test_engine.py +++ b/tests/ignite/engine/test_engine.py @@ -1,3 +1,4 @@ +import math import os import time from unittest.mock import call, MagicMock, Mock @@ -10,7 +11,7 @@ from ignite.engine import Engine, Events, State from ignite.engine.deterministic import keep_random_state from ignite.metrics import Average -from tests.ignite.engine import BatchChecker, EpochCounter, IterationCounter +from tests.ignite.engine import BatchChecker, EpochCounter, IterationCounter, get_iterable_dataset class RecordedEngine(Engine): @@ -503,6 +504,15 @@ def test__is_done(self): state = State(iteration=1000, max_epochs=10, epoch_length=100) assert Engine._is_done(state) + state = State(iteration=11, epoch=2, max_epochs=None, epoch_length=11, max_iters=22) + assert not Engine._is_done(state) + + state = State(iteration=100, epoch=1, max_epochs=None, epoch_length=100, max_iters=250) + assert not Engine._is_done(state) + + state = State(iteration=250, epoch=1, max_epochs=None, epoch_length=100, max_iters=250) + assert Engine._is_done(state) + def test__setup_engine(self): engine = Engine(lambda e, b: 1) engine.state = State(iteration=10, epoch=1, max_epochs=100, epoch_length=100) @@ -510,13 +520,26 @@ def test__setup_engine(self): data = list(range(100)) engine.state.dataloader = data engine._setup_engine() - assert engine._init_iter == 10 + assert len(engine._init_iter) == 1 and engine._init_iter[0] == 10 def test_run_asserts(self): engine = Engine(lambda e, b: 1) - with pytest.raises(ValueError, match=r"Input data has zero size. Please provide non-empty data"): + with pytest.raises( + ValueError, + match=r"Argument epoch_length is invalid. Please, either set a correct epoch_length " + r"value or check if input data has non-zero size.", + ): engine.run([]) + with pytest.raises(ValueError, match=r"Argument max_epochs should be larger than"): + engine.state.max_epochs = 5 + engine.state.epoch = 5 + engine.run([0, 1], max_epochs=3) + + with pytest.raises(ValueError, match=r"Argument max_iters should be larger than"): + engine.state.max_iters = 100 + engine.state.iteration = 100 + engine.run([0, 1], max_iters=50) def test_state_get_event_attrib_value(self): state = State() @@ -573,7 +596,21 @@ def check_completed_time(): >= (sleep_time * epoch_length + extra_sleep_time) * max_epochs + extra_sleep_time ) - def _test_check_triggered_events(self, data, max_epochs, epoch_length, exp_iter_stops=None): + def _test_check_triggered_events( + self, + data, + max_epochs=None, + epoch_length=None, + max_iters=None, + n_epoch_started=None, + n_epoch_completed=None, + n_iter_started=None, + n_iter_completed=None, + n_batch_started=None, + n_batch_completed=None, + exp_iter_stops=None, + n_terminate=None, + ): engine = Engine(lambda e, b: 1) events = [ Events.STARTED, @@ -585,6 +622,8 @@ def _test_check_triggered_events(self, data, max_epochs, epoch_length, exp_iter_ Events.GET_BATCH_STARTED, Events.GET_BATCH_COMPLETED, Events.DATALOADER_STOP_ITERATION, + Events.TERMINATE, + Events.TERMINATE_SINGLE_EPOCH, ] handlers = {e: MagicMock() for e in events} @@ -592,18 +631,41 @@ def _test_check_triggered_events(self, data, max_epochs, epoch_length, exp_iter_ for e, handler in handlers.items(): engine.add_event_handler(e, handler) - engine.run(data, max_epochs=max_epochs, epoch_length=epoch_length) + engine.run(data, max_epochs=max_epochs, max_iters=max_iters, epoch_length=epoch_length) + + if epoch_length is None: + epoch_length = engine.state.epoch_length + + assert epoch_length is not None + + if n_iter_started is None: + n_iter_started = max_epochs * epoch_length if max_epochs is not None else max_iters + + if n_iter_completed is None: + n_iter_completed = max_epochs * epoch_length if max_epochs is not None else max_iters + + if n_batch_started is None: + n_batch_started = max_epochs * epoch_length if max_epochs is not None else max_iters + + if n_batch_completed is None: + n_batch_completed = max_epochs * epoch_length if max_epochs is not None else max_iters + + if n_terminate is None: + n_terminate = int(n_epoch_started != n_epoch_completed) if max_iters is not None else 0 + expected_num_calls = { Events.STARTED: 1, Events.COMPLETED: 1, - Events.EPOCH_STARTED: max_epochs, - Events.EPOCH_COMPLETED: max_epochs, - Events.ITERATION_STARTED: max_epochs * epoch_length, - Events.ITERATION_COMPLETED: max_epochs * epoch_length, - Events.GET_BATCH_STARTED: max_epochs * epoch_length if data is not None else 0, - Events.GET_BATCH_COMPLETED: max_epochs * epoch_length if data is not None else 0, + Events.EPOCH_STARTED: n_epoch_started if n_epoch_started is not None else max_epochs, + Events.EPOCH_COMPLETED: n_epoch_completed if n_epoch_completed is not None else max_epochs, + Events.ITERATION_STARTED: n_iter_started, + Events.ITERATION_COMPLETED: n_iter_completed, + Events.GET_BATCH_STARTED: n_batch_started, + Events.GET_BATCH_COMPLETED: n_batch_completed, Events.DATALOADER_STOP_ITERATION: (max_epochs - 1) if exp_iter_stops is None else exp_iter_stops, + Events.TERMINATE: n_terminate, + Events.TERMINATE_SINGLE_EPOCH: 0, } for n, handler in handlers.items(): @@ -619,6 +681,13 @@ def _test_run_check_triggered_events(self): ) self._test_check_triggered_events(None, max_epochs=5, epoch_length=150, exp_iter_stops=0) + kwargs = dict(exp_iter_stops=4, n_epoch_started=5, n_epoch_completed=5) + self._test_check_triggered_events(list(range(20)), max_iters=100, epoch_length=20, **kwargs) + kwargs = dict(exp_iter_stops=2, n_epoch_started=5, n_epoch_completed=5) + self._test_check_triggered_events(list(range(20)), max_iters=50, epoch_length=10, **kwargs) + kwargs = dict(exp_iter_stops=2, n_epoch_started=3, n_epoch_completed=2) + self._test_check_triggered_events(list(range(20)), max_iters=55, epoch_length=25, **kwargs) + def test_run_check_triggered_events_list(self): self._test_run_check_triggered_events() @@ -632,6 +701,13 @@ def infinite_data_iterator(): self._test_check_triggered_events(infinite_data_iterator(), max_epochs=5, epoch_length=50, exp_iter_stops=0) self._test_check_triggered_events(infinite_data_iterator(), max_epochs=5, epoch_length=150, exp_iter_stops=0) + kwargs = dict(exp_iter_stops=0, n_epoch_started=5, n_epoch_completed=5) + self._test_check_triggered_events(infinite_data_iterator(), max_iters=100, epoch_length=20, **kwargs) + kwargs = dict(exp_iter_stops=0, n_epoch_started=1, n_epoch_completed=0) + self._test_check_triggered_events(infinite_data_iterator(), max_iters=10, epoch_length=20, **kwargs) + kwargs = dict(exp_iter_stops=0, n_epoch_started=2, n_epoch_completed=1) + self._test_check_triggered_events(infinite_data_iterator(), max_iters=30, epoch_length=20, **kwargs) + def limited_data_iterator(): for i in range(100): yield i @@ -639,20 +715,75 @@ def limited_data_iterator(): self._test_check_triggered_events(limited_data_iterator(), max_epochs=1, epoch_length=100, exp_iter_stops=0) self._test_check_triggered_events(limited_data_iterator(), max_epochs=10, epoch_length=10, exp_iter_stops=0) - # These tests should fail - with pytest.raises(AssertionError): - with pytest.warns(UserWarning, match=r"Data iterator can not provide data anymore"): - self._test_check_triggered_events(limited_data_iterator(), max_epochs=3, epoch_length=100) - - with pytest.raises(AssertionError): - with pytest.warns(UserWarning, match=r"Data iterator can not provide data anymore"): - self._test_check_triggered_events(limited_data_iterator(), max_epochs=3, epoch_length=75) + kwargs = dict(exp_iter_stops=0, n_epoch_started=1, n_epoch_completed=1) + self._test_check_triggered_events(limited_data_iterator(), max_iters=20, epoch_length=20, **kwargs) + kwargs = dict(exp_iter_stops=0, n_epoch_started=2, n_epoch_completed=1) + self._test_check_triggered_events(limited_data_iterator(), max_iters=19, epoch_length=10, **kwargs) + + with pytest.warns(UserWarning, match=r"Data iterator can not provide data anymore"): + kwargs = dict( + exp_iter_stops=1, + n_epoch_started=2, + n_epoch_completed=1, + n_iter_started=20, + n_iter_completed=20, + n_batch_started=21, + n_batch_completed=20, + n_terminate=1, + ) + self._test_check_triggered_events(limited_data_iterator(), max_epochs=3, epoch_length=20, **kwargs) + + with pytest.warns(UserWarning, match=r"Data iterator can not provide data anymore"): + kwargs = dict( + exp_iter_stops=1, + n_epoch_started=2, + n_epoch_completed=1, + n_iter_started=20, + n_iter_completed=20, + n_batch_started=22, # 22 and not 21. GET_BATCH_STARTED is called once more to epoch_length + n_batch_completed=20, + n_terminate=1, + ) + self._test_check_triggered_events(limited_data_iterator(), max_epochs=3, **kwargs) + + with pytest.warns(UserWarning, match=r"Data iterator can not provide data anymore"): + kwargs = dict( + exp_iter_stops=1, + n_epoch_started=2, + n_epoch_completed=1, + n_iter_started=20, + n_iter_completed=20, + n_batch_started=21, + n_batch_completed=20, + n_terminate=1, + ) + self._test_check_triggered_events(limited_data_iterator(), max_epochs=3, epoch_length=15, **kwargs) + + kwargs = dict( + exp_iter_stops=1, + n_epoch_started=1, + n_epoch_completed=0, + n_iter_started=20, + n_iter_completed=20, + n_batch_started=21, + n_batch_completed=20, + n_terminate=1, + ) + self._test_check_triggered_events(limited_data_iterator(), max_epochs=1, epoch_length=21, **kwargs) + + with pytest.warns(UserWarning, match=r"Data iterator can not provide data anymore"): + kwargs = dict( + exp_iter_stops=1, + n_epoch_started=2, + n_epoch_completed=1, + n_iter_started=20, + n_iter_completed=20, + n_batch_started=21, + n_batch_completed=20, + n_terminate=1, + ) + self._test_check_triggered_events(limited_data_iterator(), max_iters=21, epoch_length=12, **kwargs) - with pytest.raises(AssertionError): - # Below test does not raise "Data iterator can not provide data anymore" warning as the last - # epoch is equal max_epochs - # with pytest.warns(UserWarning, match=r"Data iterator can not provide data anymore"): - self._test_check_triggered_events(limited_data_iterator(), max_epochs=1, epoch_length=101) def test_run_check_triggered_events_on_iterator(self): self._test_run_check_triggered_events_on_iterator() @@ -760,7 +891,13 @@ def run_evaluation(_): assert train_batches[epoch_length + i] != train_batches[2 * epoch_length + i] assert train_batches[i] == train_only_batches[i] - def test_engine_with_dataloader_no_auto_batching(self): + @pytest.mark.parametrize( + "kwargs", [ + {"max_epochs": None, "epoch_length": 10, "max_iters": 25}, + {"max_epochs": 5, "epoch_length": 10, "max_iters": None}, + ] + ) + def test_engine_with_dataloader_no_auto_batching(self, kwargs): # tests https://github.com/pytorch/ignite/issues/941 from torch.utils.data import BatchSampler, DataLoader, RandomSampler @@ -775,9 +912,12 @@ def foo(e, b): counter[0] += 1 engine = Engine(foo) - engine.run(data_loader, epoch_length=10, max_epochs=5) + engine.run(data_loader, **kwargs) - assert counter[0] == 50 + if kwargs["max_epochs"]: + assert counter[0] == kwargs["epoch_length"] * kwargs["max_epochs"] + else: + assert counter[0] == kwargs["max_iters"] def test_run_once_finite_iterator_no_epoch_length(self): # FR: https://github.com/pytorch/ignite/issues/871 @@ -788,19 +928,43 @@ def finite_unk_size_data_iter(): for i in range(unknown_size): yield i - bc = BatchChecker(data=list(range(unknown_size))) + def _test(**kwargs): + bc = BatchChecker(data=list(range(unknown_size))) - engine = Engine(lambda e, b: bc.check(b)) + engine = Engine(lambda e, b: bc.check(b)) - completed_handler = MagicMock() - engine.add_event_handler(Events.COMPLETED, completed_handler) + epoch_completed_handler = MagicMock() + engine.add_event_handler(Events.EPOCH_COMPLETED, epoch_completed_handler) + + completed_handler = MagicMock() + engine.add_event_handler(Events.COMPLETED, completed_handler) + + data_iter = finite_unk_size_data_iter() + engine.run(data_iter, **kwargs) + + assert bc.counter == engine.state.iteration + if len(kwargs) == 0: + assert engine.state.epoch == 1 + assert engine.state.iteration == unknown_size + assert epoch_completed_handler.call_count == 1 + else: + max_iters = kwargs["max_iters"] + if max_iters <= unknown_size: + assert engine.state.epoch == 1 + assert engine.state.iteration == max_iters + else: + assert engine.state.epoch == 2 + assert engine.state.iteration == unknown_size + + assert completed_handler.call_count == 1 + + _test() + _test(max_iters=unknown_size) + _test(max_iters=unknown_size // 2) + with pytest.warns(UserWarning, match=r"Data iterator can not provide data anymore"): + _test(max_iters=unknown_size * 2) - data_iter = finite_unk_size_data_iter() - engine.run(data_iter) - assert engine.state.epoch == 1 - assert engine.state.iteration == unknown_size - assert completed_handler.call_count == 1 def test_run_finite_iterator_no_epoch_length(self): # FR: https://github.com/pytorch/ignite/issues/871 @@ -824,6 +988,7 @@ def restart_iter(): assert engine.state.epoch == 5 assert engine.state.iteration == unknown_size * 5 + def test_run_finite_iterator_no_epoch_length_2(self): # FR: https://github.com/pytorch/ignite/issues/871 known_size = 11 @@ -832,78 +997,159 @@ def finite_size_data_iter(size): for i in range(size): yield i - bc = BatchChecker(data=list(range(known_size))) + def _test(**kwargs): + bc = BatchChecker(data=list(range(known_size))) - engine = Engine(lambda e, b: bc.check(b)) + engine = Engine(lambda e, b: bc.check(b)) - @engine.on(Events.ITERATION_COMPLETED(every=known_size)) - def restart_iter(): - engine.state.dataloader = finite_size_data_iter(known_size) + @engine.on(Events.ITERATION_COMPLETED(every=known_size)) + def restart_iter(): + engine.state.dataloader = finite_size_data_iter(known_size) - data_iter = finite_size_data_iter(known_size) - engine.run(data_iter, max_epochs=5) + data_iter = finite_size_data_iter(known_size) + engine.run(data_iter, **kwargs) - assert engine.state.epoch == 5 - assert engine.state.iteration == known_size * 5 + assert bc.counter == engine.state.iteration + if "max_epochs" in kwargs: + assert engine.state.epoch == kwargs["max_epochs"] + assert engine.state.iteration == known_size * kwargs["max_epochs"] + else: + max_iters = kwargs["max_iters"] + if max_iters <= known_size: + assert engine.state.epoch == math.ceil(max_iters / known_size) + assert engine.state.iteration == max_iters - def test_faq_inf_iterator_with_epoch_length(self): - # Code snippet from FAQ - # import torch + _test(max_epochs=5) + _test(max_iters=known_size) + _test(max_iters=known_size // 2) - torch.manual_seed(12) - def infinite_iterator(batch_size): - while True: - batch = torch.rand(batch_size, 3, 32, 32) - yield batch + def test_faq_inf_iterator_with_epoch_length(): + def _test(max_epochs, max_iters): + # Code snippet from FAQ + # import torch - def train_step(trainer, batch): - # ... - s = trainer.state - print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch.norm():.3f}") + torch.manual_seed(12) - trainer = Engine(train_step) - # We need to specify epoch_length to define the epoch - trainer.run(infinite_iterator(4), epoch_length=5, max_epochs=3) + def infinite_iterator(batch_size): + while True: + batch = torch.rand(batch_size, 3, 32, 32) + yield batch + + def train_step(trainer, batch): + # ... + s = trainer.state + print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch.norm():.3f}") + + trainer = Engine(train_step) + # We need to specify epoch_length to define the epoch + trainer.run(infinite_iterator(4), epoch_length=5, max_epochs=max_epochs, max_iters=max_iters) + + assert trainer.state.epoch == 3 + assert trainer.state.iteration == 3 * 5 + + _test(max_epochs=3, max_iters=None) + _test(max_epochs=None, max_iters=3 * 5) - assert trainer.state.epoch == 3 - assert trainer.state.iteration == 3 * 5 def test_faq_inf_iterator_no_epoch_length(self): - # Code snippet from FAQ - # import torch + def _test(max_epochs, max_iters): + # Code snippet from FAQ + # import torch - torch.manual_seed(12) + torch.manual_seed(12) - def infinite_iterator(batch_size): - while True: - batch = torch.rand(batch_size, 3, 32, 32) - yield batch + def infinite_iterator(batch_size): + while True: + batch = torch.rand(batch_size, 3, 32, 32) + yield batch - def train_step(trainer, batch): - # ... - s = trainer.state - print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch.norm():.3f}") + def train_step(trainer, batch): + # ... + s = trainer.state + print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch.norm():.3f}") - trainer = Engine(train_step) + trainer = Engine(train_step) + + @trainer.on(Events.ITERATION_COMPLETED(once=15)) + def stop_training(): + trainer.terminate() + + trainer.run(infinite_iterator(4), max_epochs=max_epochs, max_iters=max_iters) - @trainer.on(Events.ITERATION_COMPLETED(once=15)) - def stop_training(): - trainer.terminate() + assert trainer.state.epoch == 1 + assert trainer.state.iteration == 15 - trainer.run(infinite_iterator(4)) + _test(max_epochs=None, max_iters=None) + _test(max_epochs=None, max_iters=100) - assert trainer.state.epoch == 1 - assert trainer.state.iteration == 15 def test_faq_fin_iterator_unknw_size(self): + def _test(max_epochs, max_iters): + # Code snippet from FAQ + # import torch + torch.manual_seed(12) + + def finite_unk_size_data_iter(): + for i in range(11): + yield i + + def train_step(trainer, batch): + # ... + s = trainer.state + print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch:.3f}") + + trainer = Engine(train_step) + + @trainer.on(Events.DATALOADER_STOP_ITERATION) + def restart_iter(): + trainer.state.dataloader = finite_unk_size_data_iter() + + data_iter = finite_unk_size_data_iter() + trainer.run(data_iter, max_epochs=max_epochs, max_iters=max_iters) + + assert trainer.state.epoch == 5 if max_iters is None else math.ceil(max_iters // 11) + assert trainer.state.iteration == 5 * 11 if max_iters is None else max_iters + + _test(max_epochs=5, max_iters=None) + _test(max_epochs=None, max_iters=60) + + # # # # # + + def _test(max_epochs, max_iters): + # import torch + torch.manual_seed(12) + + def finite_unk_size_data_iter(): + for i in range(11): + yield i + + def val_step(evaluator, batch): + # ... + s = evaluator.state + print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch:.3f}") + + evaluator = Engine(val_step) + + data_iter = finite_unk_size_data_iter() + evaluator.run(data_iter, max_epochs=max_epochs, max_iters=max_iters) + + assert evaluator.state.epoch == 1 + assert evaluator.state.iteration == 1 * 11 + + _test(max_epochs=None, max_iters=None) + + + def test_faq_fin_iterator(self): # Code snippet from FAQ # import torch torch.manual_seed(12) - def finite_unk_size_data_iter(): - for i in range(11): + size = 11 + + def finite_size_data_iter(size): + for i in range(size): yield i def train_step(trainer, batch): @@ -913,23 +1159,25 @@ def train_step(trainer, batch): trainer = Engine(train_step) - @trainer.on(Events.DATALOADER_STOP_ITERATION) + @trainer.on(Events.ITERATION_COMPLETED(every=size)) def restart_iter(): - trainer.state.dataloader = finite_unk_size_data_iter() + trainer.state.dataloader = finite_size_data_iter(size) - data_iter = finite_unk_size_data_iter() + data_iter = finite_size_data_iter(size) trainer.run(data_iter, max_epochs=5) assert trainer.state.epoch == 5 - assert trainer.state.iteration == 5 * 11 + assert trainer.state.iteration == 5 * size # Code snippet from FAQ # import torch torch.manual_seed(12) - def finite_unk_size_data_iter(): - for i in range(11): + size = 11 + + def finite_size_data_iter(size): + for i in range(size): yield i def val_step(evaluator, batch): @@ -939,46 +1187,48 @@ def val_step(evaluator, batch): evaluator = Engine(val_step) - data_iter = finite_unk_size_data_iter() + data_iter = finite_size_data_iter(size) evaluator.run(data_iter) assert evaluator.state.epoch == 1 - assert evaluator.state.iteration == 1 * 11 + assert evaluator.state.iteration == size def test_faq_fin_iterator(self): - # Code snippet from FAQ - # import torch + def _test(max_epochs, max_iters): + # Code snippet from FAQ - torch.manual_seed(12) + # import torch + torch.manual_seed(12) + size = 11 - size = 11 + def finite_size_data_iter(size): + for i in range(size): + yield i - def finite_size_data_iter(size): - for i in range(size): - yield i + def train_step(trainer, batch): + # ... + s = trainer.state + print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch:.3f}") - def train_step(trainer, batch): - # ... - s = trainer.state - print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch:.3f}") + trainer = Engine(train_step) - trainer = Engine(train_step) + @trainer.on(Events.ITERATION_COMPLETED(every=size)) + def restart_iter(): + trainer.state.dataloader = finite_size_data_iter(size) - @trainer.on(Events.ITERATION_COMPLETED(every=size)) - def restart_iter(): - trainer.state.dataloader = finite_size_data_iter(size) + data_iter = finite_size_data_iter(size) + trainer.run(data_iter, max_epochs=max_epochs, max_iters=max_iters) - data_iter = finite_size_data_iter(size) - trainer.run(data_iter, max_epochs=5) + assert trainer.state.epoch == 5 + assert trainer.state.iteration == 5 * size - assert trainer.state.epoch == 5 - assert trainer.state.iteration == 5 * size + _test(max_epochs=5, max_iters=None) + _test(max_epochs=None, max_iters=5 * 11) - # Code snippet from FAQ - # import torch + # # # # # + # import torch torch.manual_seed(12) - size = 11 def finite_size_data_iter(size): @@ -998,6 +1248,7 @@ def val_step(evaluator, batch): assert evaluator.state.epoch == 1 assert evaluator.state.iteration == size + def test_set_data(self): # tests FR https://github.com/pytorch/ignite/issues/833 from torch.utils.data import DataLoader @@ -1228,6 +1479,85 @@ def check_iter_epoch(first_epoch_iter): assert engine.state.iteration == 10 * real_epoch_length + def test_restart_training(self): + data = range(10) + engine = Engine(lambda e, b: 1) + state = engine.run(data, max_epochs=5) + with pytest.raises( + ValueError, + match=r"Argument max_epochs should be larger than the current epoch defined in the state: 2 vs 5. " + r"Please, .+ " + r"before calling engine.run\(\) in order to restart the training from the beginning.", + ): + engine.run(data, max_epochs=2) + state.max_epochs = None + engine.run(data, max_epochs=2) + + + def test_engine_multiple_runs(self): + engine = Engine(lambda e, b: 1) + engine.debug() + + init_epoch = 0 + init_iter = 0 + epoch_length = None + + @engine.on(Events.STARTED) + def assert_resume(): + assert engine.state.epoch == init_epoch + assert engine.state.iteration == init_iter + assert engine.state.epoch_length == epoch_length + + data = range(10) + epoch_length = len(data) + engine.run(data, max_epochs=2) + assert engine.state.epoch == 2 + assert engine.state.iteration == 2 * epoch_length + + engine.debug(False) + + # Continue run with max_epochs + data = range(15) + init_epoch = 2 + init_iter = 2 * epoch_length + engine.run(data, max_epochs=5) + + assert engine.state.epoch == 5 + assert engine.state.iteration == 5 * epoch_length + + # Continue run with max_iters + data = range(15) + init_epoch = 5 + init_iter = 5 * epoch_length + with pytest.raises(ValueError, match=r"State attributes max_iters and max_epochs are mutually exclusive"): + engine.run(data, max_iters=6 * epoch_length) + + engine.state.max_epochs = None + engine.run(data, max_iters=6 * epoch_length) + + assert engine.state.epoch == 6 + assert engine.state.iteration == 6 * epoch_length + + + def test_engine_multiple_runs_2(self): + + e = Engine(lambda _, b: None) + data = iter(range(100)) + + e.run(data, max_iters=50) + assert e.state.iteration == 50 + assert e.state.epoch == 1 + e.run(data, max_iters=52) + assert e.state.iteration == 52 + # should be 1 and if 2 this is a bug : https://github.com/pytorch/ignite/issues/1386 + assert e.state.epoch == 2 + e.run(data, max_iters=100) + assert e.state.iteration == 100 + # should be 1 and if 3 this is a bug : https://github.com/pytorch/ignite/issues/1386 + assert e.state.epoch == 3 + + + @pytest.mark.parametrize( "interrupt_event, e, i", [ From 0bc872cdcb68ac02ffa4ca0c9ffef7c76291b1b8 Mon Sep 17 00:00:00 2001 From: leej3 Date: Wed, 17 Apr 2024 13:39:40 +0100 Subject: [PATCH 2/8] add max_iters tests for state dict --- tests/ignite/engine/test_engine_state_dict.py | 67 +++++++++++++++---- 1 file changed, 54 insertions(+), 13 deletions(-) diff --git a/tests/ignite/engine/test_engine_state_dict.py b/tests/ignite/engine/test_engine_state_dict.py index 4ccfb7ea772..dc5c2b0deda 100644 --- a/tests/ignite/engine/test_engine_state_dict.py +++ b/tests/ignite/engine/test_engine_state_dict.py @@ -14,6 +14,7 @@ def test_state_dict(): assert "iteration" in sd and sd["iteration"] == 0 assert "max_epochs" in sd and sd["max_epochs"] is None assert "epoch_length" in sd and sd["epoch_length"] is None + assert "max_iters" in sd and sd["max_iters"] is None def _test(state): engine.state = state @@ -23,8 +24,14 @@ def _test(state): assert sd["epoch_length"] == engine.state.epoch_length assert sd["max_epochs"] == engine.state.max_epochs + if state.max_epochs is not None: + assert sd["max_epochs"] == engine.state.max_epochs + else: + assert sd["max_iters"] == engine.state.max_iters + _test(State(iteration=500, epoch_length=1000, max_epochs=100)) _test(State(epoch=5, epoch_length=1000, max_epochs=100)) + _test(State(epoch=5, epoch_length=1000, max_iters=500)) def test_state_dict_with_user_keys(): @@ -40,23 +47,40 @@ def _test(state): ) assert sd["iteration"] == engine.state.iteration assert sd["epoch_length"] == engine.state.epoch_length - assert sd["max_epochs"] == engine.state.max_epochs + if state.max_epochs is not None: + assert sd["max_epochs"] == engine.state.max_epochs + else: + assert sd["max_iters"] == engine.state.max_iters assert sd["alpha"] == engine.state.alpha assert sd["beta"] == engine.state.beta _test(State(iteration=500, epoch_length=1000, max_epochs=100, alpha=0.01, beta="Good")) + _test(State(iteration=500, epoch_length=1000, max_iters=2000, alpha=0.01, beta="Good")) -def test_state_dict_integration(): - engine = Engine(lambda e, b: 1) - data = range(100) - engine.run(data, max_epochs=10) - sd = engine.state_dict() - assert isinstance(sd, Mapping) and len(sd) == len(engine._state_dict_all_req_keys) + 1 - assert sd["iteration"] == engine.state.iteration == 10 * 100 - assert sd["epoch_length"] == engine.state.epoch_length == 100 - assert sd["max_epochs"] == engine.state.max_epochs == 10 +def test_state_dict_integration(): + def _test(max_epochs, max_iters): + engine = Engine(lambda e, b: 1) + data = range(100) + engine.run(data, max_epochs=max_epochs, max_iters=max_iters) + sd = engine.state_dict() + assert isinstance(sd, Mapping) + assert len(sd) == len(engine._state_dict_all_req_keys) + 1 + + if max_epochs is None and max_iters is None: + max_epochs = 1 + n_iters = max_iters if max_iters is not None else max_epochs * 100 + assert sd["iteration"] == engine.state.iteration == n_iters + assert sd["epoch_length"] == engine.state.epoch_length == 100 + if engine.state.max_epochs is not None: + assert sd["max_epochs"] == engine.state.max_epochs == max_epochs + else: + assert sd["max_iters"] == engine.state.max_iters == max_iters + + _test(max_epochs=10, max_iters=None) + _test(max_epochs=None, max_iters=None) + _test(max_epochs=None, max_iters=10 * 100) def test_load_state_dict_asserts(): engine = Engine(lambda e, b: 1) @@ -93,11 +117,29 @@ def _test(sd): elif "epoch" in sd: assert sd["epoch"] == engine.state.epoch assert sd["epoch_length"] == engine.state.epoch_length - assert sd["max_epochs"] == engine.state.max_epochs + if "max_epochs" in sd: + assert sd["max_epochs"] == engine.state.max_epochs + else: + assert sd["max_iters"] == engine.state.max_iters _test({"max_epochs": 100, "epoch_length": 120, "iteration": 123}) _test({"max_epochs": 100, "epoch_length": 120, "epoch": 5}) + with pytest.raises(ValueError, match=r"Argument max_epochs should be larger than"): + _test({"max_epochs": 10, "epoch_length": 120, "epoch": 50}) + + with pytest.raises(ValueError, match=r"Argument max_epochs should be larger than"): + _test({"max_epochs": 10, "epoch_length": 120, "iteration": 5000}) + + _test({"max_iters": 500, "epoch_length": 120, "iteration": 123}) + _test({"max_iters": 500, "epoch_length": 120, "epoch": 3}) + + with pytest.raises(ValueError, match=r"Argument max_iters should be larger than"): + _test({"max_iters": 500, "epoch_length": 120, "epoch": 5}) + + with pytest.raises(ValueError, match=r"Argument max_iters should be larger than"): + _test({"max_iters": 500, "epoch_length": 120, "iteration": 501}) + def test_load_state_dict_with_user_keys(): engine = Engine(lambda e, b: 1) @@ -142,8 +184,7 @@ def test_load_state_dict_with_params_overriding_integration(): assert state.max_epochs == new_max_epochs assert state.iteration == state_dict["epoch_length"] * new_max_epochs assert state.epoch == new_max_epochs - - with pytest.raises(ValueError, match=r"Argument max_epochs should be greater than or equal to the start epoch"): + with pytest.raises(ValueError, match=r"Argument max_epochs should be larger than the current epoch"): engine.load_state_dict(state_dict) engine.run(data, max_epochs=3) From a09d55c0e25a07bd488fc7d6b5699f0d9141f878 Mon Sep 17 00:00:00 2001 From: leej3 Date: Wed, 17 Apr 2024 12:43:08 +0000 Subject: [PATCH 3/8] autopep8 fix --- tests/ignite/engine/test_engine.py | 21 ++++--------------- tests/ignite/engine/test_engine_state_dict.py | 2 +- 2 files changed, 5 insertions(+), 18 deletions(-) diff --git a/tests/ignite/engine/test_engine.py b/tests/ignite/engine/test_engine.py index c1a3c466c14..9365386efc8 100644 --- a/tests/ignite/engine/test_engine.py +++ b/tests/ignite/engine/test_engine.py @@ -11,7 +11,7 @@ from ignite.engine import Engine, Events, State from ignite.engine.deterministic import keep_random_state from ignite.metrics import Average -from tests.ignite.engine import BatchChecker, EpochCounter, IterationCounter, get_iterable_dataset +from tests.ignite.engine import BatchChecker, EpochCounter, get_iterable_dataset, IterationCounter class RecordedEngine(Engine): @@ -653,7 +653,6 @@ def _test_check_triggered_events( if n_terminate is None: n_terminate = int(n_epoch_started != n_epoch_completed) if max_iters is not None else 0 - expected_num_calls = { Events.STARTED: 1, Events.COMPLETED: 1, @@ -784,7 +783,6 @@ def limited_data_iterator(): ) self._test_check_triggered_events(limited_data_iterator(), max_iters=21, epoch_length=12, **kwargs) - def test_run_check_triggered_events_on_iterator(self): self._test_run_check_triggered_events_on_iterator() @@ -892,10 +890,11 @@ def run_evaluation(_): assert train_batches[i] == train_only_batches[i] @pytest.mark.parametrize( - "kwargs", [ + "kwargs", + [ {"max_epochs": None, "epoch_length": 10, "max_iters": 25}, {"max_epochs": 5, "epoch_length": 10, "max_iters": None}, - ] + ], ) def test_engine_with_dataloader_no_auto_batching(self, kwargs): # tests https://github.com/pytorch/ignite/issues/941 @@ -964,8 +963,6 @@ def _test(**kwargs): with pytest.warns(UserWarning, match=r"Data iterator can not provide data anymore"): _test(max_iters=unknown_size * 2) - - def test_run_finite_iterator_no_epoch_length(self): # FR: https://github.com/pytorch/ignite/issues/871 unknown_size = 11 @@ -988,7 +985,6 @@ def restart_iter(): assert engine.state.epoch == 5 assert engine.state.iteration == unknown_size * 5 - def test_run_finite_iterator_no_epoch_length_2(self): # FR: https://github.com/pytorch/ignite/issues/871 known_size = 11 @@ -1023,7 +1019,6 @@ def restart_iter(): _test(max_iters=known_size) _test(max_iters=known_size // 2) - def test_faq_inf_iterator_with_epoch_length(): def _test(max_epochs, max_iters): # Code snippet from FAQ @@ -1051,7 +1046,6 @@ def train_step(trainer, batch): _test(max_epochs=3, max_iters=None) _test(max_epochs=None, max_iters=3 * 5) - def test_faq_inf_iterator_no_epoch_length(self): def _test(max_epochs, max_iters): # Code snippet from FAQ @@ -1083,7 +1077,6 @@ def stop_training(): _test(max_epochs=None, max_iters=None) _test(max_epochs=None, max_iters=100) - def test_faq_fin_iterator_unknw_size(self): def _test(max_epochs, max_iters): # Code snippet from FAQ @@ -1139,7 +1132,6 @@ def val_step(evaluator, batch): _test(max_epochs=None, max_iters=None) - def test_faq_fin_iterator(self): # Code snippet from FAQ # import torch @@ -1248,7 +1240,6 @@ def val_step(evaluator, batch): assert evaluator.state.epoch == 1 assert evaluator.state.iteration == size - def test_set_data(self): # tests FR https://github.com/pytorch/ignite/issues/833 from torch.utils.data import DataLoader @@ -1478,7 +1469,6 @@ def check_iter_epoch(first_epoch_iter): assert engine.state.epoch == 10 assert engine.state.iteration == 10 * real_epoch_length - def test_restart_training(self): data = range(10) engine = Engine(lambda e, b: 1) @@ -1493,7 +1483,6 @@ def test_restart_training(self): state.max_epochs = None engine.run(data, max_epochs=2) - def test_engine_multiple_runs(self): engine = Engine(lambda e, b: 1) engine.debug() @@ -1538,7 +1527,6 @@ def assert_resume(): assert engine.state.epoch == 6 assert engine.state.iteration == 6 * epoch_length - def test_engine_multiple_runs_2(self): e = Engine(lambda _, b: None) @@ -1557,7 +1545,6 @@ def test_engine_multiple_runs_2(self): assert e.state.epoch == 3 - @pytest.mark.parametrize( "interrupt_event, e, i", [ diff --git a/tests/ignite/engine/test_engine_state_dict.py b/tests/ignite/engine/test_engine_state_dict.py index dc5c2b0deda..06a0451acea 100644 --- a/tests/ignite/engine/test_engine_state_dict.py +++ b/tests/ignite/engine/test_engine_state_dict.py @@ -58,7 +58,6 @@ def _test(state): _test(State(iteration=500, epoch_length=1000, max_iters=2000, alpha=0.01, beta="Good")) - def test_state_dict_integration(): def _test(max_epochs, max_iters): engine = Engine(lambda e, b: 1) @@ -82,6 +81,7 @@ def _test(max_epochs, max_iters): _test(max_epochs=None, max_iters=None) _test(max_epochs=None, max_iters=10 * 100) + def test_load_state_dict_asserts(): engine = Engine(lambda e, b: 1) From e41497817b2b95f5341735d1187bde189cac8aec Mon Sep 17 00:00:00 2001 From: leej3 Date: Wed, 17 Apr 2024 13:48:22 +0100 Subject: [PATCH 4/8] add mixins test --- tests/ignite/base/test_mixins.py | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/tests/ignite/base/test_mixins.py b/tests/ignite/base/test_mixins.py index 0f3a39811fb..3bec57a873d 100644 --- a/tests/ignite/base/test_mixins.py +++ b/tests/ignite/base/test_mixins.py @@ -3,12 +3,37 @@ from ignite.base import Serializable +class ExampleSerializable(Serializable): + _state_dict_all_req_keys = ("a", "b") + _state_dict_one_of_opt_keys = (("c", "d"), ("e", "f")) + + def test_state_dict(): s = Serializable() with pytest.raises(NotImplementedError): s.state_dict() - def test_load_state_dict(): - s = Serializable() - s.load_state_dict({}) + + s = ExampleSerializable() + with pytest.raises(TypeError, match=r"Argument state_dict should be a dictionary"): + s.load_state_dict("abc") + + with pytest.raises(ValueError, match=r"is absent in provided state_dict"): + s.load_state_dict({}) + + with pytest.raises(ValueError, match=r"is absent in provided state_dict"): + s.load_state_dict({"a": 1}) + + with pytest.raises(ValueError, match=r"state_dict should contain only one of"): + s.load_state_dict({"a": 1, "b": 2}) + + with pytest.raises(ValueError, match=r"state_dict should contain only one of"): + s.load_state_dict({"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}) + + with pytest.raises(ValueError, match=r"state_dict should contain only one of"): + s.load_state_dict({"a": 1, "b": 2, "c": 3, "e": 5, "f": 5}) + + s.state_dict_user_keys.append("alpha") + with pytest.raises(ValueError, match=r"Required user state attribute"): + s.load_state_dict({"a": 1, "b": 2, "c": 3, "e": 4}) From 3f1ef202e3323821256dab960fb6372786036454 Mon Sep 17 00:00:00 2001 From: leej3 Date: Wed, 17 Apr 2024 15:54:59 +0000 Subject: [PATCH 5/8] update mixins and engine --- ignite/base/mixins.py | 26 ++++-- ignite/engine/engine.py | 126 ++++++++++++++++++++++++++++- tests/ignite/base/test_mixins.py | 1 + tests/ignite/engine/test_engine.py | 70 +++------------- 4 files changed, 152 insertions(+), 71 deletions(-) diff --git a/ignite/base/mixins.py b/ignite/base/mixins.py index 3ecb2922f03..c1b47d1e1ce 100644 --- a/ignite/base/mixins.py +++ b/ignite/base/mixins.py @@ -1,11 +1,18 @@ from collections import OrderedDict from collections.abc import Mapping -from typing import Tuple +from typing import List, Tuple class Serializable: - _state_dict_all_req_keys: Tuple = () - _state_dict_one_of_opt_keys: Tuple = () + _state_dict_all_req_keys: Tuple[str, ...] = () + _state_dict_one_of_opt_keys: Tuple[Tuple[str, ...], ...] = ((),) + + def __init__(self) -> None: + self._state_dict_user_keys: List[str] = [] + + @property + def state_dict_user_keys(self) -> List: + return self._state_dict_user_keys def state_dict(self) -> OrderedDict: raise NotImplementedError @@ -19,6 +26,13 @@ def load_state_dict(self, state_dict: Mapping) -> None: raise ValueError( f"Required state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'" ) - opts = [k in state_dict for k in self._state_dict_one_of_opt_keys] - if len(opts) > 0 and ((not any(opts)) or (all(opts))): - raise ValueError(f"state_dict should contain only one of '{self._state_dict_one_of_opt_keys}' keys") + for one_of_opt_keys in self._state_dict_one_of_opt_keys: + opts = [k in state_dict for k in one_of_opt_keys] + if len(opts) > 0 and (not any(opts)) or (all(opts)): + raise ValueError(f"state_dict should contain only one of '{one_of_opt_keys}' keys") + + for k in self._state_dict_user_keys: + if k not in state_dict: + raise ValueError( + f"Required user state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'" + ) diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index 865218af359..da07a8061ed 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -129,7 +129,16 @@ def compute_mean_std(engine, batch): """ _state_dict_all_req_keys = ("epoch_length", "max_epochs") - _state_dict_one_of_opt_keys = ("iteration", "epoch") + _state_dict_one_of_opt_keys = ( + ( + "iteration", + "epoch", + ), + ( + "max_epochs", + "max_iters", + ), + ) # Flag to disable engine._internal_run as generator feature for BC interrupt_resume_enabled = True @@ -310,6 +319,7 @@ def execute_something(): for e in event_name: self.add_event_handler(e, handler, *args, **kwargs) return RemovableEventHandle(event_name, handler, self) + if isinstance(event_name, CallableEventWithFilter) and event_name.filter is not None: event_filter = event_name.filter handler = self._handler_wrapper(handler, event_name, event_filter) @@ -332,6 +342,16 @@ def execute_something(): return RemovableEventHandle(event_name, handler, self) + @staticmethod + def _assert_non_filtered_event(event_name: Any) -> None: + if ( + isinstance(event_name, CallableEventWithFilter) + and event_name.filter != CallableEventWithFilter.default_event_filter + ): + raise TypeError( + "Argument event_name should not be a filtered event, " "please use event without any event filtering" + ) + def has_event_handler(self, handler: Callable, event_name: Optional[Any] = None) -> bool: """Check if the specified event has the specified handler. @@ -675,7 +695,12 @@ def save_engine(_): a dictionary containing engine's state """ - keys: Tuple[str, ...] = self._state_dict_all_req_keys + (self._state_dict_one_of_opt_keys[0],) + keys: Tuple[str, ...] = self._state_dict_all_req_keys + keys += ("iteration",) + if self.state.max_epochs is not None: + keys += ("max_epochs",) + else: + keys += ("max_iters",) keys += tuple(self._state_dict_user_keys) return OrderedDict([(k, getattr(self.state, k)) for k in keys]) @@ -728,6 +753,8 @@ def load_state_dict(self, state_dict: Mapping) -> None: f"Input state_dict: {state_dict}" ) self.state.iteration = self.state.epoch_length * self.state.epoch + self._check_and_set_max_epochs(state_dict.get("max_epochs", None)) + self._check_and_set_max_iters(state_dict.get("max_iters", None)) @staticmethod def _is_done(state: State) -> bool: @@ -864,12 +891,26 @@ def switch_batch(engine): epoch_length = self._get_data_length(data) if epoch_length is not None and epoch_length < 1: - raise ValueError("Input data has zero size. Please provide non-empty data") + raise ValueError( + "Argument epoch_length is invalid. Please, either set a" + " correct epoch_length value or check if input data has" + " non-zero size." + ) if max_iters is None: if max_epochs is None: max_epochs = 1 else: + if max_iters < 1: + raise ValueError("Argument max_iters is invalid. Please, set a correct max_iters positive value") + if (self.state.max_iters is not None) and max_iters <= self.state.iteration: + raise ValueError( + "Argument max_iters should be larger than the current iteration " + f"defined in the state: {max_iters} vs {self.state.iteration}. " + "Please, set engine.state.max_iters = None " + "before calling engine.run() in order to restart the training from the beginning." + ) + self.state.max_iters = max_iters if max_epochs is not None: raise ValueError( "Arguments max_iters and max_epochs are mutually exclusive." @@ -932,6 +973,53 @@ def _setup_dataloader_iter(self) -> None: else: self._dataloader_iter = iter(self.state.dataloader) + def _check_and_set_max_epochs(self, max_epochs: Optional[int] = None) -> None: + if max_epochs is not None: + if max_epochs < 1: + raise ValueError("Argument max_epochs is invalid. Please, set a correct max_epochs positive value") + if self.state.max_epochs is not None and max_epochs <= self.state.epoch: + raise ValueError( + "Argument max_epochs should be larger than the current epoch " + f"defined in the state: {max_epochs} vs {self.state.epoch}. " + "Please, set engine.state.max_epochs = None " + "before calling engine.run() in order to restart the training from the beginning." + ) + self.state.max_epochs = max_epochs + + def _check_and_set_max_iters(self, max_iters: Optional[int] = None) -> None: + if max_iters is not None: + if max_iters < 1: + raise ValueError("Argument max_iters is invalid. Please, set a correct max_iters positive value") + if (self.state.max_iters is not None) and max_iters <= self.state.iteration: + raise ValueError( + "Argument max_iters should be larger than the current iteration " + f"defined in the state: {max_iters} vs {self.state.iteration}. " + "Please, set engine.state.max_iters = None " + "before calling engine.run() in order to restart the training from the beginning." + ) + self.state.max_iters = max_iters + + def _check_and_set_epoch_length(self, data: Iterable, epoch_length: Optional[int] = None) -> None: + # Can't we accept a redefinition ? + if self.state.epoch_length is not None: + if epoch_length is not None: + if epoch_length != self.state.epoch_length: + raise ValueError( + "Argument epoch_length should be same as in the state, " + f"but given {epoch_length} vs {self.state.epoch_length}" + ) + else: + if epoch_length is None: + epoch_length = self._get_data_length(data) + + if epoch_length is not None and epoch_length < 1: + raise ValueError( + "Argument epoch_length is invalid. Please, either set a correct epoch_length value or " + "check if input data has non-zero size." + ) + + self.state.epoch_length = epoch_length + def _setup_engine(self) -> None: self._setup_dataloader_iter() @@ -1064,6 +1152,13 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]: self.state.epoch_length = iter_counter if self.state.max_iters is not None: self.state.max_epochs = math.ceil(self.state.max_iters / self.state.epoch_length) + # Warn but will continue until max iters is reached + warnings.warn( + "Data iterator can not provide data anymore but required total number of " + "iterations to run is not reached. " + f"Current iteration: {self.state.iteration} vs Total iterations to run :" + f" {self.state.max_iters}" + ) break # Should exit while loop if we can not iterate @@ -1106,7 +1201,13 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]: if self.state.max_iters is not None and self.state.iteration == self.state.max_iters: self.should_terminate = True - raise _EngineTerminateException() + warnings.warn( + "Data iterator can not provide data anymore but required total number of " + "iterations to run is not reached. " + f"Current iteration: {self.state.iteration} vs Total iterations to run : ? total_iters" + ) + break + # raise _EngineTerminateException() except _EngineTerminateSingleEpochException: self._fire_event(Events.TERMINATE_SINGLE_EPOCH, iter_counter=iter_counter) @@ -1231,6 +1332,13 @@ def _run_once_on_dataset_legacy(self) -> float: self.state.epoch_length = iter_counter if self.state.max_iters is not None: self.state.max_epochs = math.ceil(self.state.max_iters / self.state.epoch_length) + # Warn but will continue until max iters is reached + warnings.warn( + "Data iterator can not provide data anymore but required total number of " + "iterations to run is not reached. " + f"Current iteration: {self.state.iteration} vs Total iterations to run :" + f" {self.state.max_iters}" + ) break # Should exit while loop if we can not iterate @@ -1291,6 +1399,16 @@ def _run_once_on_dataset_legacy(self) -> float: return time.time() - start_time + def debug(self, enabled: bool = True) -> None: + """Enables/disables engine's logging debug mode""" + from ignite.utils import setup_logger + + if enabled: + setattr(self, "_stored_logger", self.logger) + self.logger = setup_logger(level=logging.DEBUG) + elif hasattr(self, "_stored_logger"): + self.logger = getattr(self, "_stored_logger") + def _get_none_data_iter(size: int) -> Iterator: # Sized iterator for data as None diff --git a/tests/ignite/base/test_mixins.py b/tests/ignite/base/test_mixins.py index 3bec57a873d..734384f3847 100644 --- a/tests/ignite/base/test_mixins.py +++ b/tests/ignite/base/test_mixins.py @@ -13,6 +13,7 @@ def test_state_dict(): with pytest.raises(NotImplementedError): s.state_dict() + def test_load_state_dict(): s = ExampleSerializable() diff --git a/tests/ignite/engine/test_engine.py b/tests/ignite/engine/test_engine.py index 9365386efc8..91e6d29a7fb 100644 --- a/tests/ignite/engine/test_engine.py +++ b/tests/ignite/engine/test_engine.py @@ -11,7 +11,7 @@ from ignite.engine import Engine, Events, State from ignite.engine.deterministic import keep_random_state from ignite.metrics import Average -from tests.ignite.engine import BatchChecker, EpochCounter, get_iterable_dataset, IterationCounter +from tests.ignite.engine import BatchChecker, EpochCounter, IterationCounter class RecordedEngine(Engine): @@ -520,7 +520,7 @@ def test__setup_engine(self): data = list(range(100)) engine.state.dataloader = data engine._setup_engine() - assert len(engine._init_iter) == 1 and engine._init_iter[0] == 10 + assert engine._init_iter == 10 def test_run_asserts(self): engine = Engine(lambda e, b: 1) @@ -531,7 +531,7 @@ def test_run_asserts(self): r"value or check if input data has non-zero size.", ): engine.run([]) - with pytest.raises(ValueError, match=r"Argument max_epochs should be larger than"): + with pytest.raises(ValueError, match=r"Argument max_epochs should be greater than or equal to the start epoch"): engine.state.max_epochs = 5 engine.state.epoch = 5 engine.run([0, 1], max_epochs=3) @@ -707,8 +707,8 @@ def infinite_data_iterator(): kwargs = dict(exp_iter_stops=0, n_epoch_started=2, n_epoch_completed=1) self._test_check_triggered_events(infinite_data_iterator(), max_iters=30, epoch_length=20, **kwargs) - def limited_data_iterator(): - for i in range(100): + def limited_data_iterator(length=100): + for i in range(length): yield i self._test_check_triggered_events(limited_data_iterator(), max_epochs=1, epoch_length=100, exp_iter_stops=0) @@ -730,7 +730,7 @@ def limited_data_iterator(): n_batch_completed=20, n_terminate=1, ) - self._test_check_triggered_events(limited_data_iterator(), max_epochs=3, epoch_length=20, **kwargs) + self._test_check_triggered_events(limited_data_iterator(length=20), max_epochs=3, epoch_length=20, **kwargs) with pytest.warns(UserWarning, match=r"Data iterator can not provide data anymore"): kwargs = dict( @@ -1019,7 +1019,7 @@ def restart_iter(): _test(max_iters=known_size) _test(max_iters=known_size // 2) - def test_faq_inf_iterator_with_epoch_length(): + def test_faq_inf_iterator_with_epoch_length(self): def _test(max_epochs, max_iters): # Code snippet from FAQ # import torch @@ -1132,59 +1132,6 @@ def val_step(evaluator, batch): _test(max_epochs=None, max_iters=None) - def test_faq_fin_iterator(self): - # Code snippet from FAQ - # import torch - - torch.manual_seed(12) - - size = 11 - - def finite_size_data_iter(size): - for i in range(size): - yield i - - def train_step(trainer, batch): - # ... - s = trainer.state - print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch:.3f}") - - trainer = Engine(train_step) - - @trainer.on(Events.ITERATION_COMPLETED(every=size)) - def restart_iter(): - trainer.state.dataloader = finite_size_data_iter(size) - - data_iter = finite_size_data_iter(size) - trainer.run(data_iter, max_epochs=5) - - assert trainer.state.epoch == 5 - assert trainer.state.iteration == 5 * size - - # Code snippet from FAQ - # import torch - - torch.manual_seed(12) - - size = 11 - - def finite_size_data_iter(size): - for i in range(size): - yield i - - def val_step(evaluator, batch): - # ... - s = evaluator.state - print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch:.3f}") - - evaluator = Engine(val_step) - - data_iter = finite_size_data_iter(size) - evaluator.run(data_iter) - - assert evaluator.state.epoch == 1 - assert evaluator.state.iteration == size - def test_faq_fin_iterator(self): def _test(max_epochs, max_iters): # Code snippet from FAQ @@ -1475,7 +1422,8 @@ def test_restart_training(self): state = engine.run(data, max_epochs=5) with pytest.raises( ValueError, - match=r"Argument max_epochs should be larger than the current epoch defined in the state: 2 vs 5. " + match=r"Argument max_epochs should be greater than or equal to the start epoch" + " defined in the state: 2 vs 5. " r"Please, .+ " r"before calling engine.run\(\) in order to restart the training from the beginning.", ): From 6c5ad29b2fbd08ef2061562964eef2c9f47d834a Mon Sep 17 00:00:00 2001 From: leej3 Date: Fri, 19 Apr 2024 15:26:44 +0100 Subject: [PATCH 6/8] tidy larger/greater --- ignite/engine/engine.py | 6 +++--- tests/ignite/engine/test_engine.py | 2 +- tests/ignite/engine/test_engine_state_dict.py | 10 +++++----- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index da07a8061ed..f38d4bd26be 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -905,7 +905,7 @@ def switch_batch(engine): raise ValueError("Argument max_iters is invalid. Please, set a correct max_iters positive value") if (self.state.max_iters is not None) and max_iters <= self.state.iteration: raise ValueError( - "Argument max_iters should be larger than the current iteration " + "Argument max_iters should be greater than the current iteration " f"defined in the state: {max_iters} vs {self.state.iteration}. " "Please, set engine.state.max_iters = None " "before calling engine.run() in order to restart the training from the beginning." @@ -979,7 +979,7 @@ def _check_and_set_max_epochs(self, max_epochs: Optional[int] = None) -> None: raise ValueError("Argument max_epochs is invalid. Please, set a correct max_epochs positive value") if self.state.max_epochs is not None and max_epochs <= self.state.epoch: raise ValueError( - "Argument max_epochs should be larger than the current epoch " + "Argument max_epochs should be greater than the current epoch " f"defined in the state: {max_epochs} vs {self.state.epoch}. " "Please, set engine.state.max_epochs = None " "before calling engine.run() in order to restart the training from the beginning." @@ -992,7 +992,7 @@ def _check_and_set_max_iters(self, max_iters: Optional[int] = None) -> None: raise ValueError("Argument max_iters is invalid. Please, set a correct max_iters positive value") if (self.state.max_iters is not None) and max_iters <= self.state.iteration: raise ValueError( - "Argument max_iters should be larger than the current iteration " + "Argument max_iters should be greater than the current iteration " f"defined in the state: {max_iters} vs {self.state.iteration}. " "Please, set engine.state.max_iters = None " "before calling engine.run() in order to restart the training from the beginning." diff --git a/tests/ignite/engine/test_engine.py b/tests/ignite/engine/test_engine.py index 91e6d29a7fb..d9a1911adda 100644 --- a/tests/ignite/engine/test_engine.py +++ b/tests/ignite/engine/test_engine.py @@ -536,7 +536,7 @@ def test_run_asserts(self): engine.state.epoch = 5 engine.run([0, 1], max_epochs=3) - with pytest.raises(ValueError, match=r"Argument max_iters should be larger than"): + with pytest.raises(ValueError, match=r"Argument max_iters should be greater than the current"): engine.state.max_iters = 100 engine.state.iteration = 100 engine.run([0, 1], max_iters=50) diff --git a/tests/ignite/engine/test_engine_state_dict.py b/tests/ignite/engine/test_engine_state_dict.py index 06a0451acea..1c4934eb6b0 100644 --- a/tests/ignite/engine/test_engine_state_dict.py +++ b/tests/ignite/engine/test_engine_state_dict.py @@ -125,19 +125,19 @@ def _test(sd): _test({"max_epochs": 100, "epoch_length": 120, "iteration": 123}) _test({"max_epochs": 100, "epoch_length": 120, "epoch": 5}) - with pytest.raises(ValueError, match=r"Argument max_epochs should be larger than"): + with pytest.raises(ValueError, match=r"Argument max_epochs should be greater than the current"): _test({"max_epochs": 10, "epoch_length": 120, "epoch": 50}) - with pytest.raises(ValueError, match=r"Argument max_epochs should be larger than"): + with pytest.raises(ValueError, match=r"Argument max_epochs should be greater than the current"): _test({"max_epochs": 10, "epoch_length": 120, "iteration": 5000}) _test({"max_iters": 500, "epoch_length": 120, "iteration": 123}) _test({"max_iters": 500, "epoch_length": 120, "epoch": 3}) - with pytest.raises(ValueError, match=r"Argument max_iters should be larger than"): + with pytest.raises(ValueError, match=r"Argument max_iters should be greater than"): _test({"max_iters": 500, "epoch_length": 120, "epoch": 5}) - with pytest.raises(ValueError, match=r"Argument max_iters should be larger than"): + with pytest.raises(ValueError, match=r"Argument max_iters should be greater than"): _test({"max_iters": 500, "epoch_length": 120, "iteration": 501}) @@ -184,7 +184,7 @@ def test_load_state_dict_with_params_overriding_integration(): assert state.max_epochs == new_max_epochs assert state.iteration == state_dict["epoch_length"] * new_max_epochs assert state.epoch == new_max_epochs - with pytest.raises(ValueError, match=r"Argument max_epochs should be larger than the current epoch"): + with pytest.raises(ValueError, match=r"Argument max_epochs should be greater than the current epoch"): engine.load_state_dict(state_dict) engine.run(data, max_epochs=3) From 81d002d5f555e859055645d4bf48251f203b62b6 Mon Sep 17 00:00:00 2001 From: leej3 Date: Fri, 19 Apr 2024 16:14:05 +0100 Subject: [PATCH 7/8] if data is none then batches are not run --- tests/ignite/engine/test_engine.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/ignite/engine/test_engine.py b/tests/ignite/engine/test_engine.py index d9a1911adda..004e8573bf5 100644 --- a/tests/ignite/engine/test_engine.py +++ b/tests/ignite/engine/test_engine.py @@ -645,10 +645,16 @@ def _test_check_triggered_events( n_iter_completed = max_epochs * epoch_length if max_epochs is not None else max_iters if n_batch_started is None: - n_batch_started = max_epochs * epoch_length if max_epochs is not None else max_iters + if data is None: + n_batch_started = 0 + else: + n_batch_started = max_epochs * epoch_length if max_epochs is not None else max_iters if n_batch_completed is None: - n_batch_completed = max_epochs * epoch_length if max_epochs is not None else max_iters + if data is None: + n_batch_completed = 0 + else: + n_batch_completed = max_epochs * epoch_length if max_epochs is not None else max_iters if n_terminate is None: n_terminate = int(n_epoch_started != n_epoch_completed) if max_iters is not None else 0 From 21f4a88da0fd188a92aee17516e0d5bd94332bfb Mon Sep 17 00:00:00 2001 From: leej3 Date: Wed, 24 Apr 2024 13:40:19 +0100 Subject: [PATCH 8/8] hitting max_iters is not epoch completion --- ignite/engine/engine.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index f38d4bd26be..277498ac4a8 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -1064,17 +1064,19 @@ def _internal_run_as_gen(self) -> Generator: self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken handlers_start_time = time.time() - self._fire_event(Events.EPOCH_COMPLETED) - epoch_time_taken += time.time() - handlers_start_time - # update time wrt handlers - self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken + if self.state.epoch_length is not None and self.state.iteration % self.state.epoch_length == 0: + # max_iters can cause training to complete without an epoch ending + self._fire_event(Events.EPOCH_COMPLETED) + epoch_time_taken += time.time() - handlers_start_time + # update time wrt handlers + self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken + + hours, mins, secs = _to_hours_mins_secs(epoch_time_taken) + self.logger.info( + f"Epoch[{self.state.epoch}] Complete. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}" + ) yield from self._maybe_terminate_or_interrupt() - hours, mins, secs = _to_hours_mins_secs(epoch_time_taken) - self.logger.info( - f"Epoch[{self.state.epoch}] Complete. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}" - ) - except _EngineTerminateException: self._fire_event(Events.TERMINATE)