-
Notifications
You must be signed in to change notification settings - Fork 198
/
demo.py
125 lines (105 loc) · 4.83 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# This file is part of OpenCV Zoo project.
# It is subject to the license terms in the LICENSE file found in the same directory.
import argparse
import numpy as np
import cv2 as cv
# Check OpenCV version
opencv_python_version = lambda str_version: tuple(map(int, (str_version.split("."))))
assert opencv_python_version(cv.__version__) >= opencv_python_version("4.10.0"), \
"Please install latest opencv-python for benchmark: python3 -m pip install --upgrade opencv-python"
from vittrack import VitTrack
# Valid combinations of backends and targets
backend_target_pairs = [
[cv.dnn.DNN_BACKEND_OPENCV, cv.dnn.DNN_TARGET_CPU],
[cv.dnn.DNN_BACKEND_CUDA, cv.dnn.DNN_TARGET_CUDA],
[cv.dnn.DNN_BACKEND_CUDA, cv.dnn.DNN_TARGET_CUDA_FP16],
[cv.dnn.DNN_BACKEND_TIMVX, cv.dnn.DNN_TARGET_NPU],
[cv.dnn.DNN_BACKEND_CANN, cv.dnn.DNN_TARGET_NPU]
]
parser = argparse.ArgumentParser(
description="VIT track opencv API")
parser.add_argument('--input', '-i', type=str,
help='Usage: Set path to the input video. Omit for using default camera.')
parser.add_argument('--model_path', type=str, default='object_tracking_vittrack_2023sep.onnx',
help='Usage: Set model path')
parser.add_argument('--backend_target', '-bt', type=int, default=0,
help='''Choose one of the backend-target pair to run this demo:
{:d}: (default) OpenCV implementation + CPU,
{:d}: CUDA + GPU (CUDA),
{:d}: CUDA + GPU (CUDA FP16),
{:d}: TIM-VX + NPU,
{:d}: CANN + NPU
'''.format(*[x for x in range(len(backend_target_pairs))]))
parser.add_argument('--save', '-s', action='store_true', default=False,
help='Usage: Specify to save a file with results.')
parser.add_argument('--vis', '-v', action='store_true', default=True,
help='Usage: Specify to open a new window to show results.')
args = parser.parse_args()
def visualize(image, bbox, score, isLocated, fps=None, box_color=(0, 255, 0),text_color=(0, 255, 0), fontScale = 1, fontSize = 1):
output = image.copy()
h, w, _ = output.shape
if fps is not None:
cv.putText(output, 'FPS: {:.2f}'.format(fps), (0, 30), cv.FONT_HERSHEY_DUPLEX, fontScale, text_color, fontSize)
if isLocated and score >= 0.3:
# bbox: Tuple of length 4
x, y, w, h = bbox
cv.rectangle(output, (x, y), (x+w, y+h), box_color, 2)
cv.putText(output, '{:.2f}'.format(score), (x, y+25), cv.FONT_HERSHEY_DUPLEX, fontScale, text_color, fontSize)
else:
text_size, baseline = cv.getTextSize('Target lost!', cv.FONT_HERSHEY_DUPLEX, fontScale, fontSize)
text_x = int((w - text_size[0]) / 2)
text_y = int((h - text_size[1]) / 2)
cv.putText(output, 'Target lost!', (text_x, text_y), cv.FONT_HERSHEY_DUPLEX, fontScale, (0, 0, 255), fontSize)
return output
if __name__ == '__main__':
backend_id = backend_target_pairs[args.backend_target][0]
target_id = backend_target_pairs[args.backend_target][1]
model = VitTrack(
model_path=args.model_path,
backend_id=backend_id,
target_id=target_id)
# Read from args.input
_input = 0 if args.input is None else args.input
video = cv.VideoCapture(_input)
# Select an object
has_frame, first_frame = video.read()
if not has_frame:
print('No frames grabbed!')
exit()
first_frame_copy = first_frame.copy()
cv.putText(first_frame_copy, "1. Drag a bounding box to track.", (0, 25), cv.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0))
cv.putText(first_frame_copy, "2. Press ENTER to confirm", (0, 50), cv.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0))
roi = cv.selectROI('VitTrack Demo', first_frame_copy)
if np.all(np.array(roi) == 0):
print("No ROI is selected! Exiting ...")
exit()
else:
print("Selected ROI: {}".format(roi))
if args.save:
fps = video.get(cv.CAP_PROP_FPS)
frame_size = (first_frame.shape[1], first_frame.shape[0])
output_video = cv.VideoWriter('output.mp4', cv.VideoWriter_fourcc(*'mp4v'), fps, frame_size)
# Init tracker with ROI
model.init(first_frame, roi)
# Track frame by frame
tm = cv.TickMeter()
while cv.waitKey(1) < 0:
has_frame, frame = video.read()
if not has_frame:
print('End of video')
break
# Inference
tm.start()
isLocated, bbox, score = model.infer(frame)
tm.stop()
# Visualize
frame = visualize(frame, bbox, score, isLocated, fps=tm.getFPS())
if args.save:
output_video.write(frame)
if args.vis:
cv.imshow('VitTrack Demo', frame)
tm.reset()
if args.save:
output_video.release()
video.release()
cv.destroyAllWindows()