1 | //! Masks that take up full SIMD vector registers. |
2 | |
3 | use crate::simd::intrinsics; |
4 | use crate::simd::{LaneCount, MaskElement, Simd, SupportedLaneCount}; |
5 | |
6 | #[repr (transparent)] |
7 | pub struct Mask<T, const N: usize>(Simd<T, N>) |
8 | where |
9 | T: MaskElement, |
10 | LaneCount<N>: SupportedLaneCount; |
11 | |
12 | impl<T, const N: usize> Copy for Mask<T, N> |
13 | where |
14 | T: MaskElement, |
15 | LaneCount<N>: SupportedLaneCount, |
16 | { |
17 | } |
18 | |
19 | impl<T, const N: usize> Clone for Mask<T, N> |
20 | where |
21 | T: MaskElement, |
22 | LaneCount<N>: SupportedLaneCount, |
23 | { |
24 | #[inline ] |
25 | #[must_use = "method returns a new mask and does not mutate the original value" ] |
26 | fn clone(&self) -> Self { |
27 | *self |
28 | } |
29 | } |
30 | |
31 | impl<T, const N: usize> PartialEq for Mask<T, N> |
32 | where |
33 | T: MaskElement + PartialEq, |
34 | LaneCount<N>: SupportedLaneCount, |
35 | { |
36 | #[inline ] |
37 | fn eq(&self, other: &Self) -> bool { |
38 | self.0.eq(&other.0) |
39 | } |
40 | } |
41 | |
42 | impl<T, const N: usize> PartialOrd for Mask<T, N> |
43 | where |
44 | T: MaskElement + PartialOrd, |
45 | LaneCount<N>: SupportedLaneCount, |
46 | { |
47 | #[inline ] |
48 | fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> { |
49 | self.0.partial_cmp(&other.0) |
50 | } |
51 | } |
52 | |
53 | impl<T, const N: usize> Eq for Mask<T, N> |
54 | where |
55 | T: MaskElement + Eq, |
56 | LaneCount<N>: SupportedLaneCount, |
57 | { |
58 | } |
59 | |
60 | impl<T, const N: usize> Ord for Mask<T, N> |
61 | where |
62 | T: MaskElement + Ord, |
63 | LaneCount<N>: SupportedLaneCount, |
64 | { |
65 | #[inline ] |
66 | fn cmp(&self, other: &Self) -> core::cmp::Ordering { |
67 | self.0.cmp(&other.0) |
68 | } |
69 | } |
70 | |
71 | // Used for bitmask bit order workaround |
72 | pub(crate) trait ReverseBits { |
73 | // Reverse the least significant `n` bits of `self`. |
74 | // (Remaining bits must be 0.) |
75 | fn reverse_bits(self, n: usize) -> Self; |
76 | } |
77 | |
78 | macro_rules! impl_reverse_bits { |
79 | { $($int:ty),* } => { |
80 | $( |
81 | impl ReverseBits for $int { |
82 | #[inline(always)] |
83 | fn reverse_bits(self, n: usize) -> Self { |
84 | let rev = <$int>::reverse_bits(self); |
85 | let bitsize = core::mem::size_of::<$int>() * 8; |
86 | if n < bitsize { |
87 | // Shift things back to the right |
88 | rev >> (bitsize - n) |
89 | } else { |
90 | rev |
91 | } |
92 | } |
93 | } |
94 | )* |
95 | } |
96 | } |
97 | |
98 | impl_reverse_bits! { u8, u16, u32, u64 } |
99 | |
100 | impl<T, const N: usize> Mask<T, N> |
101 | where |
102 | T: MaskElement, |
103 | LaneCount<N>: SupportedLaneCount, |
104 | { |
105 | #[inline ] |
106 | #[must_use = "method returns a new mask and does not mutate the original value" ] |
107 | pub fn splat(value: bool) -> Self { |
108 | Self(Simd::splat(if value { T::TRUE } else { T::FALSE })) |
109 | } |
110 | |
111 | #[inline ] |
112 | #[must_use = "method returns a new bool and does not mutate the original value" ] |
113 | pub unsafe fn test_unchecked(&self, lane: usize) -> bool { |
114 | T::eq(self.0[lane], T::TRUE) |
115 | } |
116 | |
117 | #[inline ] |
118 | pub unsafe fn set_unchecked(&mut self, lane: usize, value: bool) { |
119 | self.0[lane] = if value { T::TRUE } else { T::FALSE } |
120 | } |
121 | |
122 | #[inline ] |
123 | #[must_use = "method returns a new vector and does not mutate the original value" ] |
124 | pub fn to_int(self) -> Simd<T, N> { |
125 | self.0 |
126 | } |
127 | |
128 | #[inline ] |
129 | #[must_use = "method returns a new mask and does not mutate the original value" ] |
130 | pub unsafe fn from_int_unchecked(value: Simd<T, N>) -> Self { |
131 | Self(value) |
132 | } |
133 | |
134 | #[inline ] |
135 | #[must_use = "method returns a new mask and does not mutate the original value" ] |
136 | pub fn convert<U>(self) -> Mask<U, N> |
137 | where |
138 | U: MaskElement, |
139 | { |
140 | // Safety: masks are simply integer vectors of 0 and -1, and we can cast the element type. |
141 | unsafe { Mask(intrinsics::simd_cast(self.0)) } |
142 | } |
143 | |
144 | #[inline ] |
145 | #[must_use = "method returns a new vector and does not mutate the original value" ] |
146 | pub fn to_bitmask_vector(self) -> Simd<u8, N> { |
147 | let mut bitmask = Simd::splat(0); |
148 | |
149 | // Safety: Bytes is the right size array |
150 | unsafe { |
151 | // Compute the bitmask |
152 | let mut bytes: <LaneCount<N> as SupportedLaneCount>::BitMask = |
153 | intrinsics::simd_bitmask(self.0); |
154 | |
155 | // LLVM assumes bit order should match endianness |
156 | if cfg!(target_endian = "big" ) { |
157 | for x in bytes.as_mut() { |
158 | *x = x.reverse_bits() |
159 | } |
160 | } |
161 | |
162 | bitmask.as_mut_array()[..bytes.as_ref().len()].copy_from_slice(bytes.as_ref()); |
163 | } |
164 | |
165 | bitmask |
166 | } |
167 | |
168 | #[inline ] |
169 | #[must_use = "method returns a new mask and does not mutate the original value" ] |
170 | pub fn from_bitmask_vector(bitmask: Simd<u8, N>) -> Self { |
171 | let mut bytes = <LaneCount<N> as SupportedLaneCount>::BitMask::default(); |
172 | |
173 | // Safety: Bytes is the right size array |
174 | unsafe { |
175 | let len = bytes.as_ref().len(); |
176 | bytes.as_mut().copy_from_slice(&bitmask.as_array()[..len]); |
177 | |
178 | // LLVM assumes bit order should match endianness |
179 | if cfg!(target_endian = "big" ) { |
180 | for x in bytes.as_mut() { |
181 | *x = x.reverse_bits(); |
182 | } |
183 | } |
184 | |
185 | // Compute the regular mask |
186 | Self::from_int_unchecked(intrinsics::simd_select_bitmask( |
187 | bytes, |
188 | Self::splat(true).to_int(), |
189 | Self::splat(false).to_int(), |
190 | )) |
191 | } |
192 | } |
193 | |
194 | #[inline ] |
195 | unsafe fn to_bitmask_impl<U: ReverseBits, const M: usize>(self) -> U |
196 | where |
197 | LaneCount<M>: SupportedLaneCount, |
198 | { |
199 | let resized = self.to_int().resize::<M>(T::FALSE); |
200 | |
201 | // Safety: `resized` is an integer vector with length M, which must match T |
202 | let bitmask: U = unsafe { intrinsics::simd_bitmask(resized) }; |
203 | |
204 | // LLVM assumes bit order should match endianness |
205 | if cfg!(target_endian = "big" ) { |
206 | bitmask.reverse_bits(M) |
207 | } else { |
208 | bitmask |
209 | } |
210 | } |
211 | |
212 | #[inline ] |
213 | unsafe fn from_bitmask_impl<U: ReverseBits, const M: usize>(bitmask: U) -> Self |
214 | where |
215 | LaneCount<M>: SupportedLaneCount, |
216 | { |
217 | // LLVM assumes bit order should match endianness |
218 | let bitmask = if cfg!(target_endian = "big" ) { |
219 | bitmask.reverse_bits(M) |
220 | } else { |
221 | bitmask |
222 | }; |
223 | |
224 | // SAFETY: `mask` is the correct bitmask type for a u64 bitmask |
225 | let mask: Simd<T, M> = unsafe { |
226 | intrinsics::simd_select_bitmask( |
227 | bitmask, |
228 | Simd::<T, M>::splat(T::TRUE), |
229 | Simd::<T, M>::splat(T::FALSE), |
230 | ) |
231 | }; |
232 | |
233 | // SAFETY: `mask` only contains `T::TRUE` or `T::FALSE` |
234 | unsafe { Self::from_int_unchecked(mask.resize::<N>(T::FALSE)) } |
235 | } |
236 | |
237 | #[inline ] |
238 | pub(crate) fn to_bitmask_integer(self) -> u64 { |
239 | // TODO modify simd_bitmask to zero-extend output, making this unnecessary |
240 | if N <= 8 { |
241 | // Safety: bitmask matches length |
242 | unsafe { self.to_bitmask_impl::<u8, 8>() as u64 } |
243 | } else if N <= 16 { |
244 | // Safety: bitmask matches length |
245 | unsafe { self.to_bitmask_impl::<u16, 16>() as u64 } |
246 | } else if N <= 32 { |
247 | // Safety: bitmask matches length |
248 | unsafe { self.to_bitmask_impl::<u32, 32>() as u64 } |
249 | } else { |
250 | // Safety: bitmask matches length |
251 | unsafe { self.to_bitmask_impl::<u64, 64>() } |
252 | } |
253 | } |
254 | |
255 | #[inline ] |
256 | pub(crate) fn from_bitmask_integer(bitmask: u64) -> Self { |
257 | // TODO modify simd_bitmask_select to truncate input, making this unnecessary |
258 | if N <= 8 { |
259 | // Safety: bitmask matches length |
260 | unsafe { Self::from_bitmask_impl::<u8, 8>(bitmask as u8) } |
261 | } else if N <= 16 { |
262 | // Safety: bitmask matches length |
263 | unsafe { Self::from_bitmask_impl::<u16, 16>(bitmask as u16) } |
264 | } else if N <= 32 { |
265 | // Safety: bitmask matches length |
266 | unsafe { Self::from_bitmask_impl::<u32, 32>(bitmask as u32) } |
267 | } else { |
268 | // Safety: bitmask matches length |
269 | unsafe { Self::from_bitmask_impl::<u64, 64>(bitmask) } |
270 | } |
271 | } |
272 | |
273 | #[inline ] |
274 | #[must_use = "method returns a new bool and does not mutate the original value" ] |
275 | pub fn any(self) -> bool { |
276 | // Safety: use `self` as an integer vector |
277 | unsafe { intrinsics::simd_reduce_any(self.to_int()) } |
278 | } |
279 | |
280 | #[inline ] |
281 | #[must_use = "method returns a new vector and does not mutate the original value" ] |
282 | pub fn all(self) -> bool { |
283 | // Safety: use `self` as an integer vector |
284 | unsafe { intrinsics::simd_reduce_all(self.to_int()) } |
285 | } |
286 | } |
287 | |
288 | impl<T, const N: usize> From<Mask<T, N>> for Simd<T, N> |
289 | where |
290 | T: MaskElement, |
291 | LaneCount<N>: SupportedLaneCount, |
292 | { |
293 | #[inline ] |
294 | fn from(value: Mask<T, N>) -> Self { |
295 | value.0 |
296 | } |
297 | } |
298 | |
299 | impl<T, const N: usize> core::ops::BitAnd for Mask<T, N> |
300 | where |
301 | T: MaskElement, |
302 | LaneCount<N>: SupportedLaneCount, |
303 | { |
304 | type Output = Self; |
305 | #[inline ] |
306 | #[must_use = "method returns a new mask and does not mutate the original value" ] |
307 | fn bitand(self, rhs: Self) -> Self { |
308 | // Safety: `self` is an integer vector |
309 | unsafe { Self(intrinsics::simd_and(self.0, y:rhs.0)) } |
310 | } |
311 | } |
312 | |
313 | impl<T, const N: usize> core::ops::BitOr for Mask<T, N> |
314 | where |
315 | T: MaskElement, |
316 | LaneCount<N>: SupportedLaneCount, |
317 | { |
318 | type Output = Self; |
319 | #[inline ] |
320 | #[must_use = "method returns a new mask and does not mutate the original value" ] |
321 | fn bitor(self, rhs: Self) -> Self { |
322 | // Safety: `self` is an integer vector |
323 | unsafe { Self(intrinsics::simd_or(self.0, y:rhs.0)) } |
324 | } |
325 | } |
326 | |
327 | impl<T, const N: usize> core::ops::BitXor for Mask<T, N> |
328 | where |
329 | T: MaskElement, |
330 | LaneCount<N>: SupportedLaneCount, |
331 | { |
332 | type Output = Self; |
333 | #[inline ] |
334 | #[must_use = "method returns a new mask and does not mutate the original value" ] |
335 | fn bitxor(self, rhs: Self) -> Self { |
336 | // Safety: `self` is an integer vector |
337 | unsafe { Self(intrinsics::simd_xor(self.0, y:rhs.0)) } |
338 | } |
339 | } |
340 | |
341 | impl<T, const N: usize> core::ops::Not for Mask<T, N> |
342 | where |
343 | T: MaskElement, |
344 | LaneCount<N>: SupportedLaneCount, |
345 | { |
346 | type Output = Self; |
347 | #[inline ] |
348 | #[must_use = "method returns a new mask and does not mutate the original value" ] |
349 | fn not(self) -> Self::Output { |
350 | Self::splat(true) ^ self |
351 | } |
352 | } |
353 | |