-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathvisualize_data.py
135 lines (118 loc) · 4.35 KB
/
visualize_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import argparse
import os
from itertools import chain
from pathlib import Path
import cv2
import numpy as np
import tqdm
from detectron2.config import get_cfg
from detectron2.data import (
DatasetCatalog,
MetadataCatalog,
build_detection_train_loader,
)
from detectron2.data import detection_utils as utils
from detectron2.data.build import filter_images_with_few_keypoints
from detectron2.data.datasets import register_coco_instances
from detectron2.utils.logger import setup_logger
from detectron2.utils.visualizer import Visualizer
from PIL import Image
# Register datasets
from detectron2_1.datasets import BenignMapper
def setup(args):
cfg = get_cfg()
if args.config_file:
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
return cfg
def parse_args(in_args=None):
parser = argparse.ArgumentParser(description="Visualize ground-truth data")
parser.add_argument(
"--source",
choices=["annotation", "dataloader"],
required=True,
help="visualize the annotations or the data loader (with pre-processing)",
)
parser.add_argument(
"--config-file", default="", metavar="FILE", help="path to config file"
)
parser.add_argument("--output-dir", default="./", help="path to output directory")
parser.add_argument("--show", action="store_true", help="show output in a window")
parser.add_argument(
"--num-imgs", type=int, help="number of images to visualize and save"
)
parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
return parser.parse_args(in_args)
if __name__ == "__main__":
args = parse_args()
logger = setup_logger()
logger.info("arguments: " + str(args))
cfg = setup(args)
dirname = args.output_dir
os.makedirs(dirname, exist_ok=True)
metadata = MetadataCatalog.get(cfg.DATASETS.TRAIN[0])
def output(vis, fname):
if args.show:
print(fname)
cv2.imshow("window", vis.get_image()[:, :, ::-1])
cv2.waitKey()
else:
filepath = os.path.join(dirname, fname)
print("Saving to {} ...".format(filepath))
vis.save(filepath)
scale = 2.0 if args.show else 1.0
if args.source == "dataloader":
train_data_loader = build_detection_train_loader(
cfg, mapper=BenignMapper(cfg, is_train=True)
)
i = 0
for batch in train_data_loader:
for per_image in batch:
# Pytorch tensor is in (C, H, W) format
img = per_image["image"].permute(1, 2, 0)
if cfg.INPUT.FORMAT == "BGR":
img = img[:, :, [2, 1, 0]]
else:
img = np.asarray(
Image.fromarray(img, mode=cfg.INPUT.FORMAT).convert("RGB")
)
visualizer = Visualizer(img, metadata=metadata, scale=scale)
target_fields = per_image["instances"].get_fields()
labels = [
metadata.thing_classes[i] for i in target_fields["gt_classes"]
]
vis = visualizer.overlay_instances(
labels=labels,
boxes=target_fields.get("gt_boxes", None),
masks=target_fields.get("gt_masks", None),
keypoints=target_fields.get("gt_keypoints", None),
)
output(vis, str(per_image["image_id"]) + ".jpg")
i += 1
if args.num_imgs == i:
break
if args.num_imgs == i:
break
else:
dicts = list(
chain.from_iterable([DatasetCatalog.get(k) for k in cfg.DATASETS.TRAIN])
)
if cfg.MODEL.KEYPOINT_ON:
dicts = filter_images_with_few_keypoints(dicts, 1)
i = 0
for dic in tqdm.tqdm(dicts):
img = utils.read_image(dic["file_name"], "RGB")
visualizer = Visualizer(img, metadata=metadata, scale=scale)
vis = visualizer.draw_dataset_dict(dic)
output(vis, os.path.basename(dic["file_name"]))
i += 1
if args.num_imgs == i:
break