1 | /// A trait for describing vector operations used by vectorized searchers. |
2 | /// |
3 | /// The trait is highly constrained to low level vector operations needed. |
4 | /// In general, it was invented mostly to be generic over x86's __m128i and |
5 | /// __m256i types. At time of writing, it also supports wasm and aarch64 |
6 | /// 128-bit vector types as well. |
7 | /// |
8 | /// # Safety |
9 | /// |
10 | /// All methods are not safe since they are intended to be implemented using |
11 | /// vendor intrinsics, which are also not safe. Callers must ensure that the |
12 | /// appropriate target features are enabled in the calling function, and that |
13 | /// the current CPU supports them. All implementations should avoid marking the |
14 | /// routines with #[target_feature] and instead mark them as #[inline(always)] |
15 | /// to ensure they get appropriately inlined. (inline(always) cannot be used |
16 | /// with target_feature.) |
17 | pub(crate) trait Vector: Copy + core::fmt::Debug { |
18 | /// The number of bytes in the vector. That is, this is the size of the |
19 | /// vector in memory. |
20 | const BYTES: usize; |
21 | /// The bits that must be zero in order for a `*const u8` pointer to be |
22 | /// correctly aligned to read vector values. |
23 | const ALIGN: usize; |
24 | |
25 | /// The type of the value returned by `Vector::movemask`. |
26 | /// |
27 | /// This supports abstracting over the specific representation used in |
28 | /// order to accommodate different representations in different ISAs. |
29 | type Mask: MoveMask; |
30 | |
31 | /// Create a vector with 8-bit lanes with the given byte repeated into each |
32 | /// lane. |
33 | unsafe fn splat(byte: u8) -> Self; |
34 | |
35 | /// Read a vector-size number of bytes from the given pointer. The pointer |
36 | /// must be aligned to the size of the vector. |
37 | /// |
38 | /// # Safety |
39 | /// |
40 | /// Callers must guarantee that at least `BYTES` bytes are readable from |
41 | /// `data` and that `data` is aligned to a `BYTES` boundary. |
42 | unsafe fn load_aligned(data: *const u8) -> Self; |
43 | |
44 | /// Read a vector-size number of bytes from the given pointer. The pointer |
45 | /// does not need to be aligned. |
46 | /// |
47 | /// # Safety |
48 | /// |
49 | /// Callers must guarantee that at least `BYTES` bytes are readable from |
50 | /// `data`. |
51 | unsafe fn load_unaligned(data: *const u8) -> Self; |
52 | |
53 | /// _mm_movemask_epi8 or _mm256_movemask_epi8 |
54 | unsafe fn movemask(self) -> Self::Mask; |
55 | /// _mm_cmpeq_epi8 or _mm256_cmpeq_epi8 |
56 | unsafe fn cmpeq(self, vector2: Self) -> Self; |
57 | /// _mm_and_si128 or _mm256_and_si256 |
58 | unsafe fn and(self, vector2: Self) -> Self; |
59 | /// _mm_or or _mm256_or_si256 |
60 | unsafe fn or(self, vector2: Self) -> Self; |
61 | /// Returns true if and only if `Self::movemask` would return a mask that |
62 | /// contains at least one non-zero bit. |
63 | unsafe fn movemask_will_have_non_zero(self) -> bool { |
64 | self.movemask().has_non_zero() |
65 | } |
66 | } |
67 | |
68 | /// A trait that abstracts over a vector-to-scalar operation called |
69 | /// "move mask." |
70 | /// |
71 | /// On x86-64, this is `_mm_movemask_epi8` for SSE2 and `_mm256_movemask_epi8` |
72 | /// for AVX2. It takes a vector of `u8` lanes and returns a scalar where the |
73 | /// `i`th bit is set if and only if the most significant bit in the `i`th lane |
74 | /// of the vector is set. The simd128 ISA for wasm32 also supports this |
75 | /// exact same operation natively. |
76 | /// |
77 | /// ... But aarch64 doesn't. So we have to fake it with more instructions and |
78 | /// a slightly different representation. We could do extra work to unify the |
79 | /// representations, but then would require additional costs in the hot path |
80 | /// for `memchr` and `packedpair`. So instead, we abstraction over the specific |
81 | /// representation with this trait an ddefine the operations we actually need. |
82 | pub(crate) trait MoveMask: Copy + core::fmt::Debug { |
83 | /// Return a mask that is all zeros except for the least significant `n` |
84 | /// lanes in a corresponding vector. |
85 | fn all_zeros_except_least_significant(n: usize) -> Self; |
86 | |
87 | /// Returns true if and only if this mask has a a non-zero bit anywhere. |
88 | fn has_non_zero(self) -> bool; |
89 | |
90 | /// Returns the number of bits set to 1 in this mask. |
91 | fn count_ones(self) -> usize; |
92 | |
93 | /// Does a bitwise `and` operation between `self` and `other`. |
94 | fn and(self, other: Self) -> Self; |
95 | |
96 | /// Does a bitwise `or` operation between `self` and `other`. |
97 | fn or(self, other: Self) -> Self; |
98 | |
99 | /// Returns a mask that is equivalent to `self` but with the least |
100 | /// significant 1-bit set to 0. |
101 | fn clear_least_significant_bit(self) -> Self; |
102 | |
103 | /// Returns the offset of the first non-zero lane this mask represents. |
104 | fn first_offset(self) -> usize; |
105 | |
106 | /// Returns the offset of the last non-zero lane this mask represents. |
107 | fn last_offset(self) -> usize; |
108 | } |
109 | |
110 | /// This is a "sensible" movemask implementation where each bit represents |
111 | /// whether the most significant bit is set in each corresponding lane of a |
112 | /// vector. This is used on x86-64 and wasm, but such a mask is more expensive |
113 | /// to get on aarch64 so we use something a little different. |
114 | /// |
115 | /// We call this "sensible" because this is what we get using native sse/avx |
116 | /// movemask instructions. But neon has no such native equivalent. |
117 | #[derive (Clone, Copy, Debug)] |
118 | pub(crate) struct SensibleMoveMask(u32); |
119 | |
120 | impl SensibleMoveMask { |
121 | /// Get the mask in a form suitable for computing offsets. |
122 | /// |
123 | /// Basically, this normalizes to little endian. On big endian, this swaps |
124 | /// the bytes. |
125 | #[inline (always)] |
126 | fn get_for_offset(self) -> u32 { |
127 | #[cfg (target_endian = "big" )] |
128 | { |
129 | self.0.swap_bytes() |
130 | } |
131 | #[cfg (target_endian = "little" )] |
132 | { |
133 | self.0 |
134 | } |
135 | } |
136 | } |
137 | |
138 | impl MoveMask for SensibleMoveMask { |
139 | #[inline (always)] |
140 | fn all_zeros_except_least_significant(n: usize) -> SensibleMoveMask { |
141 | debug_assert!(n < 32); |
142 | SensibleMoveMask(!((1 << n) - 1)) |
143 | } |
144 | |
145 | #[inline (always)] |
146 | fn has_non_zero(self) -> bool { |
147 | self.0 != 0 |
148 | } |
149 | |
150 | #[inline (always)] |
151 | fn count_ones(self) -> usize { |
152 | self.0.count_ones() as usize |
153 | } |
154 | |
155 | #[inline (always)] |
156 | fn and(self, other: SensibleMoveMask) -> SensibleMoveMask { |
157 | SensibleMoveMask(self.0 & other.0) |
158 | } |
159 | |
160 | #[inline (always)] |
161 | fn or(self, other: SensibleMoveMask) -> SensibleMoveMask { |
162 | SensibleMoveMask(self.0 | other.0) |
163 | } |
164 | |
165 | #[inline (always)] |
166 | fn clear_least_significant_bit(self) -> SensibleMoveMask { |
167 | SensibleMoveMask(self.0 & (self.0 - 1)) |
168 | } |
169 | |
170 | #[inline (always)] |
171 | fn first_offset(self) -> usize { |
172 | // We are dealing with little endian here (and if we aren't, we swap |
173 | // the bytes so we are in practice), where the most significant byte |
174 | // is at a higher address. That means the least significant bit that |
175 | // is set corresponds to the position of our first matching byte. |
176 | // That position corresponds to the number of zeros after the least |
177 | // significant bit. |
178 | self.get_for_offset().trailing_zeros() as usize |
179 | } |
180 | |
181 | #[inline (always)] |
182 | fn last_offset(self) -> usize { |
183 | // We are dealing with little endian here (and if we aren't, we swap |
184 | // the bytes so we are in practice), where the most significant byte is |
185 | // at a higher address. That means the most significant bit that is set |
186 | // corresponds to the position of our last matching byte. The position |
187 | // from the end of the mask is therefore the number of leading zeros |
188 | // in a 32 bit integer, and the position from the start of the mask is |
189 | // therefore 32 - (leading zeros) - 1. |
190 | 32 - self.get_for_offset().leading_zeros() as usize - 1 |
191 | } |
192 | } |
193 | |
194 | #[cfg (target_arch = "x86_64" )] |
195 | mod x86sse2 { |
196 | use core::arch::x86_64::*; |
197 | |
198 | use super::{SensibleMoveMask, Vector}; |
199 | |
200 | impl Vector for __m128i { |
201 | const BYTES: usize = 16; |
202 | const ALIGN: usize = Self::BYTES - 1; |
203 | |
204 | type Mask = SensibleMoveMask; |
205 | |
206 | #[inline (always)] |
207 | unsafe fn splat(byte: u8) -> __m128i { |
208 | _mm_set1_epi8(byte as i8) |
209 | } |
210 | |
211 | #[inline (always)] |
212 | unsafe fn load_aligned(data: *const u8) -> __m128i { |
213 | _mm_load_si128(data as *const __m128i) |
214 | } |
215 | |
216 | #[inline (always)] |
217 | unsafe fn load_unaligned(data: *const u8) -> __m128i { |
218 | _mm_loadu_si128(data as *const __m128i) |
219 | } |
220 | |
221 | #[inline (always)] |
222 | unsafe fn movemask(self) -> SensibleMoveMask { |
223 | SensibleMoveMask(_mm_movemask_epi8(self) as u32) |
224 | } |
225 | |
226 | #[inline (always)] |
227 | unsafe fn cmpeq(self, vector2: Self) -> __m128i { |
228 | _mm_cmpeq_epi8(self, vector2) |
229 | } |
230 | |
231 | #[inline (always)] |
232 | unsafe fn and(self, vector2: Self) -> __m128i { |
233 | _mm_and_si128(self, vector2) |
234 | } |
235 | |
236 | #[inline (always)] |
237 | unsafe fn or(self, vector2: Self) -> __m128i { |
238 | _mm_or_si128(self, vector2) |
239 | } |
240 | } |
241 | } |
242 | |
243 | #[cfg (target_arch = "x86_64" )] |
244 | mod x86avx2 { |
245 | use core::arch::x86_64::*; |
246 | |
247 | use super::{SensibleMoveMask, Vector}; |
248 | |
249 | impl Vector for __m256i { |
250 | const BYTES: usize = 32; |
251 | const ALIGN: usize = Self::BYTES - 1; |
252 | |
253 | type Mask = SensibleMoveMask; |
254 | |
255 | #[inline (always)] |
256 | unsafe fn splat(byte: u8) -> __m256i { |
257 | _mm256_set1_epi8(byte as i8) |
258 | } |
259 | |
260 | #[inline (always)] |
261 | unsafe fn load_aligned(data: *const u8) -> __m256i { |
262 | _mm256_load_si256(data as *const __m256i) |
263 | } |
264 | |
265 | #[inline (always)] |
266 | unsafe fn load_unaligned(data: *const u8) -> __m256i { |
267 | _mm256_loadu_si256(data as *const __m256i) |
268 | } |
269 | |
270 | #[inline (always)] |
271 | unsafe fn movemask(self) -> SensibleMoveMask { |
272 | SensibleMoveMask(_mm256_movemask_epi8(self) as u32) |
273 | } |
274 | |
275 | #[inline (always)] |
276 | unsafe fn cmpeq(self, vector2: Self) -> __m256i { |
277 | _mm256_cmpeq_epi8(self, vector2) |
278 | } |
279 | |
280 | #[inline (always)] |
281 | unsafe fn and(self, vector2: Self) -> __m256i { |
282 | _mm256_and_si256(self, vector2) |
283 | } |
284 | |
285 | #[inline (always)] |
286 | unsafe fn or(self, vector2: Self) -> __m256i { |
287 | _mm256_or_si256(self, vector2) |
288 | } |
289 | } |
290 | } |
291 | |
292 | #[cfg (target_arch = "aarch64" )] |
293 | mod aarch64neon { |
294 | use core::arch::aarch64::*; |
295 | |
296 | use super::{MoveMask, Vector}; |
297 | |
298 | impl Vector for uint8x16_t { |
299 | const BYTES: usize = 16; |
300 | const ALIGN: usize = Self::BYTES - 1; |
301 | |
302 | type Mask = NeonMoveMask; |
303 | |
304 | #[inline (always)] |
305 | unsafe fn splat(byte: u8) -> uint8x16_t { |
306 | vdupq_n_u8(byte) |
307 | } |
308 | |
309 | #[inline (always)] |
310 | unsafe fn load_aligned(data: *const u8) -> uint8x16_t { |
311 | // I've tried `data.cast::<uint8x16_t>().read()` instead, but |
312 | // couldn't observe any benchmark differences. |
313 | Self::load_unaligned(data) |
314 | } |
315 | |
316 | #[inline (always)] |
317 | unsafe fn load_unaligned(data: *const u8) -> uint8x16_t { |
318 | vld1q_u8(data) |
319 | } |
320 | |
321 | #[inline (always)] |
322 | unsafe fn movemask(self) -> NeonMoveMask { |
323 | let asu16s = vreinterpretq_u16_u8(self); |
324 | let mask = vshrn_n_u16(asu16s, 4); |
325 | let asu64 = vreinterpret_u64_u8(mask); |
326 | let scalar64 = vget_lane_u64(asu64, 0); |
327 | NeonMoveMask(scalar64 & 0x8888888888888888) |
328 | } |
329 | |
330 | #[inline (always)] |
331 | unsafe fn cmpeq(self, vector2: Self) -> uint8x16_t { |
332 | vceqq_u8(self, vector2) |
333 | } |
334 | |
335 | #[inline (always)] |
336 | unsafe fn and(self, vector2: Self) -> uint8x16_t { |
337 | vandq_u8(self, vector2) |
338 | } |
339 | |
340 | #[inline (always)] |
341 | unsafe fn or(self, vector2: Self) -> uint8x16_t { |
342 | vorrq_u8(self, vector2) |
343 | } |
344 | |
345 | /// This is the only interesting implementation of this routine. |
346 | /// Basically, instead of doing the "shift right narrow" dance, we use |
347 | /// adajacent folding max to determine whether there are any non-zero |
348 | /// bytes in our mask. If there are, *then* we'll do the "shift right |
349 | /// narrow" dance. In benchmarks, this does lead to slightly better |
350 | /// throughput, but the win doesn't appear huge. |
351 | #[inline (always)] |
352 | unsafe fn movemask_will_have_non_zero(self) -> bool { |
353 | let low = vreinterpretq_u64_u8(vpmaxq_u8(self, self)); |
354 | vgetq_lane_u64(low, 0) != 0 |
355 | } |
356 | } |
357 | |
358 | /// Neon doesn't have a `movemask` that works like the one in x86-64, so we |
359 | /// wind up using a different method[1]. The different method also produces |
360 | /// a mask, but 4 bits are set in the neon case instead of a single bit set |
361 | /// in the x86-64 case. We do an extra step to zero out 3 of the 4 bits, |
362 | /// but we still wind up with at least 3 zeroes between each set bit. This |
363 | /// generally means that we need to do some division by 4 before extracting |
364 | /// offsets. |
365 | /// |
366 | /// In fact, the existence of this type is the entire reason that we have |
367 | /// the `MoveMask` trait in the first place. This basically lets us keep |
368 | /// the different representations of masks without being forced to unify |
369 | /// them into a single representation, which could result in extra and |
370 | /// unnecessary work. |
371 | /// |
372 | /// [1]: https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon |
373 | #[derive (Clone, Copy, Debug)] |
374 | pub(crate) struct NeonMoveMask(u64); |
375 | |
376 | impl NeonMoveMask { |
377 | /// Get the mask in a form suitable for computing offsets. |
378 | /// |
379 | /// Basically, this normalizes to little endian. On big endian, this |
380 | /// swaps the bytes. |
381 | #[inline (always)] |
382 | fn get_for_offset(self) -> u64 { |
383 | #[cfg (target_endian = "big" )] |
384 | { |
385 | self.0.swap_bytes() |
386 | } |
387 | #[cfg (target_endian = "little" )] |
388 | { |
389 | self.0 |
390 | } |
391 | } |
392 | } |
393 | |
394 | impl MoveMask for NeonMoveMask { |
395 | #[inline (always)] |
396 | fn all_zeros_except_least_significant(n: usize) -> NeonMoveMask { |
397 | debug_assert!(n < 16); |
398 | NeonMoveMask(!(((1 << n) << 2) - 1)) |
399 | } |
400 | |
401 | #[inline (always)] |
402 | fn has_non_zero(self) -> bool { |
403 | self.0 != 0 |
404 | } |
405 | |
406 | #[inline (always)] |
407 | fn count_ones(self) -> usize { |
408 | self.0.count_ones() as usize |
409 | } |
410 | |
411 | #[inline (always)] |
412 | fn and(self, other: NeonMoveMask) -> NeonMoveMask { |
413 | NeonMoveMask(self.0 & other.0) |
414 | } |
415 | |
416 | #[inline (always)] |
417 | fn or(self, other: NeonMoveMask) -> NeonMoveMask { |
418 | NeonMoveMask(self.0 | other.0) |
419 | } |
420 | |
421 | #[inline (always)] |
422 | fn clear_least_significant_bit(self) -> NeonMoveMask { |
423 | NeonMoveMask(self.0 & (self.0 - 1)) |
424 | } |
425 | |
426 | #[inline (always)] |
427 | fn first_offset(self) -> usize { |
428 | // We are dealing with little endian here (and if we aren't, |
429 | // we swap the bytes so we are in practice), where the most |
430 | // significant byte is at a higher address. That means the least |
431 | // significant bit that is set corresponds to the position of our |
432 | // first matching byte. That position corresponds to the number of |
433 | // zeros after the least significant bit. |
434 | // |
435 | // Note that unlike `SensibleMoveMask`, this mask has its bits |
436 | // spread out over 64 bits instead of 16 bits (for a 128 bit |
437 | // vector). Namely, where as x86-64 will turn |
438 | // |
439 | // 0x00 0xFF 0x00 0x00 0xFF |
440 | // |
441 | // into 10010, our neon approach will turn it into |
442 | // |
443 | // 10000000000010000000 |
444 | // |
445 | // And this happens because neon doesn't have a native `movemask` |
446 | // instruction, so we kind of fake it[1]. Thus, we divide the |
447 | // number of trailing zeros by 4 to get the "real" offset. |
448 | // |
449 | // [1]: https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon |
450 | (self.get_for_offset().trailing_zeros() >> 2) as usize |
451 | } |
452 | |
453 | #[inline (always)] |
454 | fn last_offset(self) -> usize { |
455 | // See comment in `first_offset` above. This is basically the same, |
456 | // but coming from the other direction. |
457 | 16 - (self.get_for_offset().leading_zeros() >> 2) as usize - 1 |
458 | } |
459 | } |
460 | } |
461 | |
462 | #[cfg (all(target_arch = "wasm32" , target_feature = "simd128" ))] |
463 | mod wasm_simd128 { |
464 | use core::arch::wasm32::*; |
465 | |
466 | use super::{SensibleMoveMask, Vector}; |
467 | |
468 | impl Vector for v128 { |
469 | const BYTES: usize = 16; |
470 | const ALIGN: usize = Self::BYTES - 1; |
471 | |
472 | type Mask = SensibleMoveMask; |
473 | |
474 | #[inline (always)] |
475 | unsafe fn splat(byte: u8) -> v128 { |
476 | u8x16_splat(byte) |
477 | } |
478 | |
479 | #[inline (always)] |
480 | unsafe fn load_aligned(data: *const u8) -> v128 { |
481 | *data.cast() |
482 | } |
483 | |
484 | #[inline (always)] |
485 | unsafe fn load_unaligned(data: *const u8) -> v128 { |
486 | v128_load(data.cast()) |
487 | } |
488 | |
489 | #[inline (always)] |
490 | unsafe fn movemask(self) -> SensibleMoveMask { |
491 | SensibleMoveMask(u8x16_bitmask(self).into()) |
492 | } |
493 | |
494 | #[inline (always)] |
495 | unsafe fn cmpeq(self, vector2: Self) -> v128 { |
496 | u8x16_eq(self, vector2) |
497 | } |
498 | |
499 | #[inline (always)] |
500 | unsafe fn and(self, vector2: Self) -> v128 { |
501 | v128_and(self, vector2) |
502 | } |
503 | |
504 | #[inline (always)] |
505 | unsafe fn or(self, vector2: Self) -> v128 { |
506 | v128_or(self, vector2) |
507 | } |
508 | } |
509 | } |
510 | |