1//! Masks that take up full SIMD vector registers.
2
3use crate::simd::intrinsics;
4use crate::simd::{LaneCount, MaskElement, Simd, SupportedLaneCount};
5
6#[repr(transparent)]
7pub struct Mask<T, const N: usize>(Simd<T, N>)
8where
9 T: MaskElement,
10 LaneCount<N>: SupportedLaneCount;
11
12impl<T, const N: usize> Copy for Mask<T, N>
13where
14 T: MaskElement,
15 LaneCount<N>: SupportedLaneCount,
16{
17}
18
19impl<T, const N: usize> Clone for Mask<T, N>
20where
21 T: MaskElement,
22 LaneCount<N>: SupportedLaneCount,
23{
24 #[inline]
25 #[must_use = "method returns a new mask and does not mutate the original value"]
26 fn clone(&self) -> Self {
27 *self
28 }
29}
30
31impl<T, const N: usize> PartialEq for Mask<T, N>
32where
33 T: MaskElement + PartialEq,
34 LaneCount<N>: SupportedLaneCount,
35{
36 #[inline]
37 fn eq(&self, other: &Self) -> bool {
38 self.0.eq(&other.0)
39 }
40}
41
42impl<T, const N: usize> PartialOrd for Mask<T, N>
43where
44 T: MaskElement + PartialOrd,
45 LaneCount<N>: SupportedLaneCount,
46{
47 #[inline]
48 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
49 self.0.partial_cmp(&other.0)
50 }
51}
52
53impl<T, const N: usize> Eq for Mask<T, N>
54where
55 T: MaskElement + Eq,
56 LaneCount<N>: SupportedLaneCount,
57{
58}
59
60impl<T, const N: usize> Ord for Mask<T, N>
61where
62 T: MaskElement + Ord,
63 LaneCount<N>: SupportedLaneCount,
64{
65 #[inline]
66 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
67 self.0.cmp(&other.0)
68 }
69}
70
71// Used for bitmask bit order workaround
72pub(crate) trait ReverseBits {
73 // Reverse the least significant `n` bits of `self`.
74 // (Remaining bits must be 0.)
75 fn reverse_bits(self, n: usize) -> Self;
76}
77
78macro_rules! impl_reverse_bits {
79 { $($int:ty),* } => {
80 $(
81 impl ReverseBits for $int {
82 #[inline(always)]
83 fn reverse_bits(self, n: usize) -> Self {
84 let rev = <$int>::reverse_bits(self);
85 let bitsize = core::mem::size_of::<$int>() * 8;
86 if n < bitsize {
87 // Shift things back to the right
88 rev >> (bitsize - n)
89 } else {
90 rev
91 }
92 }
93 }
94 )*
95 }
96}
97
98impl_reverse_bits! { u8, u16, u32, u64 }
99
100impl<T, const N: usize> Mask<T, N>
101where
102 T: MaskElement,
103 LaneCount<N>: SupportedLaneCount,
104{
105 #[inline]
106 #[must_use = "method returns a new mask and does not mutate the original value"]
107 pub fn splat(value: bool) -> Self {
108 Self(Simd::splat(if value { T::TRUE } else { T::FALSE }))
109 }
110
111 #[inline]
112 #[must_use = "method returns a new bool and does not mutate the original value"]
113 pub unsafe fn test_unchecked(&self, lane: usize) -> bool {
114 T::eq(self.0[lane], T::TRUE)
115 }
116
117 #[inline]
118 pub unsafe fn set_unchecked(&mut self, lane: usize, value: bool) {
119 self.0[lane] = if value { T::TRUE } else { T::FALSE }
120 }
121
122 #[inline]
123 #[must_use = "method returns a new vector and does not mutate the original value"]
124 pub fn to_int(self) -> Simd<T, N> {
125 self.0
126 }
127
128 #[inline]
129 #[must_use = "method returns a new mask and does not mutate the original value"]
130 pub unsafe fn from_int_unchecked(value: Simd<T, N>) -> Self {
131 Self(value)
132 }
133
134 #[inline]
135 #[must_use = "method returns a new mask and does not mutate the original value"]
136 pub fn convert<U>(self) -> Mask<U, N>
137 where
138 U: MaskElement,
139 {
140 // Safety: masks are simply integer vectors of 0 and -1, and we can cast the element type.
141 unsafe { Mask(intrinsics::simd_cast(self.0)) }
142 }
143
144 #[inline]
145 #[must_use = "method returns a new vector and does not mutate the original value"]
146 pub fn to_bitmask_vector(self) -> Simd<u8, N> {
147 let mut bitmask = Simd::splat(0);
148
149 // Safety: Bytes is the right size array
150 unsafe {
151 // Compute the bitmask
152 let mut bytes: <LaneCount<N> as SupportedLaneCount>::BitMask =
153 intrinsics::simd_bitmask(self.0);
154
155 // LLVM assumes bit order should match endianness
156 if cfg!(target_endian = "big") {
157 for x in bytes.as_mut() {
158 *x = x.reverse_bits()
159 }
160 }
161
162 bitmask.as_mut_array()[..bytes.as_ref().len()].copy_from_slice(bytes.as_ref());
163 }
164
165 bitmask
166 }
167
168 #[inline]
169 #[must_use = "method returns a new mask and does not mutate the original value"]
170 pub fn from_bitmask_vector(bitmask: Simd<u8, N>) -> Self {
171 let mut bytes = <LaneCount<N> as SupportedLaneCount>::BitMask::default();
172
173 // Safety: Bytes is the right size array
174 unsafe {
175 let len = bytes.as_ref().len();
176 bytes.as_mut().copy_from_slice(&bitmask.as_array()[..len]);
177
178 // LLVM assumes bit order should match endianness
179 if cfg!(target_endian = "big") {
180 for x in bytes.as_mut() {
181 *x = x.reverse_bits();
182 }
183 }
184
185 // Compute the regular mask
186 Self::from_int_unchecked(intrinsics::simd_select_bitmask(
187 bytes,
188 Self::splat(true).to_int(),
189 Self::splat(false).to_int(),
190 ))
191 }
192 }
193
194 #[inline]
195 unsafe fn to_bitmask_impl<U: ReverseBits, const M: usize>(self) -> U
196 where
197 LaneCount<M>: SupportedLaneCount,
198 {
199 let resized = self.to_int().resize::<M>(T::FALSE);
200
201 // Safety: `resized` is an integer vector with length M, which must match T
202 let bitmask: U = unsafe { intrinsics::simd_bitmask(resized) };
203
204 // LLVM assumes bit order should match endianness
205 if cfg!(target_endian = "big") {
206 bitmask.reverse_bits(M)
207 } else {
208 bitmask
209 }
210 }
211
212 #[inline]
213 unsafe fn from_bitmask_impl<U: ReverseBits, const M: usize>(bitmask: U) -> Self
214 where
215 LaneCount<M>: SupportedLaneCount,
216 {
217 // LLVM assumes bit order should match endianness
218 let bitmask = if cfg!(target_endian = "big") {
219 bitmask.reverse_bits(M)
220 } else {
221 bitmask
222 };
223
224 // SAFETY: `mask` is the correct bitmask type for a u64 bitmask
225 let mask: Simd<T, M> = unsafe {
226 intrinsics::simd_select_bitmask(
227 bitmask,
228 Simd::<T, M>::splat(T::TRUE),
229 Simd::<T, M>::splat(T::FALSE),
230 )
231 };
232
233 // SAFETY: `mask` only contains `T::TRUE` or `T::FALSE`
234 unsafe { Self::from_int_unchecked(mask.resize::<N>(T::FALSE)) }
235 }
236
237 #[inline]
238 pub(crate) fn to_bitmask_integer(self) -> u64 {
239 // TODO modify simd_bitmask to zero-extend output, making this unnecessary
240 if N <= 8 {
241 // Safety: bitmask matches length
242 unsafe { self.to_bitmask_impl::<u8, 8>() as u64 }
243 } else if N <= 16 {
244 // Safety: bitmask matches length
245 unsafe { self.to_bitmask_impl::<u16, 16>() as u64 }
246 } else if N <= 32 {
247 // Safety: bitmask matches length
248 unsafe { self.to_bitmask_impl::<u32, 32>() as u64 }
249 } else {
250 // Safety: bitmask matches length
251 unsafe { self.to_bitmask_impl::<u64, 64>() }
252 }
253 }
254
255 #[inline]
256 pub(crate) fn from_bitmask_integer(bitmask: u64) -> Self {
257 // TODO modify simd_bitmask_select to truncate input, making this unnecessary
258 if N <= 8 {
259 // Safety: bitmask matches length
260 unsafe { Self::from_bitmask_impl::<u8, 8>(bitmask as u8) }
261 } else if N <= 16 {
262 // Safety: bitmask matches length
263 unsafe { Self::from_bitmask_impl::<u16, 16>(bitmask as u16) }
264 } else if N <= 32 {
265 // Safety: bitmask matches length
266 unsafe { Self::from_bitmask_impl::<u32, 32>(bitmask as u32) }
267 } else {
268 // Safety: bitmask matches length
269 unsafe { Self::from_bitmask_impl::<u64, 64>(bitmask) }
270 }
271 }
272
273 #[inline]
274 #[must_use = "method returns a new bool and does not mutate the original value"]
275 pub fn any(self) -> bool {
276 // Safety: use `self` as an integer vector
277 unsafe { intrinsics::simd_reduce_any(self.to_int()) }
278 }
279
280 #[inline]
281 #[must_use = "method returns a new vector and does not mutate the original value"]
282 pub fn all(self) -> bool {
283 // Safety: use `self` as an integer vector
284 unsafe { intrinsics::simd_reduce_all(self.to_int()) }
285 }
286}
287
288impl<T, const N: usize> From<Mask<T, N>> for Simd<T, N>
289where
290 T: MaskElement,
291 LaneCount<N>: SupportedLaneCount,
292{
293 #[inline]
294 fn from(value: Mask<T, N>) -> Self {
295 value.0
296 }
297}
298
299impl<T, const N: usize> core::ops::BitAnd for Mask<T, N>
300where
301 T: MaskElement,
302 LaneCount<N>: SupportedLaneCount,
303{
304 type Output = Self;
305 #[inline]
306 #[must_use = "method returns a new mask and does not mutate the original value"]
307 fn bitand(self, rhs: Self) -> Self {
308 // Safety: `self` is an integer vector
309 unsafe { Self(intrinsics::simd_and(self.0, y:rhs.0)) }
310 }
311}
312
313impl<T, const N: usize> core::ops::BitOr for Mask<T, N>
314where
315 T: MaskElement,
316 LaneCount<N>: SupportedLaneCount,
317{
318 type Output = Self;
319 #[inline]
320 #[must_use = "method returns a new mask and does not mutate the original value"]
321 fn bitor(self, rhs: Self) -> Self {
322 // Safety: `self` is an integer vector
323 unsafe { Self(intrinsics::simd_or(self.0, y:rhs.0)) }
324 }
325}
326
327impl<T, const N: usize> core::ops::BitXor for Mask<T, N>
328where
329 T: MaskElement,
330 LaneCount<N>: SupportedLaneCount,
331{
332 type Output = Self;
333 #[inline]
334 #[must_use = "method returns a new mask and does not mutate the original value"]
335 fn bitxor(self, rhs: Self) -> Self {
336 // Safety: `self` is an integer vector
337 unsafe { Self(intrinsics::simd_xor(self.0, y:rhs.0)) }
338 }
339}
340
341impl<T, const N: usize> core::ops::Not for Mask<T, N>
342where
343 T: MaskElement,
344 LaneCount<N>: SupportedLaneCount,
345{
346 type Output = Self;
347 #[inline]
348 #[must_use = "method returns a new mask and does not mutate the original value"]
349 fn not(self) -> Self::Output {
350 Self::splat(true) ^ self
351 }
352}
353