1 | //! Masks that take up full SIMD vector registers. |
2 | |
3 | use crate::simd::{LaneCount, MaskElement, Simd, SupportedLaneCount}; |
4 | |
5 | #[repr (transparent)] |
6 | pub(crate) struct Mask<T, const N: usize>(Simd<T, N>) |
7 | where |
8 | T: MaskElement, |
9 | LaneCount<N>: SupportedLaneCount; |
10 | |
11 | impl<T, const N: usize> Copy for Mask<T, N> |
12 | where |
13 | T: MaskElement, |
14 | LaneCount<N>: SupportedLaneCount, |
15 | { |
16 | } |
17 | |
18 | impl<T, const N: usize> Clone for Mask<T, N> |
19 | where |
20 | T: MaskElement, |
21 | LaneCount<N>: SupportedLaneCount, |
22 | { |
23 | #[inline ] |
24 | fn clone(&self) -> Self { |
25 | *self |
26 | } |
27 | } |
28 | |
29 | impl<T, const N: usize> PartialEq for Mask<T, N> |
30 | where |
31 | T: MaskElement + PartialEq, |
32 | LaneCount<N>: SupportedLaneCount, |
33 | { |
34 | #[inline ] |
35 | fn eq(&self, other: &Self) -> bool { |
36 | self.0.eq(&other.0) |
37 | } |
38 | } |
39 | |
40 | impl<T, const N: usize> PartialOrd for Mask<T, N> |
41 | where |
42 | T: MaskElement + PartialOrd, |
43 | LaneCount<N>: SupportedLaneCount, |
44 | { |
45 | #[inline ] |
46 | fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> { |
47 | self.0.partial_cmp(&other.0) |
48 | } |
49 | } |
50 | |
51 | impl<T, const N: usize> Eq for Mask<T, N> |
52 | where |
53 | T: MaskElement + Eq, |
54 | LaneCount<N>: SupportedLaneCount, |
55 | { |
56 | } |
57 | |
58 | impl<T, const N: usize> Ord for Mask<T, N> |
59 | where |
60 | T: MaskElement + Ord, |
61 | LaneCount<N>: SupportedLaneCount, |
62 | { |
63 | #[inline ] |
64 | fn cmp(&self, other: &Self) -> core::cmp::Ordering { |
65 | self.0.cmp(&other.0) |
66 | } |
67 | } |
68 | |
69 | // Used for bitmask bit order workaround |
70 | pub(crate) trait ReverseBits { |
71 | // Reverse the least significant `n` bits of `self`. |
72 | // (Remaining bits must be 0.) |
73 | fn reverse_bits(self, n: usize) -> Self; |
74 | } |
75 | |
76 | macro_rules! impl_reverse_bits { |
77 | { $($int:ty),* } => { |
78 | $( |
79 | impl ReverseBits for $int { |
80 | #[inline(always)] |
81 | fn reverse_bits(self, n: usize) -> Self { |
82 | let rev = <$int>::reverse_bits(self); |
83 | let bitsize = size_of::<$int>() * 8; |
84 | if n < bitsize { |
85 | // Shift things back to the right |
86 | rev >> (bitsize - n) |
87 | } else { |
88 | rev |
89 | } |
90 | } |
91 | } |
92 | )* |
93 | } |
94 | } |
95 | |
96 | impl_reverse_bits! { u8, u16, u32, u64 } |
97 | |
98 | impl<T, const N: usize> Mask<T, N> |
99 | where |
100 | T: MaskElement, |
101 | LaneCount<N>: SupportedLaneCount, |
102 | { |
103 | #[inline ] |
104 | #[must_use = "method returns a new mask and does not mutate the original value" ] |
105 | pub(crate) fn splat(value: bool) -> Self { |
106 | Self(Simd::splat(if value { T::TRUE } else { T::FALSE })) |
107 | } |
108 | |
109 | #[inline ] |
110 | #[must_use = "method returns a new bool and does not mutate the original value" ] |
111 | pub(crate) unsafe fn test_unchecked(&self, lane: usize) -> bool { |
112 | T::eq(self.0[lane], T::TRUE) |
113 | } |
114 | |
115 | #[inline ] |
116 | pub(crate) unsafe fn set_unchecked(&mut self, lane: usize, value: bool) { |
117 | self.0[lane] = if value { T::TRUE } else { T::FALSE } |
118 | } |
119 | |
120 | #[inline ] |
121 | #[must_use = "method returns a new vector and does not mutate the original value" ] |
122 | pub(crate) fn to_int(self) -> Simd<T, N> { |
123 | self.0 |
124 | } |
125 | |
126 | #[inline ] |
127 | #[must_use = "method returns a new mask and does not mutate the original value" ] |
128 | pub(crate) unsafe fn from_int_unchecked(value: Simd<T, N>) -> Self { |
129 | Self(value) |
130 | } |
131 | |
132 | #[inline ] |
133 | #[must_use = "method returns a new mask and does not mutate the original value" ] |
134 | pub(crate) fn convert<U>(self) -> Mask<U, N> |
135 | where |
136 | U: MaskElement, |
137 | { |
138 | // Safety: masks are simply integer vectors of 0 and -1, and we can cast the element type. |
139 | unsafe { Mask(core::intrinsics::simd::simd_cast(self.0)) } |
140 | } |
141 | |
142 | #[inline ] |
143 | unsafe fn to_bitmask_impl<U: ReverseBits, const M: usize>(self) -> U |
144 | where |
145 | LaneCount<M>: SupportedLaneCount, |
146 | { |
147 | let resized = self.to_int().resize::<M>(T::FALSE); |
148 | |
149 | // Safety: `resized` is an integer vector with length M, which must match T |
150 | let bitmask: U = unsafe { core::intrinsics::simd::simd_bitmask(resized) }; |
151 | |
152 | // LLVM assumes bit order should match endianness |
153 | if cfg!(target_endian = "big" ) { |
154 | bitmask.reverse_bits(M) |
155 | } else { |
156 | bitmask |
157 | } |
158 | } |
159 | |
160 | #[inline ] |
161 | unsafe fn from_bitmask_impl<U: ReverseBits, const M: usize>(bitmask: U) -> Self |
162 | where |
163 | LaneCount<M>: SupportedLaneCount, |
164 | { |
165 | // LLVM assumes bit order should match endianness |
166 | let bitmask = if cfg!(target_endian = "big" ) { |
167 | bitmask.reverse_bits(M) |
168 | } else { |
169 | bitmask |
170 | }; |
171 | |
172 | // SAFETY: `mask` is the correct bitmask type for a u64 bitmask |
173 | let mask: Simd<T, M> = unsafe { |
174 | core::intrinsics::simd::simd_select_bitmask( |
175 | bitmask, |
176 | Simd::<T, M>::splat(T::TRUE), |
177 | Simd::<T, M>::splat(T::FALSE), |
178 | ) |
179 | }; |
180 | |
181 | // SAFETY: `mask` only contains `T::TRUE` or `T::FALSE` |
182 | unsafe { Self::from_int_unchecked(mask.resize::<N>(T::FALSE)) } |
183 | } |
184 | |
185 | #[inline ] |
186 | pub(crate) fn to_bitmask_integer(self) -> u64 { |
187 | // TODO modify simd_bitmask to zero-extend output, making this unnecessary |
188 | if N <= 8 { |
189 | // Safety: bitmask matches length |
190 | unsafe { self.to_bitmask_impl::<u8, 8>() as u64 } |
191 | } else if N <= 16 { |
192 | // Safety: bitmask matches length |
193 | unsafe { self.to_bitmask_impl::<u16, 16>() as u64 } |
194 | } else if N <= 32 { |
195 | // Safety: bitmask matches length |
196 | unsafe { self.to_bitmask_impl::<u32, 32>() as u64 } |
197 | } else { |
198 | // Safety: bitmask matches length |
199 | unsafe { self.to_bitmask_impl::<u64, 64>() } |
200 | } |
201 | } |
202 | |
203 | #[inline ] |
204 | pub(crate) fn from_bitmask_integer(bitmask: u64) -> Self { |
205 | // TODO modify simd_bitmask_select to truncate input, making this unnecessary |
206 | if N <= 8 { |
207 | // Safety: bitmask matches length |
208 | unsafe { Self::from_bitmask_impl::<u8, 8>(bitmask as u8) } |
209 | } else if N <= 16 { |
210 | // Safety: bitmask matches length |
211 | unsafe { Self::from_bitmask_impl::<u16, 16>(bitmask as u16) } |
212 | } else if N <= 32 { |
213 | // Safety: bitmask matches length |
214 | unsafe { Self::from_bitmask_impl::<u32, 32>(bitmask as u32) } |
215 | } else { |
216 | // Safety: bitmask matches length |
217 | unsafe { Self::from_bitmask_impl::<u64, 64>(bitmask) } |
218 | } |
219 | } |
220 | |
221 | #[inline ] |
222 | #[must_use = "method returns a new bool and does not mutate the original value" ] |
223 | pub(crate) fn any(self) -> bool { |
224 | // Safety: use `self` as an integer vector |
225 | unsafe { core::intrinsics::simd::simd_reduce_any(self.to_int()) } |
226 | } |
227 | |
228 | #[inline ] |
229 | #[must_use = "method returns a new bool and does not mutate the original value" ] |
230 | pub(crate) fn all(self) -> bool { |
231 | // Safety: use `self` as an integer vector |
232 | unsafe { core::intrinsics::simd::simd_reduce_all(self.to_int()) } |
233 | } |
234 | } |
235 | |
236 | impl<T, const N: usize> From<Mask<T, N>> for Simd<T, N> |
237 | where |
238 | T: MaskElement, |
239 | LaneCount<N>: SupportedLaneCount, |
240 | { |
241 | #[inline ] |
242 | fn from(value: Mask<T, N>) -> Self { |
243 | value.0 |
244 | } |
245 | } |
246 | |
247 | impl<T, const N: usize> core::ops::BitAnd for Mask<T, N> |
248 | where |
249 | T: MaskElement, |
250 | LaneCount<N>: SupportedLaneCount, |
251 | { |
252 | type Output = Self; |
253 | #[inline ] |
254 | fn bitand(self, rhs: Self) -> Self { |
255 | // Safety: `self` is an integer vector |
256 | unsafe { Self(core::intrinsics::simd::simd_and(self.0, y:rhs.0)) } |
257 | } |
258 | } |
259 | |
260 | impl<T, const N: usize> core::ops::BitOr for Mask<T, N> |
261 | where |
262 | T: MaskElement, |
263 | LaneCount<N>: SupportedLaneCount, |
264 | { |
265 | type Output = Self; |
266 | #[inline ] |
267 | fn bitor(self, rhs: Self) -> Self { |
268 | // Safety: `self` is an integer vector |
269 | unsafe { Self(core::intrinsics::simd::simd_or(self.0, y:rhs.0)) } |
270 | } |
271 | } |
272 | |
273 | impl<T, const N: usize> core::ops::BitXor for Mask<T, N> |
274 | where |
275 | T: MaskElement, |
276 | LaneCount<N>: SupportedLaneCount, |
277 | { |
278 | type Output = Self; |
279 | #[inline ] |
280 | fn bitxor(self, rhs: Self) -> Self { |
281 | // Safety: `self` is an integer vector |
282 | unsafe { Self(core::intrinsics::simd::simd_xor(self.0, y:rhs.0)) } |
283 | } |
284 | } |
285 | |
286 | impl<T, const N: usize> core::ops::Not for Mask<T, N> |
287 | where |
288 | T: MaskElement, |
289 | LaneCount<N>: SupportedLaneCount, |
290 | { |
291 | type Output = Self; |
292 | #[inline ] |
293 | fn not(self) -> Self::Output { |
294 | Self::splat(true) ^ self |
295 | } |
296 | } |
297 | |