-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy patheval.py
151 lines (125 loc) · 4.42 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from typing import List, Tuple
import hydra
import pyrootutils
from omegaconf import DictConfig
from pytorch_lightning import (
LightningDataModule,
LightningModule,
Trainer,
seed_everything,
)
from pytorch_lightning.loggers import LightningLoggerBase
from src import utils
# --------------------------------------------------------------------------- #
# `pyrootutils.setup_root(...)` above is optional line to make environment more
# convenient should be placed at the top of each entry file
#
# main advantages:
# - allows you to keep all entry files in "src/" without installing project as
# a package
# - launching python file works no matter where is your current work dir
# - automatically loads environment variables from ".env" if exists
#
# how it works:
# - `setup_root()` above recursively searches for either ".git" or
# "pyproject.toml" in present and parent dirs, to determine the project root
# dir
# - adds root dir to the PYTHONPATH (if `pythonpath=True`), so this file can
# be run from any place without installing project as a package
# - sets PROJECT_ROOT environment variable which is used in
# "configs/paths/default.yaml" to make all paths always relative to project
# root
# - loads environment variables from ".env" in root dir (if `dotenv=True`)
#
# you can remove `pyrootutils.setup_root(...)` if you:
# 1. either install project as a package or move each entry file to the project
# root dir
# 2. remove PROJECT_ROOT variable from paths in "configs/paths/default.yaml"
#
# https://github.com/ashleve/pyrootutils
# --------------------------------------------------------------------------- #
root = pyrootutils.setup_root(
search_from=__file__,
indicator=[".git", "pyproject.toml"],
pythonpath=True,
dotenv=True,
)
_HYDRA_PARAMS = {
"version_base": "1.3",
"config_path": str(root / "configs"),
"config_name": "eval.yaml",
}
log = utils.get_pylogger(__name__)
@utils.task_wrapper
def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:
"""Evaluates given checkpoint on a datamodule testset.
This method is wrapped in optional @task_wrapper decorator which applies
extra utilities before and after the call.
Args:
cfg (DictConfig): Configuration composed by Hydra.
Returns:
Tuple[dict, dict]: Dict with metrics and dict with all instantiated
objects.
"""
assert cfg.ckpt_path
utils.log_gpu_memory_metadata()
# set seed for random number generators in pytorch, numpy and python.random
if cfg.get("seed"):
log.info(f"Seed everything with <{cfg.seed}>")
seed_everything(cfg.seed, workers=True)
# Init lightning datamodule
log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")
datamodule: LightningDataModule = hydra.utils.instantiate(
cfg.datamodule, _recursive_=False
)
# Init lightning model
log.info(f"Instantiating lightning model <{cfg.module._target_}>")
model: LightningModule = hydra.utils.instantiate(
cfg.module, _recursive_=False
)
log.info("Instantiating loggers...")
logger: List[LightningLoggerBase] = utils.instantiate_loggers(
cfg.get("logger")
)
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger)
object_dict = {
"cfg": cfg,
"datamodule": datamodule,
"model": model,
"logger": logger,
"trainer": trainer,
}
if logger:
log.info("Logging hyperparameters!")
utils.log_hyperparameters(object_dict)
# Log metadata
log.info("Logging metadata!")
utils.log_metadata(cfg)
if cfg.get("predict"):
log.info("Starting predicting!")
predictions = trainer.predict(
model=model,
datamodule=datamodule,
ckpt_path=cfg.ckpt_path,
)
utils.save_predictions(
predictions=predictions,
dirname=cfg.paths.output_dir,
**cfg.extras.predictions_saving_params,
)
else:
log.info("Starting testing!")
trainer.test(
model=model,
datamodule=datamodule,
ckpt_path=cfg.ckpt_path,
)
metric_dict = trainer.callback_metrics
return metric_dict, object_dict
@utils.register_custom_resolvers(**_HYDRA_PARAMS)
@hydra.main(**_HYDRA_PARAMS)
def main(cfg: DictConfig) -> None:
evaluate(cfg)
if __name__ == "__main__":
main()