-
Notifications
You must be signed in to change notification settings - Fork 198
/
demo.cpp
210 lines (178 loc) · 5.82 KB
/
demo.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
#include <iostream>
#include <opencv2/opencv.hpp>
using namespace std;
using namespace cv;
using namespace dnn;
struct TrackingResult
{
bool isLocated;
Rect bbox;
float score;
};
class VitTrack
{
public:
VitTrack(const string& model_path, int backend_id = 0, int target_id = 0)
{
params.net = model_path;
params.backend = backend_id;
params.target = target_id;
model = TrackerVit::create(params);
}
void init(const Mat& image, const Rect& roi)
{
model->init(image, roi);
}
TrackingResult infer(const Mat& image)
{
TrackingResult result;
result.isLocated = model->update(image, result.bbox);
result.score = model->getTrackingScore();
return result;
}
private:
TrackerVit::Params params;
Ptr<TrackerVit> model;
};
Mat visualize(const Mat& image, const Rect& bbox, float score, bool isLocated, double fps = -1.0,
const Scalar& box_color = Scalar(0, 255, 0), const Scalar& text_color = Scalar(0, 255, 0),
double fontScale = 1.0, int fontSize = 1)
{
Mat output = image.clone();
int h = output.rows;
int w = output.cols;
if (fps >= 0)
{
putText(output, "FPS: " + to_string(fps), Point(0, 30), FONT_HERSHEY_DUPLEX, fontScale, text_color, fontSize);
}
if (isLocated && score >= 0.3)
{
rectangle(output, bbox, box_color, 2);
putText(output, format("%.2f", score), Point(bbox.x, bbox.y + 25),
FONT_HERSHEY_DUPLEX, fontScale, text_color, fontSize);
}
else
{
Size text_size = getTextSize("Target lost!", FONT_HERSHEY_DUPLEX, fontScale, fontSize, nullptr);
int text_x = (w - text_size.width) / 2;
int text_y = (h - text_size.height) / 2;
putText(output, "Target lost!", Point(text_x, text_y), FONT_HERSHEY_DUPLEX, fontScale, Scalar(0, 0, 255), fontSize);
}
return output;
}
int main(int argc, char** argv)
{
CommandLineParser parser(argc, argv,
"{help h | | Print help message. }"
"{input i | |Set path to the input video. Omit for using default camera.}"
"{model_path |object_tracking_vittrack_2023sep.onnx |Set model path}"
"{backend_target bt |0 |Choose backend-target pair: 0 - OpenCV implementation + CPU, 1 - CUDA + GPU (CUDA), 2 - CUDA + GPU (CUDA FP16), 3 - TIM-VX + NPU, 4 - CANN + NPU}"
"{save s |false |Specify to save a file with results.}"
"{vis v |true |Specify to open a new window to show results.}");
if (parser.has("help"))
{
parser.printMessage();
return 0;
}
string input = parser.get<string>("input");
string model_path = parser.get<string>("model_path");
int backend_target = parser.get<int>("backend_target");
bool save = parser.get<bool>("save");
bool vis = parser.get<bool>("vis");
vector<vector<int>> backend_target_pairs =
{
{DNN_BACKEND_OPENCV, DNN_TARGET_CPU},
{DNN_BACKEND_CUDA, DNN_TARGET_CUDA},
{DNN_BACKEND_CUDA, DNN_TARGET_CUDA_FP16},
{DNN_BACKEND_TIMVX, DNN_TARGET_NPU},
{DNN_BACKEND_CANN, DNN_TARGET_NPU}
};
int backend_id = backend_target_pairs[backend_target][0];
int target_id = backend_target_pairs[backend_target][1];
// Create VitTrack tracker
VitTrack tracker(model_path, backend_id, target_id);
// Open video capture
VideoCapture video;
if (input.empty())
{
video.open(0); // Default camera
}
else
{
video.open(input);
}
if (!video.isOpened())
{
cerr << "Error: Could not open video source" << endl;
return -1;
}
// Select an object
Mat first_frame;
video >> first_frame;
if (first_frame.empty())
{
cerr << "No frames grabbed!" << endl;
return -1;
}
Mat first_frame_copy = first_frame.clone();
putText(first_frame_copy, "1. Drag a bounding box to track.", Point(0, 25), FONT_HERSHEY_SIMPLEX, 1, Scalar(0, 255, 0));
putText(first_frame_copy, "2. Press ENTER to confirm", Point(0, 50), FONT_HERSHEY_SIMPLEX, 1, Scalar(0, 255, 0));
Rect roi = selectROI("VitTrack Demo", first_frame_copy);
if (roi.area() == 0)
{
cerr << "No ROI is selected! Exiting..." << endl;
return -1;
}
else
{
cout << "Selected ROI: " << roi << endl;
}
// Create VideoWriter if save option is specified
VideoWriter output_video;
if (save)
{
Size frame_size = first_frame.size();
output_video.open("output.mp4", VideoWriter::fourcc('m', 'p', '4', 'v'), video.get(CAP_PROP_FPS), frame_size);
if (!output_video.isOpened())
{
cerr << "Error: Could not create output video stream" << endl;
return -1;
}
}
// Initialize tracker with ROI
tracker.init(first_frame, roi);
// Track frame by frame
TickMeter tm;
while (waitKey(1) < 0)
{
video >> first_frame;
if (first_frame.empty())
{
cout << "End of video" << endl;
break;
}
// Inference
tm.start();
TrackingResult result = tracker.infer(first_frame);
tm.stop();
// Visualize
Mat frame = first_frame.clone();
frame = visualize(frame, result.bbox, result.score, result.isLocated, tm.getFPS());
if (save)
{
output_video.write(frame);
}
if (vis)
{
imshow("VitTrack Demo", frame);
}
tm.reset();
}
if (save)
{
output_video.release();
}
video.release();
destroyAllWindows();
return 0;
}