1// This file is part of OpenCV project.
2// It is subject to the license terms in the LICENSE file found in the top-level directory
3// of this distribution and at http://opencv.org/license.html.
4
5// Author, PengyuLiu, 1872918507@qq.com
6
7#include "../precomp.hpp"
8#ifdef HAVE_OPENCV_DNN
9#include "opencv2/dnn.hpp"
10#endif
11
12namespace cv {
13
14TrackerVit::TrackerVit()
15{
16 // nothing
17}
18
19TrackerVit::~TrackerVit()
20{
21 // nothing
22}
23
24TrackerVit::Params::Params()
25{
26 net = "vitTracker.onnx";
27 meanvalue = Scalar{0.485, 0.456, 0.406}; // normalized mean (already divided by 255)
28 stdvalue = Scalar{0.229, 0.224, 0.225}; // normalized std (already divided by 255)
29#ifdef HAVE_OPENCV_DNN
30 backend = dnn::DNN_BACKEND_DEFAULT;
31 target = dnn::DNN_TARGET_CPU;
32#else
33 backend = -1; // invalid value
34 target = -1; // invalid value
35#endif
36 tracking_score_threshold = 0.20f; // safe threshold to filter out black frames
37}
38
39#ifdef HAVE_OPENCV_DNN
40
41class TrackerVitImpl : public TrackerVit
42{
43public:
44 TrackerVitImpl(const TrackerVit::Params& parameters)
45 {
46 net = dnn::readNet(model: parameters.net);
47 CV_Assert(!net.empty());
48
49 net.setPreferableBackend(parameters.backend);
50 net.setPreferableTarget(parameters.target);
51
52 i2bp.mean = parameters.meanvalue * 255.0;
53 i2bp.scalefactor = (1.0 / parameters.stdvalue) * (1 / 255.0);
54 tracking_score_threshold = parameters.tracking_score_threshold;
55 }
56
57 TrackerVitImpl(const dnn::Net& model, Scalar meanvalue, Scalar stdvalue, float _tracking_score_threshold)
58 {
59 CV_Assert(!model.empty());
60
61 net = model;
62 i2bp.mean = meanvalue * 255.0;
63 i2bp.scalefactor = (1.0 / stdvalue) * (1 / 255.0);
64 tracking_score_threshold = _tracking_score_threshold;
65 }
66
67 void init(InputArray image, const Rect& boundingBox) CV_OVERRIDE;
68 bool update(InputArray image, Rect& boundingBox) CV_OVERRIDE;
69 float getTrackingScore() CV_OVERRIDE;
70
71 Rect rect_last;
72 float tracking_score;
73
74 float tracking_score_threshold;
75 dnn::Image2BlobParams i2bp;
76
77
78protected:
79 void preprocess(const Mat& src, Mat& dst, Size size);
80
81 const Size searchSize{256, 256};
82 const Size templateSize{128, 128};
83
84 Mat hanningWindow;
85
86 dnn::Net net;
87};
88
89static int crop_image(const Mat& src, Mat& dst, Rect box, int factor)
90{
91 int x = box.x, y = box.y, w = box.width, h = box.height;
92 int crop_sz = cvCeil(value: sqrt(x: w * h) * factor);
93
94 int x1 = x + (w - crop_sz) / 2;
95 int x2 = x1 + crop_sz;
96 int y1 = y + (h - crop_sz) / 2;
97 int y2 = y1 + crop_sz;
98
99 int x1_pad = std::max(a: 0, b: -x1);
100 int y1_pad = std::max(a: 0, b: -y1);
101 int x2_pad = std::max(a: x2 - src.size[1] + 1, b: 0);
102 int y2_pad = std::max(a: y2 - src.size[0] + 1, b: 0);
103
104 Rect roi(x1 + x1_pad, y1 + y1_pad, x2 - x2_pad - x1 - x1_pad, y2 - y2_pad - y1 - y1_pad);
105 Mat im_crop = src(roi);
106 copyMakeBorder(src: im_crop, dst, top: y1_pad, bottom: y2_pad, left: x1_pad, right: x2_pad, borderType: BORDER_CONSTANT);
107
108 return crop_sz;
109}
110
111void TrackerVitImpl::preprocess(const Mat& src, Mat& dst, Size size)
112{
113 Mat img;
114 resize(src, dst: img, dsize: size);
115
116 dst = dnn::blobFromImageWithParams(image: img, param: i2bp);
117}
118
119static Mat hann1d(int sz, bool centered = true) {
120 Mat hanningWindow(sz, 1, CV_32FC1);
121 float* data = hanningWindow.ptr<float>(y: 0);
122
123 if(centered) {
124 for(int i = 0; i < sz; i++) {
125 float val = 0.5f * (1.f - std::cos(x: static_cast<float>(2 * M_PI / (sz + 1)) * (i + 1)));
126 data[i] = val;
127 }
128 }
129 else {
130 int half_sz = sz / 2;
131 for(int i = 0; i <= half_sz; i++) {
132 float val = 0.5f * (1.f + std::cos(x: static_cast<float>(2 * M_PI / (sz + 2)) * i));
133 data[i] = val;
134 data[sz - 1 - i] = val;
135 }
136 }
137
138 return hanningWindow;
139}
140
141static Mat hann2d(Size size, bool centered = true) {
142 int rows = size.height;
143 int cols = size.width;
144
145 Mat hanningWindowRows = hann1d(sz: rows, centered);
146 Mat hanningWindowCols = hann1d(sz: cols, centered);
147
148 Mat hanningWindow = hanningWindowRows * hanningWindowCols.t();
149
150 return hanningWindow;
151}
152
153static void updateLastRect(float cx, float cy, float w, float h, int crop_size, Rect &rect_last)
154{
155 int x0 = rect_last.x + (rect_last.width - crop_size) / 2;
156 int y0 = rect_last.y + (rect_last.height - crop_size) / 2;
157
158 float x1 = cx - w / 2, y1 = cy - h / 2;
159 rect_last.x = cvFloor(value: x1 * crop_size + x0);
160 rect_last.y = cvFloor(value: y1 * crop_size + y0);
161 rect_last.width = cvFloor(value: w * crop_size);
162 rect_last.height = cvFloor(value: h * crop_size);
163}
164
165void TrackerVitImpl::init(InputArray image_, const Rect &boundingBox_)
166{
167 Mat image = image_.getMat();
168 Mat crop;
169 crop_image(src: image, dst&: crop, box: boundingBox_, factor: 2);
170 Mat blob;
171 preprocess(src: crop, dst&: blob, size: templateSize);
172 net.setInput(blob, name: "template");
173 Size size(16, 16);
174 hanningWindow = hann2d(size, centered: true);
175 rect_last = boundingBox_;
176}
177
178bool TrackerVitImpl::update(InputArray image_, Rect &boundingBoxRes)
179{
180 Mat image = image_.getMat();
181 Mat crop;
182 int crop_size = crop_image(src: image, dst&: crop, box: rect_last, factor: 4); // crop: [crop_size, crop_size]
183 Mat blob;
184 preprocess(src: crop, dst&: blob, size: searchSize);
185 net.setInput(blob, name: "search");
186 std::vector<String> outputName = {"output1", "output2", "output3"};
187 std::vector<Mat> outs;
188 net.forward(outputBlobs: outs, outBlobNames: outputName);
189 CV_Assert(outs.size() == 3);
190
191 Mat conf_map = outs[0].reshape(cn: 0, newshape: {16, 16});
192 Mat size_map = outs[1].reshape(cn: 0, newshape: {2, 16, 16});
193 Mat offset_map = outs[2].reshape(cn: 0, newshape: {2, 16, 16});
194
195 multiply(src1: conf_map, src2: hanningWindow, dst: conf_map);
196
197 double maxVal;
198 Point maxLoc;
199 minMaxLoc(src: conf_map, minVal: nullptr, maxVal: &maxVal, minLoc: nullptr, maxLoc: &maxLoc);
200 tracking_score = static_cast<float>(maxVal);
201
202 if (tracking_score >= tracking_score_threshold) {
203 float cx = (maxLoc.x + offset_map.at<float>(i0: 0, i1: maxLoc.y, i2: maxLoc.x)) / 16;
204 float cy = (maxLoc.y + offset_map.at<float>(i0: 1, i1: maxLoc.y, i2: maxLoc.x)) / 16;
205 float w = size_map.at<float>(i0: 0, i1: maxLoc.y, i2: maxLoc.x);
206 float h = size_map.at<float>(i0: 1, i1: maxLoc.y, i2: maxLoc.x);
207
208 updateLastRect(cx, cy, w, h, crop_size, rect_last);
209 boundingBoxRes = rect_last;
210 return true;
211 } else {
212 return false;
213 }
214}
215
216float TrackerVitImpl::getTrackingScore()
217{
218 return tracking_score;
219}
220
221Ptr<TrackerVit> TrackerVit::create(const TrackerVit::Params& parameters)
222{
223 return makePtr<TrackerVitImpl>(a1: parameters);
224}
225
226Ptr<TrackerVit> TrackerVit::create(const dnn::Net& model, Scalar meanvalue, Scalar stdvalue, float tracking_score_threshold)
227{
228 return makePtr<TrackerVitImpl>(a1: model, a1: meanvalue, a1: stdvalue, a1: tracking_score_threshold);
229}
230
231#else // OPENCV_HAVE_DNN
232Ptr<TrackerVit> TrackerVit::create(const TrackerVit::Params& parameters)
233{
234 CV_UNUSED(parameters);
235 CV_Error(Error::StsNotImplemented, "to use vittrack, the tracking module needs to be built with opencv_dnn !");
236}
237#endif // OPENCV_HAVE_DNN
238}
239

source code of opencv/modules/video/src/tracking/tracker_vit.cpp