1//! Masks that take up full SIMD vector registers.
2
3use crate::simd::{LaneCount, MaskElement, Simd, SupportedLaneCount};
4
5#[repr(transparent)]
6pub struct Mask<T, const N: usize>(Simd<T, N>)
7where
8 T: MaskElement,
9 LaneCount<N>: SupportedLaneCount;
10
11impl<T, const N: usize> Copy for Mask<T, N>
12where
13 T: MaskElement,
14 LaneCount<N>: SupportedLaneCount,
15{
16}
17
18impl<T, const N: usize> Clone for Mask<T, N>
19where
20 T: MaskElement,
21 LaneCount<N>: SupportedLaneCount,
22{
23 #[inline]
24 #[must_use = "method returns a new mask and does not mutate the original value"]
25 fn clone(&self) -> Self {
26 *self
27 }
28}
29
30impl<T, const N: usize> PartialEq for Mask<T, N>
31where
32 T: MaskElement + PartialEq,
33 LaneCount<N>: SupportedLaneCount,
34{
35 #[inline]
36 fn eq(&self, other: &Self) -> bool {
37 self.0.eq(&other.0)
38 }
39}
40
41impl<T, const N: usize> PartialOrd for Mask<T, N>
42where
43 T: MaskElement + PartialOrd,
44 LaneCount<N>: SupportedLaneCount,
45{
46 #[inline]
47 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
48 self.0.partial_cmp(&other.0)
49 }
50}
51
52impl<T, const N: usize> Eq for Mask<T, N>
53where
54 T: MaskElement + Eq,
55 LaneCount<N>: SupportedLaneCount,
56{
57}
58
59impl<T, const N: usize> Ord for Mask<T, N>
60where
61 T: MaskElement + Ord,
62 LaneCount<N>: SupportedLaneCount,
63{
64 #[inline]
65 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
66 self.0.cmp(&other.0)
67 }
68}
69
70// Used for bitmask bit order workaround
71pub(crate) trait ReverseBits {
72 // Reverse the least significant `n` bits of `self`.
73 // (Remaining bits must be 0.)
74 fn reverse_bits(self, n: usize) -> Self;
75}
76
77macro_rules! impl_reverse_bits {
78 { $($int:ty),* } => {
79 $(
80 impl ReverseBits for $int {
81 #[inline(always)]
82 fn reverse_bits(self, n: usize) -> Self {
83 let rev = <$int>::reverse_bits(self);
84 let bitsize = core::mem::size_of::<$int>() * 8;
85 if n < bitsize {
86 // Shift things back to the right
87 rev >> (bitsize - n)
88 } else {
89 rev
90 }
91 }
92 }
93 )*
94 }
95}
96
97impl_reverse_bits! { u8, u16, u32, u64 }
98
99impl<T, const N: usize> Mask<T, N>
100where
101 T: MaskElement,
102 LaneCount<N>: SupportedLaneCount,
103{
104 #[inline]
105 #[must_use = "method returns a new mask and does not mutate the original value"]
106 pub fn splat(value: bool) -> Self {
107 Self(Simd::splat(if value { T::TRUE } else { T::FALSE }))
108 }
109
110 #[inline]
111 #[must_use = "method returns a new bool and does not mutate the original value"]
112 pub unsafe fn test_unchecked(&self, lane: usize) -> bool {
113 T::eq(self.0[lane], T::TRUE)
114 }
115
116 #[inline]
117 pub unsafe fn set_unchecked(&mut self, lane: usize, value: bool) {
118 self.0[lane] = if value { T::TRUE } else { T::FALSE }
119 }
120
121 #[inline]
122 #[must_use = "method returns a new vector and does not mutate the original value"]
123 pub fn to_int(self) -> Simd<T, N> {
124 self.0
125 }
126
127 #[inline]
128 #[must_use = "method returns a new mask and does not mutate the original value"]
129 pub unsafe fn from_int_unchecked(value: Simd<T, N>) -> Self {
130 Self(value)
131 }
132
133 #[inline]
134 #[must_use = "method returns a new mask and does not mutate the original value"]
135 pub fn convert<U>(self) -> Mask<U, N>
136 where
137 U: MaskElement,
138 {
139 // Safety: masks are simply integer vectors of 0 and -1, and we can cast the element type.
140 unsafe { Mask(core::intrinsics::simd::simd_cast(self.0)) }
141 }
142
143 #[inline]
144 #[must_use = "method returns a new vector and does not mutate the original value"]
145 pub fn to_bitmask_vector(self) -> Simd<u8, N> {
146 let mut bitmask = Simd::splat(0);
147
148 // Safety: Bytes is the right size array
149 unsafe {
150 // Compute the bitmask
151 let mut bytes: <LaneCount<N> as SupportedLaneCount>::BitMask =
152 core::intrinsics::simd::simd_bitmask(self.0);
153
154 // LLVM assumes bit order should match endianness
155 if cfg!(target_endian = "big") {
156 for x in bytes.as_mut() {
157 *x = x.reverse_bits()
158 }
159 if N % 8 > 0 {
160 bytes.as_mut()[N / 8] >>= 8 - N % 8;
161 }
162 }
163
164 bitmask.as_mut_array()[..bytes.as_ref().len()].copy_from_slice(bytes.as_ref());
165 }
166
167 bitmask
168 }
169
170 #[inline]
171 #[must_use = "method returns a new mask and does not mutate the original value"]
172 pub fn from_bitmask_vector(bitmask: Simd<u8, N>) -> Self {
173 let mut bytes = <LaneCount<N> as SupportedLaneCount>::BitMask::default();
174
175 // Safety: Bytes is the right size array
176 unsafe {
177 let len = bytes.as_ref().len();
178 bytes.as_mut().copy_from_slice(&bitmask.as_array()[..len]);
179
180 // LLVM assumes bit order should match endianness
181 if cfg!(target_endian = "big") {
182 for x in bytes.as_mut() {
183 *x = x.reverse_bits();
184 }
185 if N % 8 > 0 {
186 bytes.as_mut()[N / 8] >>= 8 - N % 8;
187 }
188 }
189
190 // Compute the regular mask
191 Self::from_int_unchecked(core::intrinsics::simd::simd_select_bitmask(
192 bytes,
193 Self::splat(true).to_int(),
194 Self::splat(false).to_int(),
195 ))
196 }
197 }
198
199 #[inline]
200 unsafe fn to_bitmask_impl<U: ReverseBits, const M: usize>(self) -> U
201 where
202 LaneCount<M>: SupportedLaneCount,
203 {
204 let resized = self.to_int().resize::<M>(T::FALSE);
205
206 // Safety: `resized` is an integer vector with length M, which must match T
207 let bitmask: U = unsafe { core::intrinsics::simd::simd_bitmask(resized) };
208
209 // LLVM assumes bit order should match endianness
210 if cfg!(target_endian = "big") {
211 bitmask.reverse_bits(M)
212 } else {
213 bitmask
214 }
215 }
216
217 #[inline]
218 unsafe fn from_bitmask_impl<U: ReverseBits, const M: usize>(bitmask: U) -> Self
219 where
220 LaneCount<M>: SupportedLaneCount,
221 {
222 // LLVM assumes bit order should match endianness
223 let bitmask = if cfg!(target_endian = "big") {
224 bitmask.reverse_bits(M)
225 } else {
226 bitmask
227 };
228
229 // SAFETY: `mask` is the correct bitmask type for a u64 bitmask
230 let mask: Simd<T, M> = unsafe {
231 core::intrinsics::simd::simd_select_bitmask(
232 bitmask,
233 Simd::<T, M>::splat(T::TRUE),
234 Simd::<T, M>::splat(T::FALSE),
235 )
236 };
237
238 // SAFETY: `mask` only contains `T::TRUE` or `T::FALSE`
239 unsafe { Self::from_int_unchecked(mask.resize::<N>(T::FALSE)) }
240 }
241
242 #[inline]
243 pub(crate) fn to_bitmask_integer(self) -> u64 {
244 // TODO modify simd_bitmask to zero-extend output, making this unnecessary
245 if N <= 8 {
246 // Safety: bitmask matches length
247 unsafe { self.to_bitmask_impl::<u8, 8>() as u64 }
248 } else if N <= 16 {
249 // Safety: bitmask matches length
250 unsafe { self.to_bitmask_impl::<u16, 16>() as u64 }
251 } else if N <= 32 {
252 // Safety: bitmask matches length
253 unsafe { self.to_bitmask_impl::<u32, 32>() as u64 }
254 } else {
255 // Safety: bitmask matches length
256 unsafe { self.to_bitmask_impl::<u64, 64>() }
257 }
258 }
259
260 #[inline]
261 pub(crate) fn from_bitmask_integer(bitmask: u64) -> Self {
262 // TODO modify simd_bitmask_select to truncate input, making this unnecessary
263 if N <= 8 {
264 // Safety: bitmask matches length
265 unsafe { Self::from_bitmask_impl::<u8, 8>(bitmask as u8) }
266 } else if N <= 16 {
267 // Safety: bitmask matches length
268 unsafe { Self::from_bitmask_impl::<u16, 16>(bitmask as u16) }
269 } else if N <= 32 {
270 // Safety: bitmask matches length
271 unsafe { Self::from_bitmask_impl::<u32, 32>(bitmask as u32) }
272 } else {
273 // Safety: bitmask matches length
274 unsafe { Self::from_bitmask_impl::<u64, 64>(bitmask) }
275 }
276 }
277
278 #[inline]
279 #[must_use = "method returns a new bool and does not mutate the original value"]
280 pub fn any(self) -> bool {
281 // Safety: use `self` as an integer vector
282 unsafe { core::intrinsics::simd::simd_reduce_any(self.to_int()) }
283 }
284
285 #[inline]
286 #[must_use = "method returns a new vector and does not mutate the original value"]
287 pub fn all(self) -> bool {
288 // Safety: use `self` as an integer vector
289 unsafe { core::intrinsics::simd::simd_reduce_all(self.to_int()) }
290 }
291}
292
293impl<T, const N: usize> From<Mask<T, N>> for Simd<T, N>
294where
295 T: MaskElement,
296 LaneCount<N>: SupportedLaneCount,
297{
298 #[inline]
299 fn from(value: Mask<T, N>) -> Self {
300 value.0
301 }
302}
303
304impl<T, const N: usize> core::ops::BitAnd for Mask<T, N>
305where
306 T: MaskElement,
307 LaneCount<N>: SupportedLaneCount,
308{
309 type Output = Self;
310 #[inline]
311 #[must_use = "method returns a new mask and does not mutate the original value"]
312 fn bitand(self, rhs: Self) -> Self {
313 // Safety: `self` is an integer vector
314 unsafe { Self(core::intrinsics::simd::simd_and(self.0, y:rhs.0)) }
315 }
316}
317
318impl<T, const N: usize> core::ops::BitOr for Mask<T, N>
319where
320 T: MaskElement,
321 LaneCount<N>: SupportedLaneCount,
322{
323 type Output = Self;
324 #[inline]
325 #[must_use = "method returns a new mask and does not mutate the original value"]
326 fn bitor(self, rhs: Self) -> Self {
327 // Safety: `self` is an integer vector
328 unsafe { Self(core::intrinsics::simd::simd_or(self.0, y:rhs.0)) }
329 }
330}
331
332impl<T, const N: usize> core::ops::BitXor for Mask<T, N>
333where
334 T: MaskElement,
335 LaneCount<N>: SupportedLaneCount,
336{
337 type Output = Self;
338 #[inline]
339 #[must_use = "method returns a new mask and does not mutate the original value"]
340 fn bitxor(self, rhs: Self) -> Self {
341 // Safety: `self` is an integer vector
342 unsafe { Self(core::intrinsics::simd::simd_xor(self.0, y:rhs.0)) }
343 }
344}
345
346impl<T, const N: usize> core::ops::Not for Mask<T, N>
347where
348 T: MaskElement,
349 LaneCount<N>: SupportedLaneCount,
350{
351 type Output = Self;
352 #[inline]
353 #[must_use = "method returns a new mask and does not mutate the original value"]
354 fn not(self) -> Self::Output {
355 Self::splat(true) ^ self
356 }
357}
358