1 | //! Masks that take up full SIMD vector registers. |
2 | |
3 | use crate::simd::{LaneCount, MaskElement, Simd, SupportedLaneCount}; |
4 | |
5 | #[repr (transparent)] |
6 | pub struct Mask<T, const N: usize>(Simd<T, N>) |
7 | where |
8 | T: MaskElement, |
9 | LaneCount<N>: SupportedLaneCount; |
10 | |
11 | impl<T, const N: usize> Copy for Mask<T, N> |
12 | where |
13 | T: MaskElement, |
14 | LaneCount<N>: SupportedLaneCount, |
15 | { |
16 | } |
17 | |
18 | impl<T, const N: usize> Clone for Mask<T, N> |
19 | where |
20 | T: MaskElement, |
21 | LaneCount<N>: SupportedLaneCount, |
22 | { |
23 | #[inline ] |
24 | #[must_use = "method returns a new mask and does not mutate the original value" ] |
25 | fn clone(&self) -> Self { |
26 | *self |
27 | } |
28 | } |
29 | |
30 | impl<T, const N: usize> PartialEq for Mask<T, N> |
31 | where |
32 | T: MaskElement + PartialEq, |
33 | LaneCount<N>: SupportedLaneCount, |
34 | { |
35 | #[inline ] |
36 | fn eq(&self, other: &Self) -> bool { |
37 | self.0.eq(&other.0) |
38 | } |
39 | } |
40 | |
41 | impl<T, const N: usize> PartialOrd for Mask<T, N> |
42 | where |
43 | T: MaskElement + PartialOrd, |
44 | LaneCount<N>: SupportedLaneCount, |
45 | { |
46 | #[inline ] |
47 | fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> { |
48 | self.0.partial_cmp(&other.0) |
49 | } |
50 | } |
51 | |
52 | impl<T, const N: usize> Eq for Mask<T, N> |
53 | where |
54 | T: MaskElement + Eq, |
55 | LaneCount<N>: SupportedLaneCount, |
56 | { |
57 | } |
58 | |
59 | impl<T, const N: usize> Ord for Mask<T, N> |
60 | where |
61 | T: MaskElement + Ord, |
62 | LaneCount<N>: SupportedLaneCount, |
63 | { |
64 | #[inline ] |
65 | fn cmp(&self, other: &Self) -> core::cmp::Ordering { |
66 | self.0.cmp(&other.0) |
67 | } |
68 | } |
69 | |
70 | // Used for bitmask bit order workaround |
71 | pub(crate) trait ReverseBits { |
72 | // Reverse the least significant `n` bits of `self`. |
73 | // (Remaining bits must be 0.) |
74 | fn reverse_bits(self, n: usize) -> Self; |
75 | } |
76 | |
77 | macro_rules! impl_reverse_bits { |
78 | { $($int:ty),* } => { |
79 | $( |
80 | impl ReverseBits for $int { |
81 | #[inline(always)] |
82 | fn reverse_bits(self, n: usize) -> Self { |
83 | let rev = <$int>::reverse_bits(self); |
84 | let bitsize = core::mem::size_of::<$int>() * 8; |
85 | if n < bitsize { |
86 | // Shift things back to the right |
87 | rev >> (bitsize - n) |
88 | } else { |
89 | rev |
90 | } |
91 | } |
92 | } |
93 | )* |
94 | } |
95 | } |
96 | |
97 | impl_reverse_bits! { u8, u16, u32, u64 } |
98 | |
99 | impl<T, const N: usize> Mask<T, N> |
100 | where |
101 | T: MaskElement, |
102 | LaneCount<N>: SupportedLaneCount, |
103 | { |
104 | #[inline ] |
105 | #[must_use = "method returns a new mask and does not mutate the original value" ] |
106 | pub fn splat(value: bool) -> Self { |
107 | Self(Simd::splat(if value { T::TRUE } else { T::FALSE })) |
108 | } |
109 | |
110 | #[inline ] |
111 | #[must_use = "method returns a new bool and does not mutate the original value" ] |
112 | pub unsafe fn test_unchecked(&self, lane: usize) -> bool { |
113 | T::eq(self.0[lane], T::TRUE) |
114 | } |
115 | |
116 | #[inline ] |
117 | pub unsafe fn set_unchecked(&mut self, lane: usize, value: bool) { |
118 | self.0[lane] = if value { T::TRUE } else { T::FALSE } |
119 | } |
120 | |
121 | #[inline ] |
122 | #[must_use = "method returns a new vector and does not mutate the original value" ] |
123 | pub fn to_int(self) -> Simd<T, N> { |
124 | self.0 |
125 | } |
126 | |
127 | #[inline ] |
128 | #[must_use = "method returns a new mask and does not mutate the original value" ] |
129 | pub unsafe fn from_int_unchecked(value: Simd<T, N>) -> Self { |
130 | Self(value) |
131 | } |
132 | |
133 | #[inline ] |
134 | #[must_use = "method returns a new mask and does not mutate the original value" ] |
135 | pub fn convert<U>(self) -> Mask<U, N> |
136 | where |
137 | U: MaskElement, |
138 | { |
139 | // Safety: masks are simply integer vectors of 0 and -1, and we can cast the element type. |
140 | unsafe { Mask(core::intrinsics::simd::simd_cast(self.0)) } |
141 | } |
142 | |
143 | #[inline ] |
144 | #[must_use = "method returns a new vector and does not mutate the original value" ] |
145 | pub fn to_bitmask_vector(self) -> Simd<u8, N> { |
146 | let mut bitmask = Simd::splat(0); |
147 | |
148 | // Safety: Bytes is the right size array |
149 | unsafe { |
150 | // Compute the bitmask |
151 | let mut bytes: <LaneCount<N> as SupportedLaneCount>::BitMask = |
152 | core::intrinsics::simd::simd_bitmask(self.0); |
153 | |
154 | // LLVM assumes bit order should match endianness |
155 | if cfg!(target_endian = "big" ) { |
156 | for x in bytes.as_mut() { |
157 | *x = x.reverse_bits() |
158 | } |
159 | if N % 8 > 0 { |
160 | bytes.as_mut()[N / 8] >>= 8 - N % 8; |
161 | } |
162 | } |
163 | |
164 | bitmask.as_mut_array()[..bytes.as_ref().len()].copy_from_slice(bytes.as_ref()); |
165 | } |
166 | |
167 | bitmask |
168 | } |
169 | |
170 | #[inline ] |
171 | #[must_use = "method returns a new mask and does not mutate the original value" ] |
172 | pub fn from_bitmask_vector(bitmask: Simd<u8, N>) -> Self { |
173 | let mut bytes = <LaneCount<N> as SupportedLaneCount>::BitMask::default(); |
174 | |
175 | // Safety: Bytes is the right size array |
176 | unsafe { |
177 | let len = bytes.as_ref().len(); |
178 | bytes.as_mut().copy_from_slice(&bitmask.as_array()[..len]); |
179 | |
180 | // LLVM assumes bit order should match endianness |
181 | if cfg!(target_endian = "big" ) { |
182 | for x in bytes.as_mut() { |
183 | *x = x.reverse_bits(); |
184 | } |
185 | if N % 8 > 0 { |
186 | bytes.as_mut()[N / 8] >>= 8 - N % 8; |
187 | } |
188 | } |
189 | |
190 | // Compute the regular mask |
191 | Self::from_int_unchecked(core::intrinsics::simd::simd_select_bitmask( |
192 | bytes, |
193 | Self::splat(true).to_int(), |
194 | Self::splat(false).to_int(), |
195 | )) |
196 | } |
197 | } |
198 | |
199 | #[inline ] |
200 | unsafe fn to_bitmask_impl<U: ReverseBits, const M: usize>(self) -> U |
201 | where |
202 | LaneCount<M>: SupportedLaneCount, |
203 | { |
204 | let resized = self.to_int().resize::<M>(T::FALSE); |
205 | |
206 | // Safety: `resized` is an integer vector with length M, which must match T |
207 | let bitmask: U = unsafe { core::intrinsics::simd::simd_bitmask(resized) }; |
208 | |
209 | // LLVM assumes bit order should match endianness |
210 | if cfg!(target_endian = "big" ) { |
211 | bitmask.reverse_bits(M) |
212 | } else { |
213 | bitmask |
214 | } |
215 | } |
216 | |
217 | #[inline ] |
218 | unsafe fn from_bitmask_impl<U: ReverseBits, const M: usize>(bitmask: U) -> Self |
219 | where |
220 | LaneCount<M>: SupportedLaneCount, |
221 | { |
222 | // LLVM assumes bit order should match endianness |
223 | let bitmask = if cfg!(target_endian = "big" ) { |
224 | bitmask.reverse_bits(M) |
225 | } else { |
226 | bitmask |
227 | }; |
228 | |
229 | // SAFETY: `mask` is the correct bitmask type for a u64 bitmask |
230 | let mask: Simd<T, M> = unsafe { |
231 | core::intrinsics::simd::simd_select_bitmask( |
232 | bitmask, |
233 | Simd::<T, M>::splat(T::TRUE), |
234 | Simd::<T, M>::splat(T::FALSE), |
235 | ) |
236 | }; |
237 | |
238 | // SAFETY: `mask` only contains `T::TRUE` or `T::FALSE` |
239 | unsafe { Self::from_int_unchecked(mask.resize::<N>(T::FALSE)) } |
240 | } |
241 | |
242 | #[inline ] |
243 | pub(crate) fn to_bitmask_integer(self) -> u64 { |
244 | // TODO modify simd_bitmask to zero-extend output, making this unnecessary |
245 | if N <= 8 { |
246 | // Safety: bitmask matches length |
247 | unsafe { self.to_bitmask_impl::<u8, 8>() as u64 } |
248 | } else if N <= 16 { |
249 | // Safety: bitmask matches length |
250 | unsafe { self.to_bitmask_impl::<u16, 16>() as u64 } |
251 | } else if N <= 32 { |
252 | // Safety: bitmask matches length |
253 | unsafe { self.to_bitmask_impl::<u32, 32>() as u64 } |
254 | } else { |
255 | // Safety: bitmask matches length |
256 | unsafe { self.to_bitmask_impl::<u64, 64>() } |
257 | } |
258 | } |
259 | |
260 | #[inline ] |
261 | pub(crate) fn from_bitmask_integer(bitmask: u64) -> Self { |
262 | // TODO modify simd_bitmask_select to truncate input, making this unnecessary |
263 | if N <= 8 { |
264 | // Safety: bitmask matches length |
265 | unsafe { Self::from_bitmask_impl::<u8, 8>(bitmask as u8) } |
266 | } else if N <= 16 { |
267 | // Safety: bitmask matches length |
268 | unsafe { Self::from_bitmask_impl::<u16, 16>(bitmask as u16) } |
269 | } else if N <= 32 { |
270 | // Safety: bitmask matches length |
271 | unsafe { Self::from_bitmask_impl::<u32, 32>(bitmask as u32) } |
272 | } else { |
273 | // Safety: bitmask matches length |
274 | unsafe { Self::from_bitmask_impl::<u64, 64>(bitmask) } |
275 | } |
276 | } |
277 | |
278 | #[inline ] |
279 | #[must_use = "method returns a new bool and does not mutate the original value" ] |
280 | pub fn any(self) -> bool { |
281 | // Safety: use `self` as an integer vector |
282 | unsafe { core::intrinsics::simd::simd_reduce_any(self.to_int()) } |
283 | } |
284 | |
285 | #[inline ] |
286 | #[must_use = "method returns a new vector and does not mutate the original value" ] |
287 | pub fn all(self) -> bool { |
288 | // Safety: use `self` as an integer vector |
289 | unsafe { core::intrinsics::simd::simd_reduce_all(self.to_int()) } |
290 | } |
291 | } |
292 | |
293 | impl<T, const N: usize> From<Mask<T, N>> for Simd<T, N> |
294 | where |
295 | T: MaskElement, |
296 | LaneCount<N>: SupportedLaneCount, |
297 | { |
298 | #[inline ] |
299 | fn from(value: Mask<T, N>) -> Self { |
300 | value.0 |
301 | } |
302 | } |
303 | |
304 | impl<T, const N: usize> core::ops::BitAnd for Mask<T, N> |
305 | where |
306 | T: MaskElement, |
307 | LaneCount<N>: SupportedLaneCount, |
308 | { |
309 | type Output = Self; |
310 | #[inline ] |
311 | #[must_use = "method returns a new mask and does not mutate the original value" ] |
312 | fn bitand(self, rhs: Self) -> Self { |
313 | // Safety: `self` is an integer vector |
314 | unsafe { Self(core::intrinsics::simd::simd_and(self.0, y:rhs.0)) } |
315 | } |
316 | } |
317 | |
318 | impl<T, const N: usize> core::ops::BitOr for Mask<T, N> |
319 | where |
320 | T: MaskElement, |
321 | LaneCount<N>: SupportedLaneCount, |
322 | { |
323 | type Output = Self; |
324 | #[inline ] |
325 | #[must_use = "method returns a new mask and does not mutate the original value" ] |
326 | fn bitor(self, rhs: Self) -> Self { |
327 | // Safety: `self` is an integer vector |
328 | unsafe { Self(core::intrinsics::simd::simd_or(self.0, y:rhs.0)) } |
329 | } |
330 | } |
331 | |
332 | impl<T, const N: usize> core::ops::BitXor for Mask<T, N> |
333 | where |
334 | T: MaskElement, |
335 | LaneCount<N>: SupportedLaneCount, |
336 | { |
337 | type Output = Self; |
338 | #[inline ] |
339 | #[must_use = "method returns a new mask and does not mutate the original value" ] |
340 | fn bitxor(self, rhs: Self) -> Self { |
341 | // Safety: `self` is an integer vector |
342 | unsafe { Self(core::intrinsics::simd::simd_xor(self.0, y:rhs.0)) } |
343 | } |
344 | } |
345 | |
346 | impl<T, const N: usize> core::ops::Not for Mask<T, N> |
347 | where |
348 | T: MaskElement, |
349 | LaneCount<N>: SupportedLaneCount, |
350 | { |
351 | type Output = Self; |
352 | #[inline ] |
353 | #[must_use = "method returns a new mask and does not mutate the original value" ] |
354 | fn not(self) -> Self::Output { |
355 | Self::splat(true) ^ self |
356 | } |
357 | } |
358 | |