1/// A trait for describing vector operations used by vectorized searchers.
2///
3/// The trait is highly constrained to low level vector operations needed.
4/// In general, it was invented mostly to be generic over x86's __m128i and
5/// __m256i types. At time of writing, it also supports wasm and aarch64
6/// 128-bit vector types as well.
7///
8/// # Safety
9///
10/// All methods are not safe since they are intended to be implemented using
11/// vendor intrinsics, which are also not safe. Callers must ensure that the
12/// appropriate target features are enabled in the calling function, and that
13/// the current CPU supports them. All implementations should avoid marking the
14/// routines with #[target_feature] and instead mark them as #[inline(always)]
15/// to ensure they get appropriately inlined. (inline(always) cannot be used
16/// with target_feature.)
17pub(crate) trait Vector: Copy + core::fmt::Debug {
18 /// The number of bits in the vector.
19 const BITS: usize;
20 /// The number of bytes in the vector. That is, this is the size of the
21 /// vector in memory.
22 const BYTES: usize;
23 /// The bits that must be zero in order for a `*const u8` pointer to be
24 /// correctly aligned to read vector values.
25 const ALIGN: usize;
26
27 /// The type of the value returned by `Vector::movemask`.
28 ///
29 /// This supports abstracting over the specific representation used in
30 /// order to accommodate different representations in different ISAs.
31 type Mask: MoveMask;
32
33 /// Create a vector with 8-bit lanes with the given byte repeated into each
34 /// lane.
35 unsafe fn splat(byte: u8) -> Self;
36
37 /// Read a vector-size number of bytes from the given pointer. The pointer
38 /// must be aligned to the size of the vector.
39 ///
40 /// # Safety
41 ///
42 /// Callers must guarantee that at least `BYTES` bytes are readable from
43 /// `data` and that `data` is aligned to a `BYTES` boundary.
44 unsafe fn load_aligned(data: *const u8) -> Self;
45
46 /// Read a vector-size number of bytes from the given pointer. The pointer
47 /// does not need to be aligned.
48 ///
49 /// # Safety
50 ///
51 /// Callers must guarantee that at least `BYTES` bytes are readable from
52 /// `data`.
53 unsafe fn load_unaligned(data: *const u8) -> Self;
54
55 /// _mm_movemask_epi8 or _mm256_movemask_epi8
56 unsafe fn movemask(self) -> Self::Mask;
57 /// _mm_cmpeq_epi8 or _mm256_cmpeq_epi8
58 unsafe fn cmpeq(self, vector2: Self) -> Self;
59 /// _mm_and_si128 or _mm256_and_si256
60 unsafe fn and(self, vector2: Self) -> Self;
61 /// _mm_or or _mm256_or_si256
62 unsafe fn or(self, vector2: Self) -> Self;
63 /// Returns true if and only if `Self::movemask` would return a mask that
64 /// contains at least one non-zero bit.
65 unsafe fn movemask_will_have_non_zero(self) -> bool {
66 self.movemask().has_non_zero()
67 }
68}
69
70/// A trait that abstracts over a vector-to-scalar operation called
71/// "move mask."
72///
73/// On x86-64, this is `_mm_movemask_epi8` for SSE2 and `_mm256_movemask_epi8`
74/// for AVX2. It takes a vector of `u8` lanes and returns a scalar where the
75/// `i`th bit is set if and only if the most significant bit in the `i`th lane
76/// of the vector is set. The simd128 ISA for wasm32 also supports this
77/// exact same operation natively.
78///
79/// ... But aarch64 doesn't. So we have to fake it with more instructions and
80/// a slightly different representation. We could do extra work to unify the
81/// representations, but then would require additional costs in the hot path
82/// for `memchr` and `packedpair`. So instead, we abstraction over the specific
83/// representation with this trait an ddefine the operations we actually need.
84pub(crate) trait MoveMask: Copy + core::fmt::Debug {
85 /// Return a mask that is all zeros except for the least significant `n`
86 /// lanes in a corresponding vector.
87 fn all_zeros_except_least_significant(n: usize) -> Self;
88
89 /// Returns true if and only if this mask has a a non-zero bit anywhere.
90 fn has_non_zero(self) -> bool;
91
92 /// Returns the number of bits set to 1 in this mask.
93 fn count_ones(self) -> usize;
94
95 /// Does a bitwise `and` operation between `self` and `other`.
96 fn and(self, other: Self) -> Self;
97
98 /// Does a bitwise `or` operation between `self` and `other`.
99 fn or(self, other: Self) -> Self;
100
101 /// Returns a mask that is equivalent to `self` but with the least
102 /// significant 1-bit set to 0.
103 fn clear_least_significant_bit(self) -> Self;
104
105 /// Returns the offset of the first non-zero lane this mask represents.
106 fn first_offset(self) -> usize;
107
108 /// Returns the offset of the last non-zero lane this mask represents.
109 fn last_offset(self) -> usize;
110}
111
112/// This is a "sensible" movemask implementation where each bit represents
113/// whether the most significant bit is set in each corresponding lane of a
114/// vector. This is used on x86-64 and wasm, but such a mask is more expensive
115/// to get on aarch64 so we use something a little different.
116///
117/// We call this "sensible" because this is what we get using native sse/avx
118/// movemask instructions. But neon has no such native equivalent.
119#[derive(Clone, Copy, Debug)]
120pub(crate) struct SensibleMoveMask(u32);
121
122impl SensibleMoveMask {
123 /// Get the mask in a form suitable for computing offsets.
124 ///
125 /// Basically, this normalizes to little endian. On big endian, this swaps
126 /// the bytes.
127 #[inline(always)]
128 fn get_for_offset(self) -> u32 {
129 #[cfg(target_endian = "big")]
130 {
131 self.0.swap_bytes()
132 }
133 #[cfg(target_endian = "little")]
134 {
135 self.0
136 }
137 }
138}
139
140impl MoveMask for SensibleMoveMask {
141 #[inline(always)]
142 fn all_zeros_except_least_significant(n: usize) -> SensibleMoveMask {
143 debug_assert!(n < 32);
144 SensibleMoveMask(!((1 << n) - 1))
145 }
146
147 #[inline(always)]
148 fn has_non_zero(self) -> bool {
149 self.0 != 0
150 }
151
152 #[inline(always)]
153 fn count_ones(self) -> usize {
154 self.0.count_ones() as usize
155 }
156
157 #[inline(always)]
158 fn and(self, other: SensibleMoveMask) -> SensibleMoveMask {
159 SensibleMoveMask(self.0 & other.0)
160 }
161
162 #[inline(always)]
163 fn or(self, other: SensibleMoveMask) -> SensibleMoveMask {
164 SensibleMoveMask(self.0 | other.0)
165 }
166
167 #[inline(always)]
168 fn clear_least_significant_bit(self) -> SensibleMoveMask {
169 SensibleMoveMask(self.0 & (self.0 - 1))
170 }
171
172 #[inline(always)]
173 fn first_offset(self) -> usize {
174 // We are dealing with little endian here (and if we aren't, we swap
175 // the bytes so we are in practice), where the most significant byte
176 // is at a higher address. That means the least significant bit that
177 // is set corresponds to the position of our first matching byte.
178 // That position corresponds to the number of zeros after the least
179 // significant bit.
180 self.get_for_offset().trailing_zeros() as usize
181 }
182
183 #[inline(always)]
184 fn last_offset(self) -> usize {
185 // We are dealing with little endian here (and if we aren't, we swap
186 // the bytes so we are in practice), where the most significant byte is
187 // at a higher address. That means the most significant bit that is set
188 // corresponds to the position of our last matching byte. The position
189 // from the end of the mask is therefore the number of leading zeros
190 // in a 32 bit integer, and the position from the start of the mask is
191 // therefore 32 - (leading zeros) - 1.
192 32 - self.get_for_offset().leading_zeros() as usize - 1
193 }
194}
195
196#[cfg(target_arch = "x86_64")]
197mod x86sse2 {
198 use core::arch::x86_64::*;
199
200 use super::{SensibleMoveMask, Vector};
201
202 impl Vector for __m128i {
203 const BITS: usize = 128;
204 const BYTES: usize = 16;
205 const ALIGN: usize = Self::BYTES - 1;
206
207 type Mask = SensibleMoveMask;
208
209 #[inline(always)]
210 unsafe fn splat(byte: u8) -> __m128i {
211 _mm_set1_epi8(byte as i8)
212 }
213
214 #[inline(always)]
215 unsafe fn load_aligned(data: *const u8) -> __m128i {
216 _mm_load_si128(data as *const __m128i)
217 }
218
219 #[inline(always)]
220 unsafe fn load_unaligned(data: *const u8) -> __m128i {
221 _mm_loadu_si128(data as *const __m128i)
222 }
223
224 #[inline(always)]
225 unsafe fn movemask(self) -> SensibleMoveMask {
226 SensibleMoveMask(_mm_movemask_epi8(self) as u32)
227 }
228
229 #[inline(always)]
230 unsafe fn cmpeq(self, vector2: Self) -> __m128i {
231 _mm_cmpeq_epi8(self, vector2)
232 }
233
234 #[inline(always)]
235 unsafe fn and(self, vector2: Self) -> __m128i {
236 _mm_and_si128(self, vector2)
237 }
238
239 #[inline(always)]
240 unsafe fn or(self, vector2: Self) -> __m128i {
241 _mm_or_si128(self, vector2)
242 }
243 }
244}
245
246#[cfg(target_arch = "x86_64")]
247mod x86avx2 {
248 use core::arch::x86_64::*;
249
250 use super::{SensibleMoveMask, Vector};
251
252 impl Vector for __m256i {
253 const BITS: usize = 256;
254 const BYTES: usize = 32;
255 const ALIGN: usize = Self::BYTES - 1;
256
257 type Mask = SensibleMoveMask;
258
259 #[inline(always)]
260 unsafe fn splat(byte: u8) -> __m256i {
261 _mm256_set1_epi8(byte as i8)
262 }
263
264 #[inline(always)]
265 unsafe fn load_aligned(data: *const u8) -> __m256i {
266 _mm256_load_si256(data as *const __m256i)
267 }
268
269 #[inline(always)]
270 unsafe fn load_unaligned(data: *const u8) -> __m256i {
271 _mm256_loadu_si256(data as *const __m256i)
272 }
273
274 #[inline(always)]
275 unsafe fn movemask(self) -> SensibleMoveMask {
276 SensibleMoveMask(_mm256_movemask_epi8(self) as u32)
277 }
278
279 #[inline(always)]
280 unsafe fn cmpeq(self, vector2: Self) -> __m256i {
281 _mm256_cmpeq_epi8(self, vector2)
282 }
283
284 #[inline(always)]
285 unsafe fn and(self, vector2: Self) -> __m256i {
286 _mm256_and_si256(self, vector2)
287 }
288
289 #[inline(always)]
290 unsafe fn or(self, vector2: Self) -> __m256i {
291 _mm256_or_si256(self, vector2)
292 }
293 }
294}
295
296#[cfg(target_arch = "aarch64")]
297mod aarch64neon {
298 use core::arch::aarch64::*;
299
300 use super::{MoveMask, Vector};
301
302 impl Vector for uint8x16_t {
303 const BITS: usize = 128;
304 const BYTES: usize = 16;
305 const ALIGN: usize = Self::BYTES - 1;
306
307 type Mask = NeonMoveMask;
308
309 #[inline(always)]
310 unsafe fn splat(byte: u8) -> uint8x16_t {
311 vdupq_n_u8(byte)
312 }
313
314 #[inline(always)]
315 unsafe fn load_aligned(data: *const u8) -> uint8x16_t {
316 // I've tried `data.cast::<uint8x16_t>().read()` instead, but
317 // couldn't observe any benchmark differences.
318 Self::load_unaligned(data)
319 }
320
321 #[inline(always)]
322 unsafe fn load_unaligned(data: *const u8) -> uint8x16_t {
323 vld1q_u8(data)
324 }
325
326 #[inline(always)]
327 unsafe fn movemask(self) -> NeonMoveMask {
328 let asu16s = vreinterpretq_u16_u8(self);
329 let mask = vshrn_n_u16(asu16s, 4);
330 let asu64 = vreinterpret_u64_u8(mask);
331 let scalar64 = vget_lane_u64(asu64, 0);
332 NeonMoveMask(scalar64 & 0x8888888888888888)
333 }
334
335 #[inline(always)]
336 unsafe fn cmpeq(self, vector2: Self) -> uint8x16_t {
337 vceqq_u8(self, vector2)
338 }
339
340 #[inline(always)]
341 unsafe fn and(self, vector2: Self) -> uint8x16_t {
342 vandq_u8(self, vector2)
343 }
344
345 #[inline(always)]
346 unsafe fn or(self, vector2: Self) -> uint8x16_t {
347 vorrq_u8(self, vector2)
348 }
349
350 /// This is the only interesting implementation of this routine.
351 /// Basically, instead of doing the "shift right narrow" dance, we use
352 /// adajacent folding max to determine whether there are any non-zero
353 /// bytes in our mask. If there are, *then* we'll do the "shift right
354 /// narrow" dance. In benchmarks, this does lead to slightly better
355 /// throughput, but the win doesn't appear huge.
356 #[inline(always)]
357 unsafe fn movemask_will_have_non_zero(self) -> bool {
358 let low = vreinterpretq_u64_u8(vpmaxq_u8(self, self));
359 vgetq_lane_u64(low, 0) != 0
360 }
361 }
362
363 /// Neon doesn't have a `movemask` that works like the one in x86-64, so we
364 /// wind up using a different method[1]. The different method also produces
365 /// a mask, but 4 bits are set in the neon case instead of a single bit set
366 /// in the x86-64 case. We do an extra step to zero out 3 of the 4 bits,
367 /// but we still wind up with at least 3 zeroes between each set bit. This
368 /// generally means that we need to do some division by 4 before extracting
369 /// offsets.
370 ///
371 /// In fact, the existence of this type is the entire reason that we have
372 /// the `MoveMask` trait in the first place. This basically lets us keep
373 /// the different representations of masks without being forced to unify
374 /// them into a single representation, which could result in extra and
375 /// unnecessary work.
376 ///
377 /// [1]: https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon
378 #[derive(Clone, Copy, Debug)]
379 pub(crate) struct NeonMoveMask(u64);
380
381 impl NeonMoveMask {
382 /// Get the mask in a form suitable for computing offsets.
383 ///
384 /// Basically, this normalizes to little endian. On big endian, this
385 /// swaps the bytes.
386 #[inline(always)]
387 fn get_for_offset(self) -> u64 {
388 #[cfg(target_endian = "big")]
389 {
390 self.0.swap_bytes()
391 }
392 #[cfg(target_endian = "little")]
393 {
394 self.0
395 }
396 }
397 }
398
399 impl MoveMask for NeonMoveMask {
400 #[inline(always)]
401 fn all_zeros_except_least_significant(n: usize) -> NeonMoveMask {
402 debug_assert!(n < 16);
403 NeonMoveMask(!(((1 << n) << 2) - 1))
404 }
405
406 #[inline(always)]
407 fn has_non_zero(self) -> bool {
408 self.0 != 0
409 }
410
411 #[inline(always)]
412 fn count_ones(self) -> usize {
413 self.0.count_ones() as usize
414 }
415
416 #[inline(always)]
417 fn and(self, other: NeonMoveMask) -> NeonMoveMask {
418 NeonMoveMask(self.0 & other.0)
419 }
420
421 #[inline(always)]
422 fn or(self, other: NeonMoveMask) -> NeonMoveMask {
423 NeonMoveMask(self.0 | other.0)
424 }
425
426 #[inline(always)]
427 fn clear_least_significant_bit(self) -> NeonMoveMask {
428 NeonMoveMask(self.0 & (self.0 - 1))
429 }
430
431 #[inline(always)]
432 fn first_offset(self) -> usize {
433 // We are dealing with little endian here (and if we aren't,
434 // we swap the bytes so we are in practice), where the most
435 // significant byte is at a higher address. That means the least
436 // significant bit that is set corresponds to the position of our
437 // first matching byte. That position corresponds to the number of
438 // zeros after the least significant bit.
439 //
440 // Note that unlike `SensibleMoveMask`, this mask has its bits
441 // spread out over 64 bits instead of 16 bits (for a 128 bit
442 // vector). Namely, where as x86-64 will turn
443 //
444 // 0x00 0xFF 0x00 0x00 0xFF
445 //
446 // into 10010, our neon approach will turn it into
447 //
448 // 10000000000010000000
449 //
450 // And this happens because neon doesn't have a native `movemask`
451 // instruction, so we kind of fake it[1]. Thus, we divide the
452 // number of trailing zeros by 4 to get the "real" offset.
453 //
454 // [1]: https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon
455 (self.get_for_offset().trailing_zeros() >> 2) as usize
456 }
457
458 #[inline(always)]
459 fn last_offset(self) -> usize {
460 // See comment in `first_offset` above. This is basically the same,
461 // but coming from the other direction.
462 16 - (self.get_for_offset().leading_zeros() >> 2) as usize - 1
463 }
464 }
465}
466
467#[cfg(target_arch = "wasm32")]
468mod wasm_simd128 {
469 use core::arch::wasm32::*;
470
471 use super::{SensibleMoveMask, Vector};
472
473 impl Vector for v128 {
474 const BITS: usize = 128;
475 const BYTES: usize = 16;
476 const ALIGN: usize = Self::BYTES - 1;
477
478 type Mask = SensibleMoveMask;
479
480 #[inline(always)]
481 unsafe fn splat(byte: u8) -> v128 {
482 u8x16_splat(byte)
483 }
484
485 #[inline(always)]
486 unsafe fn load_aligned(data: *const u8) -> v128 {
487 *data.cast()
488 }
489
490 #[inline(always)]
491 unsafe fn load_unaligned(data: *const u8) -> v128 {
492 v128_load(data.cast())
493 }
494
495 #[inline(always)]
496 unsafe fn movemask(self) -> SensibleMoveMask {
497 SensibleMoveMask(u8x16_bitmask(self).into())
498 }
499
500 #[inline(always)]
501 unsafe fn cmpeq(self, vector2: Self) -> v128 {
502 u8x16_eq(self, vector2)
503 }
504
505 #[inline(always)]
506 unsafe fn and(self, vector2: Self) -> v128 {
507 v128_and(self, vector2)
508 }
509
510 #[inline(always)]
511 unsafe fn or(self, vector2: Self) -> v128 {
512 v128_or(self, vector2)
513 }
514 }
515}
516