-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathtrain.py
194 lines (159 loc) · 5.81 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
from typing import Any, List, Optional, Tuple
import hydra
import pyrootutils
from omegaconf import DictConfig
from pytorch_lightning import (
Callback,
LightningDataModule,
LightningModule,
Trainer,
seed_everything,
)
from pytorch_lightning.loggers import LightningLoggerBase
from src import utils
# --------------------------------------------------------------------------- #
# `pyrootutils.setup_root(...)` above is optional line to make environment more
# convenient should be placed at the top of each entry file
#
# main advantages:
# - allows you to keep all entry files in "src/" without installing project as
# a package
# - launching python file works no matter where is your current work dir
# - automatically loads environment variables from ".env" if exists
#
# how it works:
# - `setup_root()` above recursively searches for either ".git" or
# "pyproject.toml" in present and parent dirs, to determine the project root
# dir
# - adds root dir to the PYTHONPATH (if `pythonpath=True`), so this file can
# be run from any place without installing project as a package
# - sets PROJECT_ROOT environment variable which is used in
# "configs/paths/default.yaml" to make all paths always relative to project
# root
# - loads environment variables from ".env" in root dir (if `dotenv=True`)
#
# you can remove `pyrootutils.setup_root(...)` if you:
# 1. either install project as a package or move each entry file to the project
# root dir
# 2. remove PROJECT_ROOT variable from paths in "configs/paths/default.yaml"
#
# https://github.com/ashleve/pyrootutils
# --------------------------------------------------------------------------- #
root = pyrootutils.setup_root(
search_from=__file__,
indicator=[".git", "pyproject.toml"],
pythonpath=True,
dotenv=True,
)
_HYDRA_PARAMS = {
"version_base": "1.3",
"config_path": str(root / "configs"),
"config_name": "train.yaml",
}
log = utils.get_pylogger(__name__)
@utils.task_wrapper
def train(cfg: DictConfig) -> Tuple[dict, dict]:
"""Trains the model. Can additionally evaluate on a testset, using best
weights obtained during training.
This method is wrapped in optional @task_wrapper decorator which applies
extra utilities before and after the call.
Args:
cfg (DictConfig): Configuration composed by Hydra.
Returns:
Tuple[dict, dict]: Dict with metrics and dict with all instantiated
objects.
"""
utils.log_gpu_memory_metadata()
# set seed for random number generators in pytorch, numpy and python.random
if cfg.get("seed"):
log.info(f"Seed everything with <{cfg.seed}>")
seed_everything(cfg.seed, workers=True)
# Init lightning datamodule
log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")
datamodule: LightningDataModule = hydra.utils.instantiate(
cfg.datamodule, _recursive_=False
)
# Init lightning model
log.info(f"Instantiating lightning model <{cfg.module._target_}>")
model: LightningModule = hydra.utils.instantiate(
cfg.module, _recursive_=False
)
# Init callbacks
log.info("Instantiating callbacks...")
callbacks: List[Callback] = utils.instantiate_callbacks(
cfg.get("callbacks")
)
# Init loggers
log.info("Instantiating loggers...")
logger: List[LightningLoggerBase] = utils.instantiate_loggers(
cfg.get("logger")
)
# Init lightning ddp plugins
log.info("Instantiating plugins...")
plugins: Optional[List[Any]] = utils.instantiate_plugins(cfg)
# Init lightning trainer
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(
cfg.trainer, callbacks=callbacks, logger=logger, plugins=plugins
)
# Send parameters from cfg to all lightning loggers
object_dict = {
"cfg": cfg,
"datamodule": datamodule,
"model": model,
"callbacks": callbacks,
"logger": logger,
"trainer": trainer,
}
if logger:
log.info("Logging hyperparameters!")
utils.log_hyperparameters(object_dict)
# Log metadata
log.info("Logging metadata!")
utils.log_metadata(cfg)
# Train the model
if cfg.get("train"):
log.info("Starting training!")
trainer.fit(
model=model,
datamodule=datamodule,
ckpt_path=cfg.get("ckpt_path"),
)
train_metrics = trainer.callback_metrics
# Test the model
if cfg.get("test"):
log.info("Starting testing!")
ckpt_path = trainer.checkpoint_callback.best_model_path
if ckpt_path == "":
log.warning(
"Best ckpt not found! Using current weights for testing..."
)
ckpt_path = None
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
log.info(f"Best ckpt path: {ckpt_path}")
test_metrics = trainer.callback_metrics
# Save state dicts for best and last checkpoints
if cfg.get("save_state_dict"):
log.info("Starting saving state dicts!")
utils.save_state_dicts(
trainer=trainer,
model=model,
dirname=cfg.paths.output_dir,
**cfg.extras.state_dict_saving_params,
)
# merge train and test metrics
metric_dict = {**train_metrics, **test_metrics}
return metric_dict, object_dict
@utils.register_custom_resolvers(**_HYDRA_PARAMS)
@hydra.main(**_HYDRA_PARAMS)
def main(cfg: DictConfig) -> Optional[float]:
# train the model
metric_dict, _ = train(cfg)
# safely retrieve metric value for hydra-based hyperparameter optimization
metric_value = utils.get_metric_value(
metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")
)
# return optimized metric
return metric_value
if __name__ == "__main__":
main()