-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathrun.py
98 lines (78 loc) · 2.66 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import time
import tqdm
import torch
import numpy as np
from lib.utils.net_utils import load_network
from lib.utils.data_utils import to_cuda
from lib.evaluators import make_evaluator
from lib.visualizers import make_visualizer
from lib.networks.renderer import make_renderer
from lib.datasets import make_data_loader
from lib.networks import make_network
from lib.config import cfg, args
from lib.utils.log_utils import log
import cv2
cv2.setNumThreads(1)
cfg.train.num_workers = 0 # no multi-process dataloading needed when visualizing
@torch.no_grad()
def run_dataset():
from lib.datasets import make_data_loader
import tqdm
data_loader = make_data_loader(cfg, is_train=False)
for batch in tqdm.tqdm(data_loader):
pass
@torch.no_grad()
def run_network():
network = make_network(cfg).cuda()
load_network(network, cfg.trained_model_dir, epoch=cfg.test.epoch)
network.eval()
renderer = make_renderer(cfg, network)
data_loader = make_data_loader(cfg, is_train=False)
total_time = 0
for batch in tqdm.tqdm(data_loader):
batch = to_cuda(batch)
torch.cuda.synchronize()
start = time.time()
output = renderer.render(batch)
torch.cuda.synchronize()
total_time += time.time() - start
log(total_time / len(data_loader))
@torch.no_grad()
def run_evaluate():
network = make_network(cfg).cuda()
load_network(network, cfg.trained_model_dir, epoch=cfg.test.epoch)
network.eval()
data_loader = make_data_loader(cfg, is_train=False)
renderer = make_renderer(cfg, network)
evaluator = make_evaluator(cfg)
for batch in tqdm.tqdm(data_loader):
batch = to_cuda(batch)
output = renderer.render(batch)
evaluator.evaluate(output, batch)
evaluator.summarize()
@torch.no_grad()
def run_visualize():
network = make_network(cfg).cuda()
load_network(network, cfg.trained_model_dir, epoch=cfg.test.epoch)
network.eval()
diffs = []
data_loader = make_data_loader(cfg, is_train=False)
renderer = make_renderer(cfg, network)
visualizer = make_visualizer(cfg)
for batch in tqdm.tqdm(data_loader):
batch = to_cuda(batch, 'cuda')
output = renderer.render(batch)
if 'diff' in output:
diffs.append(output.diff)
del output.diff
visualizer.visualize(output, batch)
visualizer.summarize()
if len(diffs):
log(f'###################{cfg.exp_name}###################', 'green')
log(f'Net work rendering time: {np.mean(diffs)}', 'green')
if __name__ == '__main__':
try:
globals()['run_' + args.type]()
except:
import pdbr
pdbr.post_mortem()