1 | #pragma once |
2 | |
3 | #include <algorithm> |
4 | #include <cmath> |
5 | #include <cstdint> |
6 | #include <tuple> |
7 | #include <vector> |
8 | #include <cassert> |
9 | |
10 | namespace kdbush { |
11 | |
12 | template <std::uint8_t I, typename T> |
13 | struct nth { |
14 | inline static typename std::tuple_element<I, T>::type get(const T &t) { |
15 | return std::get<I>(t); |
16 | } |
17 | }; |
18 | |
19 | template <typename TPoint, typename TIndex = std::size_t> |
20 | class KDBush { |
21 | |
22 | public: |
23 | using TNumber = decltype(nth<0, TPoint>::get(std::declval<TPoint>())); |
24 | static_assert( |
25 | std::is_same<TNumber, decltype(nth<1, TPoint>::get(std::declval<TPoint>()))>::value, |
26 | "point component types must be identical" ); |
27 | |
28 | static const std::uint8_t defaultNodeSize = 64; |
29 | |
30 | KDBush(const std::uint8_t nodeSize_ = defaultNodeSize) : nodeSize(nodeSize_) { |
31 | } |
32 | |
33 | KDBush(const std::vector<TPoint> &points_, const std::uint8_t nodeSize_ = defaultNodeSize) |
34 | : KDBush(std::begin(points_), std::end(points_), nodeSize_) { |
35 | } |
36 | |
37 | template <typename TPointIter> |
38 | KDBush(const TPointIter &points_begin, |
39 | const TPointIter &points_end, |
40 | const std::uint8_t nodeSize_ = defaultNodeSize) |
41 | : nodeSize(nodeSize_) { |
42 | fill(points_begin, points_end); |
43 | } |
44 | |
45 | void fill(const std::vector<TPoint> &points_) { |
46 | fill(std::begin(points_), std::end(points_)); |
47 | } |
48 | |
49 | template <typename TPointIter> |
50 | void fill(const TPointIter &points_begin, const TPointIter &points_end) { |
51 | assert(points.empty()); |
52 | const TIndex size = static_cast<TIndex>(std::distance(points_begin, points_end)); |
53 | |
54 | points.reserve(size); |
55 | ids.reserve(size); |
56 | |
57 | TIndex i = 0; |
58 | for (auto p = points_begin; p != points_end; p++) { |
59 | points.emplace_back(nth<0, TPoint>::get(*p), nth<1, TPoint>::get(*p)); |
60 | ids.push_back(i++); |
61 | } |
62 | |
63 | sortKD(left: 0, right: size - 1, axis: 0); |
64 | } |
65 | |
66 | template <typename TVisitor> |
67 | void range(const TNumber minX, |
68 | const TNumber minY, |
69 | const TNumber maxX, |
70 | const TNumber maxY, |
71 | const TVisitor &visitor) { |
72 | range(minX, minY, maxX, maxY, visitor, 0, static_cast<TIndex>(ids.size() - 1), 0); |
73 | } |
74 | |
75 | template <typename TVisitor> |
76 | void within(const TNumber qx, const TNumber qy, const TNumber r, const TVisitor &visitor) { |
77 | within(qx, qy, r, visitor, 0, static_cast<TIndex>(ids.size() - 1), 0); |
78 | } |
79 | |
80 | private: |
81 | std::vector<TIndex> ids; |
82 | std::vector<std::pair<TNumber, TNumber>> points; |
83 | std::uint8_t nodeSize; |
84 | |
85 | template <typename TVisitor> |
86 | void range(const TNumber minX, |
87 | const TNumber minY, |
88 | const TNumber maxX, |
89 | const TNumber maxY, |
90 | const TVisitor &visitor, |
91 | const TIndex left, |
92 | const TIndex right, |
93 | const std::uint8_t axis) { |
94 | |
95 | if (right - left <= nodeSize) { |
96 | for (auto i = left; i <= right; i++) { |
97 | const TNumber x = std::get<0>(points[i]); |
98 | const TNumber y = std::get<1>(points[i]); |
99 | if (x >= minX && x <= maxX && y >= minY && y <= maxY) visitor(ids[i]); |
100 | } |
101 | return; |
102 | } |
103 | |
104 | const TIndex m = (left + right) >> 1; |
105 | const TNumber x = std::get<0>(points[m]); |
106 | const TNumber y = std::get<1>(points[m]); |
107 | |
108 | if (x >= minX && x <= maxX && y >= minY && y <= maxY) visitor(ids[m]); |
109 | |
110 | if (axis == 0 ? minX <= x : minY <= y) |
111 | range(minX, minY, maxX, maxY, visitor, left, m - 1, (axis + 1) % 2); |
112 | |
113 | if (axis == 0 ? maxX >= x : maxY >= y) |
114 | range(minX, minY, maxX, maxY, visitor, m + 1, right, (axis + 1) % 2); |
115 | } |
116 | |
117 | template <typename TVisitor> |
118 | void within(const TNumber qx, |
119 | const TNumber qy, |
120 | const TNumber r, |
121 | const TVisitor &visitor, |
122 | const TIndex left, |
123 | const TIndex right, |
124 | const std::uint8_t axis) { |
125 | |
126 | const TNumber r2 = r * r; |
127 | |
128 | if (right - left <= nodeSize) { |
129 | for (auto i = left; i <= right; i++) { |
130 | const TNumber x = std::get<0>(points[i]); |
131 | const TNumber y = std::get<1>(points[i]); |
132 | if (sqDist(ax: x, ay: y, bx: qx, by: qy) <= r2) visitor(ids[i]); |
133 | } |
134 | return; |
135 | } |
136 | |
137 | const TIndex m = (left + right) >> 1; |
138 | const TNumber x = std::get<0>(points[m]); |
139 | const TNumber y = std::get<1>(points[m]); |
140 | |
141 | if (sqDist(ax: x, ay: y, bx: qx, by: qy) <= r2) visitor(ids[m]); |
142 | |
143 | if (axis == 0 ? qx - r <= x : qy - r <= y) |
144 | within(qx, qy, r, visitor, left, m - 1, (axis + 1) % 2); |
145 | |
146 | if (axis == 0 ? qx + r >= x : qy + r >= y) |
147 | within(qx, qy, r, visitor, m + 1, right, (axis + 1) % 2); |
148 | } |
149 | |
150 | void sortKD(const TIndex left, const TIndex right, const std::uint8_t axis) { |
151 | if (right - left <= nodeSize) return; |
152 | const TIndex m = (left + right) >> 1; |
153 | if (axis == 0) { |
154 | select<0>(m, left, right); |
155 | } else { |
156 | select<1>(m, left, right); |
157 | } |
158 | sortKD(left, right: m - 1, axis: (axis + 1) % 2); |
159 | sortKD(left: m + 1, right, axis: (axis + 1) % 2); |
160 | } |
161 | |
162 | template <std::uint8_t I> |
163 | void select(const TIndex k, TIndex left, TIndex right) { |
164 | |
165 | while (right > left) { |
166 | if (right - left > 600) { |
167 | const double n = right - left + 1; |
168 | const double m = k - left + 1; |
169 | const double z = std::log(x: n); |
170 | const double s = 0.5 * std::exp(x: 2 * z / 3); |
171 | const double r = |
172 | k - m * s / n + 0.5 * std::sqrt(x: z * s * (1 - s / n)) * (2 * m < n ? -1 : 1); |
173 | select<I>(k, std::max(left, TIndex(r)), std::min(right, TIndex(r + s))); |
174 | } |
175 | |
176 | const TNumber t = std::get<I>(points[k]); |
177 | TIndex i = left; |
178 | TIndex j = right; |
179 | |
180 | swapItem(i: left, j: k); |
181 | if (std::get<I>(points[right]) > t) swapItem(i: left, j: right); |
182 | |
183 | while (i < j) { |
184 | swapItem(i: i++, j: j--); |
185 | while (std::get<I>(points[i]) < t) i++; |
186 | while (std::get<I>(points[j]) > t) j--; |
187 | } |
188 | |
189 | if (std::get<I>(points[left]) == t) |
190 | swapItem(i: left, j); |
191 | else { |
192 | swapItem(i: ++j, j: right); |
193 | } |
194 | |
195 | if (j <= k) left = j + 1; |
196 | if (k <= j) right = j - 1; |
197 | } |
198 | } |
199 | |
200 | void swapItem(const TIndex i, const TIndex j) { |
201 | std::iter_swap(ids.begin() + i, ids.begin() + j); |
202 | std::iter_swap(points.begin() + i, points.begin() + j); |
203 | } |
204 | |
205 | TNumber sqDist(const TNumber ax, const TNumber ay, const TNumber bx, const TNumber by) { |
206 | return std::pow(ax - bx, 2) + std::pow(ay - by, 2); |
207 | } |
208 | }; |
209 | |
210 | } // namespace kdbush |
211 | |