| 1 | #pragma once |
| 2 | |
| 3 | #include <algorithm> |
| 4 | #include <cmath> |
| 5 | #include <cstdint> |
| 6 | #include <tuple> |
| 7 | #include <vector> |
| 8 | #include <cassert> |
| 9 | |
| 10 | namespace kdbush { |
| 11 | |
| 12 | template <std::uint8_t I, typename T> |
| 13 | struct nth { |
| 14 | inline static typename std::tuple_element<I, T>::type get(const T &t) { |
| 15 | return std::get<I>(t); |
| 16 | } |
| 17 | }; |
| 18 | |
| 19 | template <typename TPoint, typename TIndex = std::size_t> |
| 20 | class KDBush { |
| 21 | |
| 22 | public: |
| 23 | using TNumber = decltype(nth<0, TPoint>::get(std::declval<TPoint>())); |
| 24 | static_assert( |
| 25 | std::is_same<TNumber, decltype(nth<1, TPoint>::get(std::declval<TPoint>()))>::value, |
| 26 | "point component types must be identical" ); |
| 27 | |
| 28 | static const std::uint8_t defaultNodeSize = 64; |
| 29 | |
| 30 | KDBush(const std::uint8_t nodeSize_ = defaultNodeSize) : nodeSize(nodeSize_) { |
| 31 | } |
| 32 | |
| 33 | KDBush(const std::vector<TPoint> &points_, const std::uint8_t nodeSize_ = defaultNodeSize) |
| 34 | : KDBush(std::begin(points_), std::end(points_), nodeSize_) { |
| 35 | } |
| 36 | |
| 37 | template <typename TPointIter> |
| 38 | KDBush(const TPointIter &points_begin, |
| 39 | const TPointIter &points_end, |
| 40 | const std::uint8_t nodeSize_ = defaultNodeSize) |
| 41 | : nodeSize(nodeSize_) { |
| 42 | fill(points_begin, points_end); |
| 43 | } |
| 44 | |
| 45 | void fill(const std::vector<TPoint> &points_) { |
| 46 | fill(std::begin(points_), std::end(points_)); |
| 47 | } |
| 48 | |
| 49 | template <typename TPointIter> |
| 50 | void fill(const TPointIter &points_begin, const TPointIter &points_end) { |
| 51 | assert(points.empty()); |
| 52 | const TIndex size = static_cast<TIndex>(std::distance(points_begin, points_end)); |
| 53 | |
| 54 | points.reserve(size); |
| 55 | ids.reserve(size); |
| 56 | |
| 57 | TIndex i = 0; |
| 58 | for (auto p = points_begin; p != points_end; p++) { |
| 59 | points.emplace_back(nth<0, TPoint>::get(*p), nth<1, TPoint>::get(*p)); |
| 60 | ids.push_back(i++); |
| 61 | } |
| 62 | |
| 63 | sortKD(left: 0, right: size - 1, axis: 0); |
| 64 | } |
| 65 | |
| 66 | template <typename TVisitor> |
| 67 | void range(const TNumber minX, |
| 68 | const TNumber minY, |
| 69 | const TNumber maxX, |
| 70 | const TNumber maxY, |
| 71 | const TVisitor &visitor) { |
| 72 | range(minX, minY, maxX, maxY, visitor, 0, static_cast<TIndex>(ids.size() - 1), 0); |
| 73 | } |
| 74 | |
| 75 | template <typename TVisitor> |
| 76 | void within(const TNumber qx, const TNumber qy, const TNumber r, const TVisitor &visitor) { |
| 77 | within(qx, qy, r, visitor, 0, static_cast<TIndex>(ids.size() - 1), 0); |
| 78 | } |
| 79 | |
| 80 | private: |
| 81 | std::vector<TIndex> ids; |
| 82 | std::vector<std::pair<TNumber, TNumber>> points; |
| 83 | std::uint8_t nodeSize; |
| 84 | |
| 85 | template <typename TVisitor> |
| 86 | void range(const TNumber minX, |
| 87 | const TNumber minY, |
| 88 | const TNumber maxX, |
| 89 | const TNumber maxY, |
| 90 | const TVisitor &visitor, |
| 91 | const TIndex left, |
| 92 | const TIndex right, |
| 93 | const std::uint8_t axis) { |
| 94 | |
| 95 | if (right - left <= nodeSize) { |
| 96 | for (auto i = left; i <= right; i++) { |
| 97 | const TNumber x = std::get<0>(points[i]); |
| 98 | const TNumber y = std::get<1>(points[i]); |
| 99 | if (x >= minX && x <= maxX && y >= minY && y <= maxY) visitor(ids[i]); |
| 100 | } |
| 101 | return; |
| 102 | } |
| 103 | |
| 104 | const TIndex m = (left + right) >> 1; |
| 105 | const TNumber x = std::get<0>(points[m]); |
| 106 | const TNumber y = std::get<1>(points[m]); |
| 107 | |
| 108 | if (x >= minX && x <= maxX && y >= minY && y <= maxY) visitor(ids[m]); |
| 109 | |
| 110 | if (axis == 0 ? minX <= x : minY <= y) |
| 111 | range(minX, minY, maxX, maxY, visitor, left, m - 1, (axis + 1) % 2); |
| 112 | |
| 113 | if (axis == 0 ? maxX >= x : maxY >= y) |
| 114 | range(minX, minY, maxX, maxY, visitor, m + 1, right, (axis + 1) % 2); |
| 115 | } |
| 116 | |
| 117 | template <typename TVisitor> |
| 118 | void within(const TNumber qx, |
| 119 | const TNumber qy, |
| 120 | const TNumber r, |
| 121 | const TVisitor &visitor, |
| 122 | const TIndex left, |
| 123 | const TIndex right, |
| 124 | const std::uint8_t axis) { |
| 125 | |
| 126 | const TNumber r2 = r * r; |
| 127 | |
| 128 | if (right - left <= nodeSize) { |
| 129 | for (auto i = left; i <= right; i++) { |
| 130 | const TNumber x = std::get<0>(points[i]); |
| 131 | const TNumber y = std::get<1>(points[i]); |
| 132 | if (sqDist(ax: x, ay: y, bx: qx, by: qy) <= r2) visitor(ids[i]); |
| 133 | } |
| 134 | return; |
| 135 | } |
| 136 | |
| 137 | const TIndex m = (left + right) >> 1; |
| 138 | const TNumber x = std::get<0>(points[m]); |
| 139 | const TNumber y = std::get<1>(points[m]); |
| 140 | |
| 141 | if (sqDist(ax: x, ay: y, bx: qx, by: qy) <= r2) visitor(ids[m]); |
| 142 | |
| 143 | if (axis == 0 ? qx - r <= x : qy - r <= y) |
| 144 | within(qx, qy, r, visitor, left, m - 1, (axis + 1) % 2); |
| 145 | |
| 146 | if (axis == 0 ? qx + r >= x : qy + r >= y) |
| 147 | within(qx, qy, r, visitor, m + 1, right, (axis + 1) % 2); |
| 148 | } |
| 149 | |
| 150 | void sortKD(const TIndex left, const TIndex right, const std::uint8_t axis) { |
| 151 | if (right - left <= nodeSize) return; |
| 152 | const TIndex m = (left + right) >> 1; |
| 153 | if (axis == 0) { |
| 154 | select<0>(m, left, right); |
| 155 | } else { |
| 156 | select<1>(m, left, right); |
| 157 | } |
| 158 | sortKD(left, right: m - 1, axis: (axis + 1) % 2); |
| 159 | sortKD(left: m + 1, right, axis: (axis + 1) % 2); |
| 160 | } |
| 161 | |
| 162 | template <std::uint8_t I> |
| 163 | void select(const TIndex k, TIndex left, TIndex right) { |
| 164 | |
| 165 | while (right > left) { |
| 166 | if (right - left > 600) { |
| 167 | const double n = right - left + 1; |
| 168 | const double m = k - left + 1; |
| 169 | const double z = std::log(x: n); |
| 170 | const double s = 0.5 * std::exp(x: 2 * z / 3); |
| 171 | const double r = |
| 172 | k - m * s / n + 0.5 * std::sqrt(x: z * s * (1 - s / n)) * (2 * m < n ? -1 : 1); |
| 173 | select<I>(k, std::max(left, TIndex(r)), std::min(right, TIndex(r + s))); |
| 174 | } |
| 175 | |
| 176 | const TNumber t = std::get<I>(points[k]); |
| 177 | TIndex i = left; |
| 178 | TIndex j = right; |
| 179 | |
| 180 | swapItem(i: left, j: k); |
| 181 | if (std::get<I>(points[right]) > t) swapItem(i: left, j: right); |
| 182 | |
| 183 | while (i < j) { |
| 184 | swapItem(i: i++, j: j--); |
| 185 | while (std::get<I>(points[i]) < t) i++; |
| 186 | while (std::get<I>(points[j]) > t) j--; |
| 187 | } |
| 188 | |
| 189 | if (std::get<I>(points[left]) == t) |
| 190 | swapItem(i: left, j); |
| 191 | else { |
| 192 | swapItem(i: ++j, j: right); |
| 193 | } |
| 194 | |
| 195 | if (j <= k) left = j + 1; |
| 196 | if (k <= j) right = j - 1; |
| 197 | } |
| 198 | } |
| 199 | |
| 200 | void swapItem(const TIndex i, const TIndex j) { |
| 201 | std::iter_swap(ids.begin() + i, ids.begin() + j); |
| 202 | std::iter_swap(points.begin() + i, points.begin() + j); |
| 203 | } |
| 204 | |
| 205 | TNumber sqDist(const TNumber ax, const TNumber ay, const TNumber bx, const TNumber by) { |
| 206 | return std::pow(ax - bx, 2) + std::pow(ay - by, 2); |
| 207 | } |
| 208 | }; |
| 209 | |
| 210 | } // namespace kdbush |
| 211 | |