1//! Types and traits associated with masking elements of vectors.
2//! Types representing
3#![allow(non_camel_case_types)]
4
5#[cfg_attr(
6 not(all(target_arch = "x86_64", target_feature = "avx512f")),
7 path = "masks/full_masks.rs"
8)]
9#[cfg_attr(
10 all(target_arch = "x86_64", target_feature = "avx512f"),
11 path = "masks/bitmask.rs"
12)]
13mod mask_impl;
14
15use crate::simd::{
16 cmp::SimdPartialEq, intrinsics, LaneCount, Simd, SimdCast, SimdElement, SupportedLaneCount,
17};
18use core::cmp::Ordering;
19use core::{fmt, mem};
20
21mod sealed {
22 use super::*;
23
24 /// Not only does this seal the `MaskElement` trait, but these functions prevent other traits
25 /// from bleeding into the parent bounds.
26 ///
27 /// For example, `eq` could be provided by requiring `MaskElement: PartialEq`, but that would
28 /// prevent us from ever removing that bound, or from implementing `MaskElement` on
29 /// non-`PartialEq` types in the future.
30 pub trait Sealed {
31 fn valid<const N: usize>(values: Simd<Self, N>) -> bool
32 where
33 LaneCount<N>: SupportedLaneCount,
34 Self: SimdElement;
35
36 fn eq(self, other: Self) -> bool;
37
38 fn as_usize(self) -> usize;
39
40 type Unsigned: SimdElement;
41
42 const TRUE: Self;
43
44 const FALSE: Self;
45 }
46}
47use sealed::Sealed;
48
49/// Marker trait for types that may be used as SIMD mask elements.
50///
51/// # Safety
52/// Type must be a signed integer.
53pub unsafe trait MaskElement: SimdElement<Mask = Self> + SimdCast + Sealed {}
54
55macro_rules! impl_element {
56 { $ty:ty, $unsigned:ty } => {
57 impl Sealed for $ty {
58 #[inline]
59 fn valid<const N: usize>(value: Simd<Self, N>) -> bool
60 where
61 LaneCount<N>: SupportedLaneCount,
62 {
63 (value.simd_eq(Simd::splat(0 as _)) | value.simd_eq(Simd::splat(-1 as _))).all()
64 }
65
66 #[inline]
67 fn eq(self, other: Self) -> bool { self == other }
68
69 #[inline]
70 fn as_usize(self) -> usize {
71 self as usize
72 }
73
74 type Unsigned = $unsigned;
75
76 const TRUE: Self = -1;
77 const FALSE: Self = 0;
78 }
79
80 // Safety: this is a valid mask element type
81 unsafe impl MaskElement for $ty {}
82 }
83}
84
85impl_element! { i8, u8 }
86impl_element! { i16, u16 }
87impl_element! { i32, u32 }
88impl_element! { i64, u64 }
89impl_element! { isize, usize }
90
91/// A SIMD vector mask for `N` elements of width specified by `Element`.
92///
93/// Masks represent boolean inclusion/exclusion on a per-element basis.
94///
95/// The layout of this type is unspecified, and may change between platforms
96/// and/or Rust versions, and code should not assume that it is equivalent to
97/// `[T; N]`.
98#[repr(transparent)]
99pub struct Mask<T, const N: usize>(mask_impl::Mask<T, N>)
100where
101 T: MaskElement,
102 LaneCount<N>: SupportedLaneCount;
103
104impl<T, const N: usize> Copy for Mask<T, N>
105where
106 T: MaskElement,
107 LaneCount<N>: SupportedLaneCount,
108{
109}
110
111impl<T, const N: usize> Clone for Mask<T, N>
112where
113 T: MaskElement,
114 LaneCount<N>: SupportedLaneCount,
115{
116 #[inline]
117 fn clone(&self) -> Self {
118 *self
119 }
120}
121
122impl<T, const N: usize> Mask<T, N>
123where
124 T: MaskElement,
125 LaneCount<N>: SupportedLaneCount,
126{
127 /// Construct a mask by setting all elements to the given value.
128 #[inline]
129 pub fn splat(value: bool) -> Self {
130 Self(mask_impl::Mask::splat(value))
131 }
132
133 /// Converts an array of bools to a SIMD mask.
134 #[inline]
135 pub fn from_array(array: [bool; N]) -> Self {
136 // SAFETY: Rust's bool has a layout of 1 byte (u8) with a value of
137 // true: 0b_0000_0001
138 // false: 0b_0000_0000
139 // Thus, an array of bools is also a valid array of bytes: [u8; N]
140 // This would be hypothetically valid as an "in-place" transmute,
141 // but these are "dependently-sized" types, so copy elision it is!
142 unsafe {
143 let bytes: [u8; N] = mem::transmute_copy(&array);
144 let bools: Simd<i8, N> = intrinsics::simd_ne(Simd::from_array(bytes), Simd::splat(0u8));
145 Mask::from_int_unchecked(intrinsics::simd_cast(bools))
146 }
147 }
148
149 /// Converts a SIMD mask to an array of bools.
150 #[inline]
151 pub fn to_array(self) -> [bool; N] {
152 // This follows mostly the same logic as from_array.
153 // SAFETY: Rust's bool has a layout of 1 byte (u8) with a value of
154 // true: 0b_0000_0001
155 // false: 0b_0000_0000
156 // Thus, an array of bools is also a valid array of bytes: [u8; N]
157 // Since our masks are equal to integers where all bits are set,
158 // we can simply convert them to i8s, and then bitand them by the
159 // bitpattern for Rust's "true" bool.
160 // This would be hypothetically valid as an "in-place" transmute,
161 // but these are "dependently-sized" types, so copy elision it is!
162 unsafe {
163 let mut bytes: Simd<i8, N> = intrinsics::simd_cast(self.to_int());
164 bytes &= Simd::splat(1i8);
165 mem::transmute_copy(&bytes)
166 }
167 }
168
169 /// Converts a vector of integers to a mask, where 0 represents `false` and -1
170 /// represents `true`.
171 ///
172 /// # Safety
173 /// All elements must be either 0 or -1.
174 #[inline]
175 #[must_use = "method returns a new mask and does not mutate the original value"]
176 pub unsafe fn from_int_unchecked(value: Simd<T, N>) -> Self {
177 // Safety: the caller must confirm this invariant
178 unsafe { Self(mask_impl::Mask::from_int_unchecked(value)) }
179 }
180
181 /// Converts a vector of integers to a mask, where 0 represents `false` and -1
182 /// represents `true`.
183 ///
184 /// # Panics
185 /// Panics if any element is not 0 or -1.
186 #[inline]
187 #[must_use = "method returns a new mask and does not mutate the original value"]
188 #[track_caller]
189 pub fn from_int(value: Simd<T, N>) -> Self {
190 assert!(T::valid(value), "all values must be either 0 or -1",);
191 // Safety: the validity has been checked
192 unsafe { Self::from_int_unchecked(value) }
193 }
194
195 /// Converts the mask to a vector of integers, where 0 represents `false` and -1
196 /// represents `true`.
197 #[inline]
198 #[must_use = "method returns a new vector and does not mutate the original value"]
199 pub fn to_int(self) -> Simd<T, N> {
200 self.0.to_int()
201 }
202
203 /// Converts the mask to a mask of any other element size.
204 #[inline]
205 #[must_use = "method returns a new mask and does not mutate the original value"]
206 pub fn cast<U: MaskElement>(self) -> Mask<U, N> {
207 Mask(self.0.convert())
208 }
209
210 /// Tests the value of the specified element.
211 ///
212 /// # Safety
213 /// `index` must be less than `self.len()`.
214 #[inline]
215 #[must_use = "method returns a new bool and does not mutate the original value"]
216 pub unsafe fn test_unchecked(&self, index: usize) -> bool {
217 // Safety: the caller must confirm this invariant
218 unsafe { self.0.test_unchecked(index) }
219 }
220
221 /// Tests the value of the specified element.
222 ///
223 /// # Panics
224 /// Panics if `index` is greater than or equal to the number of elements in the vector.
225 #[inline]
226 #[must_use = "method returns a new bool and does not mutate the original value"]
227 #[track_caller]
228 pub fn test(&self, index: usize) -> bool {
229 assert!(index < N, "element index out of range");
230 // Safety: the element index has been checked
231 unsafe { self.test_unchecked(index) }
232 }
233
234 /// Sets the value of the specified element.
235 ///
236 /// # Safety
237 /// `index` must be less than `self.len()`.
238 #[inline]
239 pub unsafe fn set_unchecked(&mut self, index: usize, value: bool) {
240 // Safety: the caller must confirm this invariant
241 unsafe {
242 self.0.set_unchecked(index, value);
243 }
244 }
245
246 /// Sets the value of the specified element.
247 ///
248 /// # Panics
249 /// Panics if `index` is greater than or equal to the number of elements in the vector.
250 #[inline]
251 #[track_caller]
252 pub fn set(&mut self, index: usize, value: bool) {
253 assert!(index < N, "element index out of range");
254 // Safety: the element index has been checked
255 unsafe {
256 self.set_unchecked(index, value);
257 }
258 }
259
260 /// Returns true if any element is set, or false otherwise.
261 #[inline]
262 #[must_use = "method returns a new bool and does not mutate the original value"]
263 pub fn any(self) -> bool {
264 self.0.any()
265 }
266
267 /// Returns true if all elements are set, or false otherwise.
268 #[inline]
269 #[must_use = "method returns a new bool and does not mutate the original value"]
270 pub fn all(self) -> bool {
271 self.0.all()
272 }
273
274 /// Create a bitmask from a mask.
275 ///
276 /// Each bit is set if the corresponding element in the mask is `true`.
277 /// If the mask contains more than 64 elements, the bitmask is truncated to the first 64.
278 #[inline]
279 #[must_use = "method returns a new integer and does not mutate the original value"]
280 pub fn to_bitmask(self) -> u64 {
281 self.0.to_bitmask_integer()
282 }
283
284 /// Create a mask from a bitmask.
285 ///
286 /// For each bit, if it is set, the corresponding element in the mask is set to `true`.
287 /// If the mask contains more than 64 elements, the remainder are set to `false`.
288 #[inline]
289 #[must_use = "method returns a new mask and does not mutate the original value"]
290 pub fn from_bitmask(bitmask: u64) -> Self {
291 Self(mask_impl::Mask::from_bitmask_integer(bitmask))
292 }
293
294 /// Create a bitmask vector from a mask.
295 ///
296 /// Each bit is set if the corresponding element in the mask is `true`.
297 /// The remaining bits are unset.
298 ///
299 /// The bits are packed into the first N bits of the vector:
300 /// ```
301 /// # #![feature(portable_simd)]
302 /// # #[cfg(feature = "as_crate")] use core_simd::simd;
303 /// # #[cfg(not(feature = "as_crate"))] use core::simd;
304 /// # use simd::mask32x8;
305 /// let mask = mask32x8::from_array([true, false, true, false, false, false, true, false]);
306 /// assert_eq!(mask.to_bitmask_vector()[0], 0b01000101);
307 /// ```
308 #[inline]
309 #[must_use = "method returns a new integer and does not mutate the original value"]
310 pub fn to_bitmask_vector(self) -> Simd<u8, N> {
311 self.0.to_bitmask_vector()
312 }
313
314 /// Create a mask from a bitmask vector.
315 ///
316 /// For each bit, if it is set, the corresponding element in the mask is set to `true`.
317 ///
318 /// The bits are packed into the first N bits of the vector:
319 /// ```
320 /// # #![feature(portable_simd)]
321 /// # #[cfg(feature = "as_crate")] use core_simd::simd;
322 /// # #[cfg(not(feature = "as_crate"))] use core::simd;
323 /// # use simd::{mask32x8, u8x8};
324 /// let bitmask = u8x8::from_array([0b01000101, 0, 0, 0, 0, 0, 0, 0]);
325 /// assert_eq!(
326 /// mask32x8::from_bitmask_vector(bitmask),
327 /// mask32x8::from_array([true, false, true, false, false, false, true, false]),
328 /// );
329 /// ```
330 #[inline]
331 #[must_use = "method returns a new mask and does not mutate the original value"]
332 pub fn from_bitmask_vector(bitmask: Simd<u8, N>) -> Self {
333 Self(mask_impl::Mask::from_bitmask_vector(bitmask))
334 }
335
336 /// Find the index of the first set element.
337 ///
338 /// ```
339 /// # #![feature(portable_simd)]
340 /// # #[cfg(feature = "as_crate")] use core_simd::simd;
341 /// # #[cfg(not(feature = "as_crate"))] use core::simd;
342 /// # use simd::mask32x8;
343 /// assert_eq!(mask32x8::splat(false).first_set(), None);
344 /// assert_eq!(mask32x8::splat(true).first_set(), Some(0));
345 ///
346 /// let mask = mask32x8::from_array([false, true, false, false, true, false, false, true]);
347 /// assert_eq!(mask.first_set(), Some(1));
348 /// ```
349 #[inline]
350 #[must_use = "method returns the index and does not mutate the original value"]
351 pub fn first_set(self) -> Option<usize> {
352 // If bitmasks are efficient, using them is better
353 if cfg!(target_feature = "sse") && N <= 64 {
354 let tz = self.to_bitmask().trailing_zeros();
355 return if tz == 64 { None } else { Some(tz as usize) };
356 }
357
358 // To find the first set index:
359 // * create a vector 0..N
360 // * replace unset mask elements in that vector with -1
361 // * perform _unsigned_ reduce-min
362 // * check if the result is -1 or an index
363
364 let index = Simd::from_array(
365 const {
366 let mut index = [0; N];
367 let mut i = 0;
368 while i < N {
369 index[i] = i;
370 i += 1;
371 }
372 index
373 },
374 );
375
376 // Safety: the input and output are integer vectors
377 let index: Simd<T, N> = unsafe { intrinsics::simd_cast(index) };
378
379 let masked_index = self.select(index, Self::splat(true).to_int());
380
381 // Safety: the input and output are integer vectors
382 let masked_index: Simd<T::Unsigned, N> = unsafe { intrinsics::simd_cast(masked_index) };
383
384 // Safety: the input is an integer vector
385 let min_index: T::Unsigned = unsafe { intrinsics::simd_reduce_min(masked_index) };
386
387 // Safety: the return value is the unsigned version of T
388 let min_index: T = unsafe { core::mem::transmute_copy(&min_index) };
389
390 if min_index.eq(T::TRUE) {
391 None
392 } else {
393 Some(min_index.as_usize())
394 }
395 }
396}
397
398// vector/array conversion
399impl<T, const N: usize> From<[bool; N]> for Mask<T, N>
400where
401 T: MaskElement,
402 LaneCount<N>: SupportedLaneCount,
403{
404 #[inline]
405 fn from(array: [bool; N]) -> Self {
406 Self::from_array(array)
407 }
408}
409
410impl<T, const N: usize> From<Mask<T, N>> for [bool; N]
411where
412 T: MaskElement,
413 LaneCount<N>: SupportedLaneCount,
414{
415 #[inline]
416 fn from(vector: Mask<T, N>) -> Self {
417 vector.to_array()
418 }
419}
420
421impl<T, const N: usize> Default for Mask<T, N>
422where
423 T: MaskElement,
424 LaneCount<N>: SupportedLaneCount,
425{
426 #[inline]
427 #[must_use = "method returns a defaulted mask with all elements set to false (0)"]
428 fn default() -> Self {
429 Self::splat(false)
430 }
431}
432
433impl<T, const N: usize> PartialEq for Mask<T, N>
434where
435 T: MaskElement + PartialEq,
436 LaneCount<N>: SupportedLaneCount,
437{
438 #[inline]
439 #[must_use = "method returns a new bool and does not mutate the original value"]
440 fn eq(&self, other: &Self) -> bool {
441 self.0 == other.0
442 }
443}
444
445impl<T, const N: usize> PartialOrd for Mask<T, N>
446where
447 T: MaskElement + PartialOrd,
448 LaneCount<N>: SupportedLaneCount,
449{
450 #[inline]
451 #[must_use = "method returns a new Ordering and does not mutate the original value"]
452 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
453 self.0.partial_cmp(&other.0)
454 }
455}
456
457impl<T, const N: usize> fmt::Debug for Mask<T, N>
458where
459 T: MaskElement + fmt::Debug,
460 LaneCount<N>: SupportedLaneCount,
461{
462 #[inline]
463 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
464 f&mut DebugList<'_, '_>.debug_list()
465 .entries((0..N).map(|i: usize| self.test(index:i)))
466 .finish()
467 }
468}
469
470impl<T, const N: usize> core::ops::BitAnd for Mask<T, N>
471where
472 T: MaskElement,
473 LaneCount<N>: SupportedLaneCount,
474{
475 type Output = Self;
476 #[inline]
477 #[must_use = "method returns a new mask and does not mutate the original value"]
478 fn bitand(self, rhs: Self) -> Self {
479 Self(self.0 & rhs.0)
480 }
481}
482
483impl<T, const N: usize> core::ops::BitAnd<bool> for Mask<T, N>
484where
485 T: MaskElement,
486 LaneCount<N>: SupportedLaneCount,
487{
488 type Output = Self;
489 #[inline]
490 #[must_use = "method returns a new mask and does not mutate the original value"]
491 fn bitand(self, rhs: bool) -> Self {
492 self & Self::splat(rhs)
493 }
494}
495
496impl<T, const N: usize> core::ops::BitAnd<Mask<T, N>> for bool
497where
498 T: MaskElement,
499 LaneCount<N>: SupportedLaneCount,
500{
501 type Output = Mask<T, N>;
502 #[inline]
503 #[must_use = "method returns a new mask and does not mutate the original value"]
504 fn bitand(self, rhs: Mask<T, N>) -> Mask<T, N> {
505 Mask::splat(self) & rhs
506 }
507}
508
509impl<T, const N: usize> core::ops::BitOr for Mask<T, N>
510where
511 T: MaskElement,
512 LaneCount<N>: SupportedLaneCount,
513{
514 type Output = Self;
515 #[inline]
516 #[must_use = "method returns a new mask and does not mutate the original value"]
517 fn bitor(self, rhs: Self) -> Self {
518 Self(self.0 | rhs.0)
519 }
520}
521
522impl<T, const N: usize> core::ops::BitOr<bool> for Mask<T, N>
523where
524 T: MaskElement,
525 LaneCount<N>: SupportedLaneCount,
526{
527 type Output = Self;
528 #[inline]
529 #[must_use = "method returns a new mask and does not mutate the original value"]
530 fn bitor(self, rhs: bool) -> Self {
531 self | Self::splat(rhs)
532 }
533}
534
535impl<T, const N: usize> core::ops::BitOr<Mask<T, N>> for bool
536where
537 T: MaskElement,
538 LaneCount<N>: SupportedLaneCount,
539{
540 type Output = Mask<T, N>;
541 #[inline]
542 #[must_use = "method returns a new mask and does not mutate the original value"]
543 fn bitor(self, rhs: Mask<T, N>) -> Mask<T, N> {
544 Mask::splat(self) | rhs
545 }
546}
547
548impl<T, const N: usize> core::ops::BitXor for Mask<T, N>
549where
550 T: MaskElement,
551 LaneCount<N>: SupportedLaneCount,
552{
553 type Output = Self;
554 #[inline]
555 #[must_use = "method returns a new mask and does not mutate the original value"]
556 fn bitxor(self, rhs: Self) -> Self::Output {
557 Self(self.0 ^ rhs.0)
558 }
559}
560
561impl<T, const N: usize> core::ops::BitXor<bool> for Mask<T, N>
562where
563 T: MaskElement,
564 LaneCount<N>: SupportedLaneCount,
565{
566 type Output = Self;
567 #[inline]
568 #[must_use = "method returns a new mask and does not mutate the original value"]
569 fn bitxor(self, rhs: bool) -> Self::Output {
570 self ^ Self::splat(rhs)
571 }
572}
573
574impl<T, const N: usize> core::ops::BitXor<Mask<T, N>> for bool
575where
576 T: MaskElement,
577 LaneCount<N>: SupportedLaneCount,
578{
579 type Output = Mask<T, N>;
580 #[inline]
581 #[must_use = "method returns a new mask and does not mutate the original value"]
582 fn bitxor(self, rhs: Mask<T, N>) -> Self::Output {
583 Mask::splat(self) ^ rhs
584 }
585}
586
587impl<T, const N: usize> core::ops::Not for Mask<T, N>
588where
589 T: MaskElement,
590 LaneCount<N>: SupportedLaneCount,
591{
592 type Output = Mask<T, N>;
593 #[inline]
594 #[must_use = "method returns a new mask and does not mutate the original value"]
595 fn not(self) -> Self::Output {
596 Self(!self.0)
597 }
598}
599
600impl<T, const N: usize> core::ops::BitAndAssign for Mask<T, N>
601where
602 T: MaskElement,
603 LaneCount<N>: SupportedLaneCount,
604{
605 #[inline]
606 fn bitand_assign(&mut self, rhs: Self) {
607 self.0 = self.0 & rhs.0;
608 }
609}
610
611impl<T, const N: usize> core::ops::BitAndAssign<bool> for Mask<T, N>
612where
613 T: MaskElement,
614 LaneCount<N>: SupportedLaneCount,
615{
616 #[inline]
617 fn bitand_assign(&mut self, rhs: bool) {
618 *self &= Self::splat(rhs);
619 }
620}
621
622impl<T, const N: usize> core::ops::BitOrAssign for Mask<T, N>
623where
624 T: MaskElement,
625 LaneCount<N>: SupportedLaneCount,
626{
627 #[inline]
628 fn bitor_assign(&mut self, rhs: Self) {
629 self.0 = self.0 | rhs.0;
630 }
631}
632
633impl<T, const N: usize> core::ops::BitOrAssign<bool> for Mask<T, N>
634where
635 T: MaskElement,
636 LaneCount<N>: SupportedLaneCount,
637{
638 #[inline]
639 fn bitor_assign(&mut self, rhs: bool) {
640 *self |= Self::splat(rhs);
641 }
642}
643
644impl<T, const N: usize> core::ops::BitXorAssign for Mask<T, N>
645where
646 T: MaskElement,
647 LaneCount<N>: SupportedLaneCount,
648{
649 #[inline]
650 fn bitxor_assign(&mut self, rhs: Self) {
651 self.0 = self.0 ^ rhs.0;
652 }
653}
654
655impl<T, const N: usize> core::ops::BitXorAssign<bool> for Mask<T, N>
656where
657 T: MaskElement,
658 LaneCount<N>: SupportedLaneCount,
659{
660 #[inline]
661 fn bitxor_assign(&mut self, rhs: bool) {
662 *self ^= Self::splat(rhs);
663 }
664}
665
666macro_rules! impl_from {
667 { $from:ty => $($to:ty),* } => {
668 $(
669 impl<const N: usize> From<Mask<$from, N>> for Mask<$to, N>
670 where
671 LaneCount<N>: SupportedLaneCount,
672 {
673 #[inline]
674 fn from(value: Mask<$from, N>) -> Self {
675 value.cast()
676 }
677 }
678 )*
679 }
680}
681impl_from! { i8 => i16, i32, i64, isize }
682impl_from! { i16 => i32, i64, isize, i8 }
683impl_from! { i32 => i64, isize, i8, i16 }
684impl_from! { i64 => isize, i8, i16, i32 }
685impl_from! { isize => i8, i16, i32, i64 }
686