| 1 | // This file is part of OpenCV project. |
| 2 | // It is subject to the license terms in the LICENSE file found in the top-level directory |
| 3 | // of this distribution and at http://opencv.org/license.html. |
| 4 | |
| 5 | // Author, PengyuLiu, 1872918507@qq.com |
| 6 | |
| 7 | #include "../precomp.hpp" |
| 8 | #ifdef HAVE_OPENCV_DNN |
| 9 | #include "opencv2/dnn.hpp" |
| 10 | #endif |
| 11 | |
| 12 | namespace cv { |
| 13 | |
| 14 | TrackerVit::TrackerVit() |
| 15 | { |
| 16 | // nothing |
| 17 | } |
| 18 | |
| 19 | TrackerVit::~TrackerVit() |
| 20 | { |
| 21 | // nothing |
| 22 | } |
| 23 | |
| 24 | TrackerVit::Params::Params() |
| 25 | { |
| 26 | net = "vitTracker.onnx" ; |
| 27 | meanvalue = Scalar{0.485, 0.456, 0.406}; // normalized mean (already divided by 255) |
| 28 | stdvalue = Scalar{0.229, 0.224, 0.225}; // normalized std (already divided by 255) |
| 29 | #ifdef HAVE_OPENCV_DNN |
| 30 | backend = dnn::DNN_BACKEND_DEFAULT; |
| 31 | target = dnn::DNN_TARGET_CPU; |
| 32 | #else |
| 33 | backend = -1; // invalid value |
| 34 | target = -1; // invalid value |
| 35 | #endif |
| 36 | tracking_score_threshold = 0.20f; // safe threshold to filter out black frames |
| 37 | } |
| 38 | |
| 39 | #ifdef HAVE_OPENCV_DNN |
| 40 | |
| 41 | class TrackerVitImpl : public TrackerVit |
| 42 | { |
| 43 | public: |
| 44 | TrackerVitImpl(const TrackerVit::Params& parameters) |
| 45 | { |
| 46 | net = dnn::readNet(model: parameters.net); |
| 47 | CV_Assert(!net.empty()); |
| 48 | |
| 49 | net.setPreferableBackend(parameters.backend); |
| 50 | net.setPreferableTarget(parameters.target); |
| 51 | |
| 52 | i2bp.mean = parameters.meanvalue * 255.0; |
| 53 | i2bp.scalefactor = (1.0 / parameters.stdvalue) * (1 / 255.0); |
| 54 | tracking_score_threshold = parameters.tracking_score_threshold; |
| 55 | } |
| 56 | |
| 57 | TrackerVitImpl(const dnn::Net& model, Scalar meanvalue, Scalar stdvalue, float _tracking_score_threshold) |
| 58 | { |
| 59 | CV_Assert(!model.empty()); |
| 60 | |
| 61 | net = model; |
| 62 | i2bp.mean = meanvalue * 255.0; |
| 63 | i2bp.scalefactor = (1.0 / stdvalue) * (1 / 255.0); |
| 64 | tracking_score_threshold = _tracking_score_threshold; |
| 65 | } |
| 66 | |
| 67 | void init(InputArray image, const Rect& boundingBox) CV_OVERRIDE; |
| 68 | bool update(InputArray image, Rect& boundingBox) CV_OVERRIDE; |
| 69 | float getTrackingScore() CV_OVERRIDE; |
| 70 | |
| 71 | Rect rect_last; |
| 72 | float tracking_score; |
| 73 | |
| 74 | float tracking_score_threshold; |
| 75 | dnn::Image2BlobParams i2bp; |
| 76 | |
| 77 | |
| 78 | protected: |
| 79 | void preprocess(const Mat& src, Mat& dst, Size size); |
| 80 | |
| 81 | const Size searchSize{256, 256}; |
| 82 | const Size templateSize{128, 128}; |
| 83 | |
| 84 | Mat hanningWindow; |
| 85 | |
| 86 | dnn::Net net; |
| 87 | }; |
| 88 | |
| 89 | static int crop_image(const Mat& src, Mat& dst, Rect box, int factor) |
| 90 | { |
| 91 | int x = box.x, y = box.y, w = box.width, h = box.height; |
| 92 | int crop_sz = cvCeil(value: sqrt(x: w * h) * factor); |
| 93 | |
| 94 | int x1 = x + (w - crop_sz) / 2; |
| 95 | int x2 = x1 + crop_sz; |
| 96 | int y1 = y + (h - crop_sz) / 2; |
| 97 | int y2 = y1 + crop_sz; |
| 98 | |
| 99 | int x1_pad = std::max(a: 0, b: -x1); |
| 100 | int y1_pad = std::max(a: 0, b: -y1); |
| 101 | int x2_pad = std::max(a: x2 - src.size[1] + 1, b: 0); |
| 102 | int y2_pad = std::max(a: y2 - src.size[0] + 1, b: 0); |
| 103 | |
| 104 | Rect roi(x1 + x1_pad, y1 + y1_pad, x2 - x2_pad - x1 - x1_pad, y2 - y2_pad - y1 - y1_pad); |
| 105 | Mat im_crop = src(roi); |
| 106 | copyMakeBorder(src: im_crop, dst, top: y1_pad, bottom: y2_pad, left: x1_pad, right: x2_pad, borderType: BORDER_CONSTANT); |
| 107 | |
| 108 | return crop_sz; |
| 109 | } |
| 110 | |
| 111 | void TrackerVitImpl::preprocess(const Mat& src, Mat& dst, Size size) |
| 112 | { |
| 113 | Mat img; |
| 114 | resize(src, dst: img, dsize: size); |
| 115 | |
| 116 | dst = dnn::blobFromImageWithParams(image: img, param: i2bp); |
| 117 | } |
| 118 | |
| 119 | static Mat hann1d(int sz, bool centered = true) { |
| 120 | Mat hanningWindow(sz, 1, CV_32FC1); |
| 121 | float* data = hanningWindow.ptr<float>(y: 0); |
| 122 | |
| 123 | if(centered) { |
| 124 | for(int i = 0; i < sz; i++) { |
| 125 | float val = 0.5f * (1.f - std::cos(x: static_cast<float>(2 * M_PI / (sz + 1)) * (i + 1))); |
| 126 | data[i] = val; |
| 127 | } |
| 128 | } |
| 129 | else { |
| 130 | int half_sz = sz / 2; |
| 131 | for(int i = 0; i <= half_sz; i++) { |
| 132 | float val = 0.5f * (1.f + std::cos(x: static_cast<float>(2 * M_PI / (sz + 2)) * i)); |
| 133 | data[i] = val; |
| 134 | data[sz - 1 - i] = val; |
| 135 | } |
| 136 | } |
| 137 | |
| 138 | return hanningWindow; |
| 139 | } |
| 140 | |
| 141 | static Mat hann2d(Size size, bool centered = true) { |
| 142 | int rows = size.height; |
| 143 | int cols = size.width; |
| 144 | |
| 145 | Mat hanningWindowRows = hann1d(sz: rows, centered); |
| 146 | Mat hanningWindowCols = hann1d(sz: cols, centered); |
| 147 | |
| 148 | Mat hanningWindow = hanningWindowRows * hanningWindowCols.t(); |
| 149 | |
| 150 | return hanningWindow; |
| 151 | } |
| 152 | |
| 153 | static void updateLastRect(float cx, float cy, float w, float h, int crop_size, Rect &rect_last) |
| 154 | { |
| 155 | int x0 = rect_last.x + (rect_last.width - crop_size) / 2; |
| 156 | int y0 = rect_last.y + (rect_last.height - crop_size) / 2; |
| 157 | |
| 158 | float x1 = cx - w / 2, y1 = cy - h / 2; |
| 159 | rect_last.x = cvFloor(value: x1 * crop_size + x0); |
| 160 | rect_last.y = cvFloor(value: y1 * crop_size + y0); |
| 161 | rect_last.width = cvFloor(value: w * crop_size); |
| 162 | rect_last.height = cvFloor(value: h * crop_size); |
| 163 | } |
| 164 | |
| 165 | void TrackerVitImpl::init(InputArray image_, const Rect &boundingBox_) |
| 166 | { |
| 167 | Mat image = image_.getMat(); |
| 168 | Mat crop; |
| 169 | crop_image(src: image, dst&: crop, box: boundingBox_, factor: 2); |
| 170 | Mat blob; |
| 171 | preprocess(src: crop, dst&: blob, size: templateSize); |
| 172 | net.setInput(blob, name: "template" ); |
| 173 | Size size(16, 16); |
| 174 | hanningWindow = hann2d(size, centered: true); |
| 175 | rect_last = boundingBox_; |
| 176 | } |
| 177 | |
| 178 | bool TrackerVitImpl::update(InputArray image_, Rect &boundingBoxRes) |
| 179 | { |
| 180 | Mat image = image_.getMat(); |
| 181 | Mat crop; |
| 182 | int crop_size = crop_image(src: image, dst&: crop, box: rect_last, factor: 4); // crop: [crop_size, crop_size] |
| 183 | Mat blob; |
| 184 | preprocess(src: crop, dst&: blob, size: searchSize); |
| 185 | net.setInput(blob, name: "search" ); |
| 186 | std::vector<String> outputName = {"output1" , "output2" , "output3" }; |
| 187 | std::vector<Mat> outs; |
| 188 | net.forward(outputBlobs: outs, outBlobNames: outputName); |
| 189 | CV_Assert(outs.size() == 3); |
| 190 | |
| 191 | Mat conf_map = outs[0].reshape(cn: 0, newshape: {16, 16}); |
| 192 | Mat size_map = outs[1].reshape(cn: 0, newshape: {2, 16, 16}); |
| 193 | Mat offset_map = outs[2].reshape(cn: 0, newshape: {2, 16, 16}); |
| 194 | |
| 195 | multiply(src1: conf_map, src2: hanningWindow, dst: conf_map); |
| 196 | |
| 197 | double maxVal; |
| 198 | Point maxLoc; |
| 199 | minMaxLoc(src: conf_map, minVal: nullptr, maxVal: &maxVal, minLoc: nullptr, maxLoc: &maxLoc); |
| 200 | tracking_score = static_cast<float>(maxVal); |
| 201 | |
| 202 | if (tracking_score >= tracking_score_threshold) { |
| 203 | float cx = (maxLoc.x + offset_map.at<float>(i0: 0, i1: maxLoc.y, i2: maxLoc.x)) / 16; |
| 204 | float cy = (maxLoc.y + offset_map.at<float>(i0: 1, i1: maxLoc.y, i2: maxLoc.x)) / 16; |
| 205 | float w = size_map.at<float>(i0: 0, i1: maxLoc.y, i2: maxLoc.x); |
| 206 | float h = size_map.at<float>(i0: 1, i1: maxLoc.y, i2: maxLoc.x); |
| 207 | |
| 208 | updateLastRect(cx, cy, w, h, crop_size, rect_last); |
| 209 | boundingBoxRes = rect_last; |
| 210 | return true; |
| 211 | } else { |
| 212 | return false; |
| 213 | } |
| 214 | } |
| 215 | |
| 216 | float TrackerVitImpl::getTrackingScore() |
| 217 | { |
| 218 | return tracking_score; |
| 219 | } |
| 220 | |
| 221 | Ptr<TrackerVit> TrackerVit::create(const TrackerVit::Params& parameters) |
| 222 | { |
| 223 | return makePtr<TrackerVitImpl>(a1: parameters); |
| 224 | } |
| 225 | |
| 226 | Ptr<TrackerVit> TrackerVit::create(const dnn::Net& model, Scalar meanvalue, Scalar stdvalue, float tracking_score_threshold) |
| 227 | { |
| 228 | return makePtr<TrackerVitImpl>(a1: model, a1: meanvalue, a1: stdvalue, a1: tracking_score_threshold); |
| 229 | } |
| 230 | |
| 231 | #else // OPENCV_HAVE_DNN |
| 232 | Ptr<TrackerVit> TrackerVit::create(const TrackerVit::Params& parameters) |
| 233 | { |
| 234 | CV_UNUSED(parameters); |
| 235 | CV_Error(Error::StsNotImplemented, "to use vittrack, the tracking module needs to be built with opencv_dnn !" ); |
| 236 | } |
| 237 | #endif // OPENCV_HAVE_DNN |
| 238 | } |
| 239 | |