1 | /// A trait for describing vector operations used by vectorized searchers. |
2 | /// |
3 | /// The trait is highly constrained to low level vector operations needed. |
4 | /// In general, it was invented mostly to be generic over x86's __m128i and |
5 | /// __m256i types. At time of writing, it also supports wasm and aarch64 |
6 | /// 128-bit vector types as well. |
7 | /// |
8 | /// # Safety |
9 | /// |
10 | /// All methods are not safe since they are intended to be implemented using |
11 | /// vendor intrinsics, which are also not safe. Callers must ensure that the |
12 | /// appropriate target features are enabled in the calling function, and that |
13 | /// the current CPU supports them. All implementations should avoid marking the |
14 | /// routines with #[target_feature] and instead mark them as #[inline(always)] |
15 | /// to ensure they get appropriately inlined. (inline(always) cannot be used |
16 | /// with target_feature.) |
17 | pub(crate) trait Vector: Copy + core::fmt::Debug { |
18 | /// The number of bits in the vector. |
19 | const BITS: usize; |
20 | /// The number of bytes in the vector. That is, this is the size of the |
21 | /// vector in memory. |
22 | const BYTES: usize; |
23 | /// The bits that must be zero in order for a `*const u8` pointer to be |
24 | /// correctly aligned to read vector values. |
25 | const ALIGN: usize; |
26 | |
27 | /// The type of the value returned by `Vector::movemask`. |
28 | /// |
29 | /// This supports abstracting over the specific representation used in |
30 | /// order to accommodate different representations in different ISAs. |
31 | type Mask: MoveMask; |
32 | |
33 | /// Create a vector with 8-bit lanes with the given byte repeated into each |
34 | /// lane. |
35 | unsafe fn splat(byte: u8) -> Self; |
36 | |
37 | /// Read a vector-size number of bytes from the given pointer. The pointer |
38 | /// must be aligned to the size of the vector. |
39 | /// |
40 | /// # Safety |
41 | /// |
42 | /// Callers must guarantee that at least `BYTES` bytes are readable from |
43 | /// `data` and that `data` is aligned to a `BYTES` boundary. |
44 | unsafe fn load_aligned(data: *const u8) -> Self; |
45 | |
46 | /// Read a vector-size number of bytes from the given pointer. The pointer |
47 | /// does not need to be aligned. |
48 | /// |
49 | /// # Safety |
50 | /// |
51 | /// Callers must guarantee that at least `BYTES` bytes are readable from |
52 | /// `data`. |
53 | unsafe fn load_unaligned(data: *const u8) -> Self; |
54 | |
55 | /// _mm_movemask_epi8 or _mm256_movemask_epi8 |
56 | unsafe fn movemask(self) -> Self::Mask; |
57 | /// _mm_cmpeq_epi8 or _mm256_cmpeq_epi8 |
58 | unsafe fn cmpeq(self, vector2: Self) -> Self; |
59 | /// _mm_and_si128 or _mm256_and_si256 |
60 | unsafe fn and(self, vector2: Self) -> Self; |
61 | /// _mm_or or _mm256_or_si256 |
62 | unsafe fn or(self, vector2: Self) -> Self; |
63 | /// Returns true if and only if `Self::movemask` would return a mask that |
64 | /// contains at least one non-zero bit. |
65 | unsafe fn movemask_will_have_non_zero(self) -> bool { |
66 | self.movemask().has_non_zero() |
67 | } |
68 | } |
69 | |
70 | /// A trait that abstracts over a vector-to-scalar operation called |
71 | /// "move mask." |
72 | /// |
73 | /// On x86-64, this is `_mm_movemask_epi8` for SSE2 and `_mm256_movemask_epi8` |
74 | /// for AVX2. It takes a vector of `u8` lanes and returns a scalar where the |
75 | /// `i`th bit is set if and only if the most significant bit in the `i`th lane |
76 | /// of the vector is set. The simd128 ISA for wasm32 also supports this |
77 | /// exact same operation natively. |
78 | /// |
79 | /// ... But aarch64 doesn't. So we have to fake it with more instructions and |
80 | /// a slightly different representation. We could do extra work to unify the |
81 | /// representations, but then would require additional costs in the hot path |
82 | /// for `memchr` and `packedpair`. So instead, we abstraction over the specific |
83 | /// representation with this trait an ddefine the operations we actually need. |
84 | pub(crate) trait MoveMask: Copy + core::fmt::Debug { |
85 | /// Return a mask that is all zeros except for the least significant `n` |
86 | /// lanes in a corresponding vector. |
87 | fn all_zeros_except_least_significant(n: usize) -> Self; |
88 | |
89 | /// Returns true if and only if this mask has a a non-zero bit anywhere. |
90 | fn has_non_zero(self) -> bool; |
91 | |
92 | /// Returns the number of bits set to 1 in this mask. |
93 | fn count_ones(self) -> usize; |
94 | |
95 | /// Does a bitwise `and` operation between `self` and `other`. |
96 | fn and(self, other: Self) -> Self; |
97 | |
98 | /// Does a bitwise `or` operation between `self` and `other`. |
99 | fn or(self, other: Self) -> Self; |
100 | |
101 | /// Returns a mask that is equivalent to `self` but with the least |
102 | /// significant 1-bit set to 0. |
103 | fn clear_least_significant_bit(self) -> Self; |
104 | |
105 | /// Returns the offset of the first non-zero lane this mask represents. |
106 | fn first_offset(self) -> usize; |
107 | |
108 | /// Returns the offset of the last non-zero lane this mask represents. |
109 | fn last_offset(self) -> usize; |
110 | } |
111 | |
112 | /// This is a "sensible" movemask implementation where each bit represents |
113 | /// whether the most significant bit is set in each corresponding lane of a |
114 | /// vector. This is used on x86-64 and wasm, but such a mask is more expensive |
115 | /// to get on aarch64 so we use something a little different. |
116 | /// |
117 | /// We call this "sensible" because this is what we get using native sse/avx |
118 | /// movemask instructions. But neon has no such native equivalent. |
119 | #[derive(Clone, Copy, Debug)] |
120 | pub(crate) struct SensibleMoveMask(u32); |
121 | |
122 | impl SensibleMoveMask { |
123 | /// Get the mask in a form suitable for computing offsets. |
124 | /// |
125 | /// Basically, this normalizes to little endian. On big endian, this swaps |
126 | /// the bytes. |
127 | #[inline (always)] |
128 | fn get_for_offset(self) -> u32 { |
129 | #[cfg (target_endian = "big" )] |
130 | { |
131 | self.0.swap_bytes() |
132 | } |
133 | #[cfg (target_endian = "little" )] |
134 | { |
135 | self.0 |
136 | } |
137 | } |
138 | } |
139 | |
140 | impl MoveMask for SensibleMoveMask { |
141 | #[inline (always)] |
142 | fn all_zeros_except_least_significant(n: usize) -> SensibleMoveMask { |
143 | debug_assert!(n < 32); |
144 | SensibleMoveMask(!((1 << n) - 1)) |
145 | } |
146 | |
147 | #[inline (always)] |
148 | fn has_non_zero(self) -> bool { |
149 | self.0 != 0 |
150 | } |
151 | |
152 | #[inline (always)] |
153 | fn count_ones(self) -> usize { |
154 | self.0.count_ones() as usize |
155 | } |
156 | |
157 | #[inline (always)] |
158 | fn and(self, other: SensibleMoveMask) -> SensibleMoveMask { |
159 | SensibleMoveMask(self.0 & other.0) |
160 | } |
161 | |
162 | #[inline (always)] |
163 | fn or(self, other: SensibleMoveMask) -> SensibleMoveMask { |
164 | SensibleMoveMask(self.0 | other.0) |
165 | } |
166 | |
167 | #[inline (always)] |
168 | fn clear_least_significant_bit(self) -> SensibleMoveMask { |
169 | SensibleMoveMask(self.0 & (self.0 - 1)) |
170 | } |
171 | |
172 | #[inline (always)] |
173 | fn first_offset(self) -> usize { |
174 | // We are dealing with little endian here (and if we aren't, we swap |
175 | // the bytes so we are in practice), where the most significant byte |
176 | // is at a higher address. That means the least significant bit that |
177 | // is set corresponds to the position of our first matching byte. |
178 | // That position corresponds to the number of zeros after the least |
179 | // significant bit. |
180 | self.get_for_offset().trailing_zeros() as usize |
181 | } |
182 | |
183 | #[inline (always)] |
184 | fn last_offset(self) -> usize { |
185 | // We are dealing with little endian here (and if we aren't, we swap |
186 | // the bytes so we are in practice), where the most significant byte is |
187 | // at a higher address. That means the most significant bit that is set |
188 | // corresponds to the position of our last matching byte. The position |
189 | // from the end of the mask is therefore the number of leading zeros |
190 | // in a 32 bit integer, and the position from the start of the mask is |
191 | // therefore 32 - (leading zeros) - 1. |
192 | 32 - self.get_for_offset().leading_zeros() as usize - 1 |
193 | } |
194 | } |
195 | |
196 | #[cfg (target_arch = "x86_64" )] |
197 | mod x86sse2 { |
198 | use core::arch::x86_64::*; |
199 | |
200 | use super::{SensibleMoveMask, Vector}; |
201 | |
202 | impl Vector for __m128i { |
203 | const BITS: usize = 128; |
204 | const BYTES: usize = 16; |
205 | const ALIGN: usize = Self::BYTES - 1; |
206 | |
207 | type Mask = SensibleMoveMask; |
208 | |
209 | #[inline (always)] |
210 | unsafe fn splat(byte: u8) -> __m128i { |
211 | _mm_set1_epi8(byte as i8) |
212 | } |
213 | |
214 | #[inline (always)] |
215 | unsafe fn load_aligned(data: *const u8) -> __m128i { |
216 | _mm_load_si128(data as *const __m128i) |
217 | } |
218 | |
219 | #[inline (always)] |
220 | unsafe fn load_unaligned(data: *const u8) -> __m128i { |
221 | _mm_loadu_si128(data as *const __m128i) |
222 | } |
223 | |
224 | #[inline (always)] |
225 | unsafe fn movemask(self) -> SensibleMoveMask { |
226 | SensibleMoveMask(_mm_movemask_epi8(self) as u32) |
227 | } |
228 | |
229 | #[inline (always)] |
230 | unsafe fn cmpeq(self, vector2: Self) -> __m128i { |
231 | _mm_cmpeq_epi8(self, vector2) |
232 | } |
233 | |
234 | #[inline (always)] |
235 | unsafe fn and(self, vector2: Self) -> __m128i { |
236 | _mm_and_si128(self, vector2) |
237 | } |
238 | |
239 | #[inline (always)] |
240 | unsafe fn or(self, vector2: Self) -> __m128i { |
241 | _mm_or_si128(self, vector2) |
242 | } |
243 | } |
244 | } |
245 | |
246 | #[cfg (target_arch = "x86_64" )] |
247 | mod x86avx2 { |
248 | use core::arch::x86_64::*; |
249 | |
250 | use super::{SensibleMoveMask, Vector}; |
251 | |
252 | impl Vector for __m256i { |
253 | const BITS: usize = 256; |
254 | const BYTES: usize = 32; |
255 | const ALIGN: usize = Self::BYTES - 1; |
256 | |
257 | type Mask = SensibleMoveMask; |
258 | |
259 | #[inline (always)] |
260 | unsafe fn splat(byte: u8) -> __m256i { |
261 | _mm256_set1_epi8(byte as i8) |
262 | } |
263 | |
264 | #[inline (always)] |
265 | unsafe fn load_aligned(data: *const u8) -> __m256i { |
266 | _mm256_load_si256(data as *const __m256i) |
267 | } |
268 | |
269 | #[inline (always)] |
270 | unsafe fn load_unaligned(data: *const u8) -> __m256i { |
271 | _mm256_loadu_si256(data as *const __m256i) |
272 | } |
273 | |
274 | #[inline (always)] |
275 | unsafe fn movemask(self) -> SensibleMoveMask { |
276 | SensibleMoveMask(_mm256_movemask_epi8(self) as u32) |
277 | } |
278 | |
279 | #[inline (always)] |
280 | unsafe fn cmpeq(self, vector2: Self) -> __m256i { |
281 | _mm256_cmpeq_epi8(self, vector2) |
282 | } |
283 | |
284 | #[inline (always)] |
285 | unsafe fn and(self, vector2: Self) -> __m256i { |
286 | _mm256_and_si256(self, vector2) |
287 | } |
288 | |
289 | #[inline (always)] |
290 | unsafe fn or(self, vector2: Self) -> __m256i { |
291 | _mm256_or_si256(self, vector2) |
292 | } |
293 | } |
294 | } |
295 | |
296 | #[cfg (target_arch = "aarch64" )] |
297 | mod aarch64neon { |
298 | use core::arch::aarch64::*; |
299 | |
300 | use super::{MoveMask, Vector}; |
301 | |
302 | impl Vector for uint8x16_t { |
303 | const BITS: usize = 128; |
304 | const BYTES: usize = 16; |
305 | const ALIGN: usize = Self::BYTES - 1; |
306 | |
307 | type Mask = NeonMoveMask; |
308 | |
309 | #[inline (always)] |
310 | unsafe fn splat(byte: u8) -> uint8x16_t { |
311 | vdupq_n_u8(byte) |
312 | } |
313 | |
314 | #[inline (always)] |
315 | unsafe fn load_aligned(data: *const u8) -> uint8x16_t { |
316 | // I've tried `data.cast::<uint8x16_t>().read()` instead, but |
317 | // couldn't observe any benchmark differences. |
318 | Self::load_unaligned(data) |
319 | } |
320 | |
321 | #[inline (always)] |
322 | unsafe fn load_unaligned(data: *const u8) -> uint8x16_t { |
323 | vld1q_u8(data) |
324 | } |
325 | |
326 | #[inline (always)] |
327 | unsafe fn movemask(self) -> NeonMoveMask { |
328 | let asu16s = vreinterpretq_u16_u8(self); |
329 | let mask = vshrn_n_u16(asu16s, 4); |
330 | let asu64 = vreinterpret_u64_u8(mask); |
331 | let scalar64 = vget_lane_u64(asu64, 0); |
332 | NeonMoveMask(scalar64 & 0x8888888888888888) |
333 | } |
334 | |
335 | #[inline (always)] |
336 | unsafe fn cmpeq(self, vector2: Self) -> uint8x16_t { |
337 | vceqq_u8(self, vector2) |
338 | } |
339 | |
340 | #[inline (always)] |
341 | unsafe fn and(self, vector2: Self) -> uint8x16_t { |
342 | vandq_u8(self, vector2) |
343 | } |
344 | |
345 | #[inline (always)] |
346 | unsafe fn or(self, vector2: Self) -> uint8x16_t { |
347 | vorrq_u8(self, vector2) |
348 | } |
349 | |
350 | /// This is the only interesting implementation of this routine. |
351 | /// Basically, instead of doing the "shift right narrow" dance, we use |
352 | /// adajacent folding max to determine whether there are any non-zero |
353 | /// bytes in our mask. If there are, *then* we'll do the "shift right |
354 | /// narrow" dance. In benchmarks, this does lead to slightly better |
355 | /// throughput, but the win doesn't appear huge. |
356 | #[inline (always)] |
357 | unsafe fn movemask_will_have_non_zero(self) -> bool { |
358 | let low = vreinterpretq_u64_u8(vpmaxq_u8(self, self)); |
359 | vgetq_lane_u64(low, 0) != 0 |
360 | } |
361 | } |
362 | |
363 | /// Neon doesn't have a `movemask` that works like the one in x86-64, so we |
364 | /// wind up using a different method[1]. The different method also produces |
365 | /// a mask, but 4 bits are set in the neon case instead of a single bit set |
366 | /// in the x86-64 case. We do an extra step to zero out 3 of the 4 bits, |
367 | /// but we still wind up with at least 3 zeroes between each set bit. This |
368 | /// generally means that we need to do some division by 4 before extracting |
369 | /// offsets. |
370 | /// |
371 | /// In fact, the existence of this type is the entire reason that we have |
372 | /// the `MoveMask` trait in the first place. This basically lets us keep |
373 | /// the different representations of masks without being forced to unify |
374 | /// them into a single representation, which could result in extra and |
375 | /// unnecessary work. |
376 | /// |
377 | /// [1]: https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon |
378 | #[derive(Clone, Copy, Debug)] |
379 | pub(crate) struct NeonMoveMask(u64); |
380 | |
381 | impl NeonMoveMask { |
382 | /// Get the mask in a form suitable for computing offsets. |
383 | /// |
384 | /// Basically, this normalizes to little endian. On big endian, this |
385 | /// swaps the bytes. |
386 | #[inline (always)] |
387 | fn get_for_offset(self) -> u64 { |
388 | #[cfg (target_endian = "big" )] |
389 | { |
390 | self.0.swap_bytes() |
391 | } |
392 | #[cfg (target_endian = "little" )] |
393 | { |
394 | self.0 |
395 | } |
396 | } |
397 | } |
398 | |
399 | impl MoveMask for NeonMoveMask { |
400 | #[inline (always)] |
401 | fn all_zeros_except_least_significant(n: usize) -> NeonMoveMask { |
402 | debug_assert!(n < 16); |
403 | NeonMoveMask(!(((1 << n) << 2) - 1)) |
404 | } |
405 | |
406 | #[inline (always)] |
407 | fn has_non_zero(self) -> bool { |
408 | self.0 != 0 |
409 | } |
410 | |
411 | #[inline (always)] |
412 | fn count_ones(self) -> usize { |
413 | self.0.count_ones() as usize |
414 | } |
415 | |
416 | #[inline (always)] |
417 | fn and(self, other: NeonMoveMask) -> NeonMoveMask { |
418 | NeonMoveMask(self.0 & other.0) |
419 | } |
420 | |
421 | #[inline (always)] |
422 | fn or(self, other: NeonMoveMask) -> NeonMoveMask { |
423 | NeonMoveMask(self.0 | other.0) |
424 | } |
425 | |
426 | #[inline (always)] |
427 | fn clear_least_significant_bit(self) -> NeonMoveMask { |
428 | NeonMoveMask(self.0 & (self.0 - 1)) |
429 | } |
430 | |
431 | #[inline (always)] |
432 | fn first_offset(self) -> usize { |
433 | // We are dealing with little endian here (and if we aren't, |
434 | // we swap the bytes so we are in practice), where the most |
435 | // significant byte is at a higher address. That means the least |
436 | // significant bit that is set corresponds to the position of our |
437 | // first matching byte. That position corresponds to the number of |
438 | // zeros after the least significant bit. |
439 | // |
440 | // Note that unlike `SensibleMoveMask`, this mask has its bits |
441 | // spread out over 64 bits instead of 16 bits (for a 128 bit |
442 | // vector). Namely, where as x86-64 will turn |
443 | // |
444 | // 0x00 0xFF 0x00 0x00 0xFF |
445 | // |
446 | // into 10010, our neon approach will turn it into |
447 | // |
448 | // 10000000000010000000 |
449 | // |
450 | // And this happens because neon doesn't have a native `movemask` |
451 | // instruction, so we kind of fake it[1]. Thus, we divide the |
452 | // number of trailing zeros by 4 to get the "real" offset. |
453 | // |
454 | // [1]: https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon |
455 | (self.get_for_offset().trailing_zeros() >> 2) as usize |
456 | } |
457 | |
458 | #[inline (always)] |
459 | fn last_offset(self) -> usize { |
460 | // See comment in `first_offset` above. This is basically the same, |
461 | // but coming from the other direction. |
462 | 16 - (self.get_for_offset().leading_zeros() >> 2) as usize - 1 |
463 | } |
464 | } |
465 | } |
466 | |
467 | #[cfg (target_arch = "wasm32" )] |
468 | mod wasm_simd128 { |
469 | use core::arch::wasm32::*; |
470 | |
471 | use super::{SensibleMoveMask, Vector}; |
472 | |
473 | impl Vector for v128 { |
474 | const BITS: usize = 128; |
475 | const BYTES: usize = 16; |
476 | const ALIGN: usize = Self::BYTES - 1; |
477 | |
478 | type Mask = SensibleMoveMask; |
479 | |
480 | #[inline (always)] |
481 | unsafe fn splat(byte: u8) -> v128 { |
482 | u8x16_splat(byte) |
483 | } |
484 | |
485 | #[inline (always)] |
486 | unsafe fn load_aligned(data: *const u8) -> v128 { |
487 | *data.cast() |
488 | } |
489 | |
490 | #[inline (always)] |
491 | unsafe fn load_unaligned(data: *const u8) -> v128 { |
492 | v128_load(data.cast()) |
493 | } |
494 | |
495 | #[inline (always)] |
496 | unsafe fn movemask(self) -> SensibleMoveMask { |
497 | SensibleMoveMask(u8x16_bitmask(self).into()) |
498 | } |
499 | |
500 | #[inline (always)] |
501 | unsafe fn cmpeq(self, vector2: Self) -> v128 { |
502 | u8x16_eq(self, vector2) |
503 | } |
504 | |
505 | #[inline (always)] |
506 | unsafe fn and(self, vector2: Self) -> v128 { |
507 | v128_and(self, vector2) |
508 | } |
509 | |
510 | #[inline (always)] |
511 | unsafe fn or(self, vector2: Self) -> v128 { |
512 | v128_or(self, vector2) |
513 | } |
514 | } |
515 | } |
516 | |