-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy patheval_segmentation.py
82 lines (58 loc) · 2.22 KB
/
eval_segmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#!/usr/bin/env python3
# Python standard library
import os
# Public libraries
import torch
from torchvision.utils import save_image
# Local imports
import colors
from arguments import SegmentationEvaluationArguments
from harness import Harness
class SegmentationEvaluator(Harness):
def _init_resampler(self, opt):
pass
def _init_validation(self, opt):
self.val_num_log_images = opt.eval_num_images
self.eval_name = opt.model_name
def evaluate(self):
print('Evaluate segmentation predictions:', flush=True)
scores, images = self._run_segmentation_validation(
self.val_num_log_images
)
for domain in scores:
print('eval_name | domain | miou | accuracy')
metrics = scores[domain].get_scores()
miou = metrics['meaniou']
acc = metrics['meanacc']
print(f'{self.eval_name:12} | {domain:20} | {miou:8.3f} | {acc:8.3f}', flush=True)
for domain in images:
domain_dir = os.path.join(self.log_path, 'eval_images', domain)
os.makedirs(domain_dir, exist_ok=True)
for i, (color_gt, seg_gt, seg_pred) in enumerate(images[domain]):
image_path = os.path.join(domain_dir, f'img_{i}.png')
logged_images = (
color_gt,
colors.seg_idx_image(seg_pred),
colors.seg_idx_image(seg_gt),
)
save_image(
torch.cat(logged_images, 2).clamp(0, 1),
image_path
)
self._log_gpu_memory()
return scores
if __name__ == "__main__":
opt = SegmentationEvaluationArguments().parse()
if opt.model_load is None:
raise Exception('You must use --model-load to select a model state directory to run evaluation on')
if opt.sys_best_effort_determinism:
import random
import numpy as np
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed(1)
random.seed(1)
evaluator = SegmentationEvaluator(opt)
evaluator.evaluate()