1 | use crate::simd::{ |
2 | cmp::SimdPartialEq, |
3 | intrinsics, |
4 | ptr::{SimdConstPtr, SimdMutPtr}, |
5 | LaneCount, Mask, Simd, SupportedLaneCount, |
6 | }; |
7 | |
8 | /// Parallel `PartialOrd`. |
9 | pub trait SimdPartialOrd: SimdPartialEq { |
10 | /// Test if each element is less than the corresponding element in `other`. |
11 | #[must_use = "method returns a new mask and does not mutate the original value" ] |
12 | fn simd_lt(self, other: Self) -> Self::Mask; |
13 | |
14 | /// Test if each element is less than or equal to the corresponding element in `other`. |
15 | #[must_use = "method returns a new mask and does not mutate the original value" ] |
16 | fn simd_le(self, other: Self) -> Self::Mask; |
17 | |
18 | /// Test if each element is greater than the corresponding element in `other`. |
19 | #[must_use = "method returns a new mask and does not mutate the original value" ] |
20 | fn simd_gt(self, other: Self) -> Self::Mask; |
21 | |
22 | /// Test if each element is greater than or equal to the corresponding element in `other`. |
23 | #[must_use = "method returns a new mask and does not mutate the original value" ] |
24 | fn simd_ge(self, other: Self) -> Self::Mask; |
25 | } |
26 | |
27 | /// Parallel `Ord`. |
28 | pub trait SimdOrd: SimdPartialOrd { |
29 | /// Returns the element-wise maximum with `other`. |
30 | #[must_use = "method returns a new vector and does not mutate the original value" ] |
31 | fn simd_max(self, other: Self) -> Self; |
32 | |
33 | /// Returns the element-wise minimum with `other`. |
34 | #[must_use = "method returns a new vector and does not mutate the original value" ] |
35 | fn simd_min(self, other: Self) -> Self; |
36 | |
37 | /// Restrict each element to a certain interval. |
38 | /// |
39 | /// For each element, returns `max` if `self` is greater than `max`, and `min` if `self` is |
40 | /// less than `min`. Otherwise returns `self`. |
41 | /// |
42 | /// # Panics |
43 | /// |
44 | /// Panics if `min > max` on any element. |
45 | #[must_use = "method returns a new vector and does not mutate the original value" ] |
46 | fn simd_clamp(self, min: Self, max: Self) -> Self; |
47 | } |
48 | |
49 | macro_rules! impl_integer { |
50 | { $($integer:ty),* } => { |
51 | $( |
52 | impl<const N: usize> SimdPartialOrd for Simd<$integer, N> |
53 | where |
54 | LaneCount<N>: SupportedLaneCount, |
55 | { |
56 | #[inline] |
57 | fn simd_lt(self, other: Self) -> Self::Mask { |
58 | // Safety: `self` is a vector, and the result of the comparison |
59 | // is always a valid mask. |
60 | unsafe { Mask::from_int_unchecked(intrinsics::simd_lt(self, other)) } |
61 | } |
62 | |
63 | #[inline] |
64 | fn simd_le(self, other: Self) -> Self::Mask { |
65 | // Safety: `self` is a vector, and the result of the comparison |
66 | // is always a valid mask. |
67 | unsafe { Mask::from_int_unchecked(intrinsics::simd_le(self, other)) } |
68 | } |
69 | |
70 | #[inline] |
71 | fn simd_gt(self, other: Self) -> Self::Mask { |
72 | // Safety: `self` is a vector, and the result of the comparison |
73 | // is always a valid mask. |
74 | unsafe { Mask::from_int_unchecked(intrinsics::simd_gt(self, other)) } |
75 | } |
76 | |
77 | #[inline] |
78 | fn simd_ge(self, other: Self) -> Self::Mask { |
79 | // Safety: `self` is a vector, and the result of the comparison |
80 | // is always a valid mask. |
81 | unsafe { Mask::from_int_unchecked(intrinsics::simd_ge(self, other)) } |
82 | } |
83 | } |
84 | |
85 | impl<const N: usize> SimdOrd for Simd<$integer, N> |
86 | where |
87 | LaneCount<N>: SupportedLaneCount, |
88 | { |
89 | #[inline] |
90 | fn simd_max(self, other: Self) -> Self { |
91 | self.simd_lt(other).select(other, self) |
92 | } |
93 | |
94 | #[inline] |
95 | fn simd_min(self, other: Self) -> Self { |
96 | self.simd_gt(other).select(other, self) |
97 | } |
98 | |
99 | #[inline] |
100 | #[track_caller] |
101 | fn simd_clamp(self, min: Self, max: Self) -> Self { |
102 | assert!( |
103 | min.simd_le(max).all(), |
104 | "each element in `min` must be less than or equal to the corresponding element in `max`" , |
105 | ); |
106 | self.simd_max(min).simd_min(max) |
107 | } |
108 | } |
109 | )* |
110 | } |
111 | } |
112 | |
113 | impl_integer! { u8, u16, u32, u64, usize, i8, i16, i32, i64, isize } |
114 | |
115 | macro_rules! impl_float { |
116 | { $($float:ty),* } => { |
117 | $( |
118 | impl<const N: usize> SimdPartialOrd for Simd<$float, N> |
119 | where |
120 | LaneCount<N>: SupportedLaneCount, |
121 | { |
122 | #[inline] |
123 | fn simd_lt(self, other: Self) -> Self::Mask { |
124 | // Safety: `self` is a vector, and the result of the comparison |
125 | // is always a valid mask. |
126 | unsafe { Mask::from_int_unchecked(intrinsics::simd_lt(self, other)) } |
127 | } |
128 | |
129 | #[inline] |
130 | fn simd_le(self, other: Self) -> Self::Mask { |
131 | // Safety: `self` is a vector, and the result of the comparison |
132 | // is always a valid mask. |
133 | unsafe { Mask::from_int_unchecked(intrinsics::simd_le(self, other)) } |
134 | } |
135 | |
136 | #[inline] |
137 | fn simd_gt(self, other: Self) -> Self::Mask { |
138 | // Safety: `self` is a vector, and the result of the comparison |
139 | // is always a valid mask. |
140 | unsafe { Mask::from_int_unchecked(intrinsics::simd_gt(self, other)) } |
141 | } |
142 | |
143 | #[inline] |
144 | fn simd_ge(self, other: Self) -> Self::Mask { |
145 | // Safety: `self` is a vector, and the result of the comparison |
146 | // is always a valid mask. |
147 | unsafe { Mask::from_int_unchecked(intrinsics::simd_ge(self, other)) } |
148 | } |
149 | } |
150 | )* |
151 | } |
152 | } |
153 | |
154 | impl_float! { f32, f64 } |
155 | |
156 | macro_rules! impl_mask { |
157 | { $($integer:ty),* } => { |
158 | $( |
159 | impl<const N: usize> SimdPartialOrd for Mask<$integer, N> |
160 | where |
161 | LaneCount<N>: SupportedLaneCount, |
162 | { |
163 | #[inline] |
164 | fn simd_lt(self, other: Self) -> Self::Mask { |
165 | // Safety: `self` is a vector, and the result of the comparison |
166 | // is always a valid mask. |
167 | unsafe { Self::from_int_unchecked(intrinsics::simd_lt(self.to_int(), other.to_int())) } |
168 | } |
169 | |
170 | #[inline] |
171 | fn simd_le(self, other: Self) -> Self::Mask { |
172 | // Safety: `self` is a vector, and the result of the comparison |
173 | // is always a valid mask. |
174 | unsafe { Self::from_int_unchecked(intrinsics::simd_le(self.to_int(), other.to_int())) } |
175 | } |
176 | |
177 | #[inline] |
178 | fn simd_gt(self, other: Self) -> Self::Mask { |
179 | // Safety: `self` is a vector, and the result of the comparison |
180 | // is always a valid mask. |
181 | unsafe { Self::from_int_unchecked(intrinsics::simd_gt(self.to_int(), other.to_int())) } |
182 | } |
183 | |
184 | #[inline] |
185 | fn simd_ge(self, other: Self) -> Self::Mask { |
186 | // Safety: `self` is a vector, and the result of the comparison |
187 | // is always a valid mask. |
188 | unsafe { Self::from_int_unchecked(intrinsics::simd_ge(self.to_int(), other.to_int())) } |
189 | } |
190 | } |
191 | |
192 | impl<const N: usize> SimdOrd for Mask<$integer, N> |
193 | where |
194 | LaneCount<N>: SupportedLaneCount, |
195 | { |
196 | #[inline] |
197 | fn simd_max(self, other: Self) -> Self { |
198 | self.simd_gt(other).select_mask(other, self) |
199 | } |
200 | |
201 | #[inline] |
202 | fn simd_min(self, other: Self) -> Self { |
203 | self.simd_lt(other).select_mask(other, self) |
204 | } |
205 | |
206 | #[inline] |
207 | #[track_caller] |
208 | fn simd_clamp(self, min: Self, max: Self) -> Self { |
209 | assert!( |
210 | min.simd_le(max).all(), |
211 | "each element in `min` must be less than or equal to the corresponding element in `max`" , |
212 | ); |
213 | self.simd_max(min).simd_min(max) |
214 | } |
215 | } |
216 | )* |
217 | } |
218 | } |
219 | |
220 | impl_mask! { i8, i16, i32, i64, isize } |
221 | |
222 | impl<T, const N: usize> SimdPartialOrd for Simd<*const T, N> |
223 | where |
224 | LaneCount<N>: SupportedLaneCount, |
225 | { |
226 | #[inline ] |
227 | fn simd_lt(self, other: Self) -> Self::Mask { |
228 | self.addr().simd_lt(other.addr()) |
229 | } |
230 | |
231 | #[inline ] |
232 | fn simd_le(self, other: Self) -> Self::Mask { |
233 | self.addr().simd_le(other.addr()) |
234 | } |
235 | |
236 | #[inline ] |
237 | fn simd_gt(self, other: Self) -> Self::Mask { |
238 | self.addr().simd_gt(other.addr()) |
239 | } |
240 | |
241 | #[inline ] |
242 | fn simd_ge(self, other: Self) -> Self::Mask { |
243 | self.addr().simd_ge(other.addr()) |
244 | } |
245 | } |
246 | |
247 | impl<T, const N: usize> SimdOrd for Simd<*const T, N> |
248 | where |
249 | LaneCount<N>: SupportedLaneCount, |
250 | { |
251 | #[inline ] |
252 | fn simd_max(self, other: Self) -> Self { |
253 | self.simd_lt(other).select(true_values:other, self) |
254 | } |
255 | |
256 | #[inline ] |
257 | fn simd_min(self, other: Self) -> Self { |
258 | self.simd_gt(other).select(true_values:other, self) |
259 | } |
260 | |
261 | #[inline ] |
262 | #[track_caller ] |
263 | fn simd_clamp(self, min: Self, max: Self) -> Self { |
264 | assert!( |
265 | min.simd_le(max).all(), |
266 | "each element in `min` must be less than or equal to the corresponding element in `max`" , |
267 | ); |
268 | self.simd_max(min).simd_min(max) |
269 | } |
270 | } |
271 | |
272 | impl<T, const N: usize> SimdPartialOrd for Simd<*mut T, N> |
273 | where |
274 | LaneCount<N>: SupportedLaneCount, |
275 | { |
276 | #[inline ] |
277 | fn simd_lt(self, other: Self) -> Self::Mask { |
278 | self.addr().simd_lt(other.addr()) |
279 | } |
280 | |
281 | #[inline ] |
282 | fn simd_le(self, other: Self) -> Self::Mask { |
283 | self.addr().simd_le(other.addr()) |
284 | } |
285 | |
286 | #[inline ] |
287 | fn simd_gt(self, other: Self) -> Self::Mask { |
288 | self.addr().simd_gt(other.addr()) |
289 | } |
290 | |
291 | #[inline ] |
292 | fn simd_ge(self, other: Self) -> Self::Mask { |
293 | self.addr().simd_ge(other.addr()) |
294 | } |
295 | } |
296 | |
297 | impl<T, const N: usize> SimdOrd for Simd<*mut T, N> |
298 | where |
299 | LaneCount<N>: SupportedLaneCount, |
300 | { |
301 | #[inline ] |
302 | fn simd_max(self, other: Self) -> Self { |
303 | self.simd_lt(other).select(true_values:other, self) |
304 | } |
305 | |
306 | #[inline ] |
307 | fn simd_min(self, other: Self) -> Self { |
308 | self.simd_gt(other).select(true_values:other, self) |
309 | } |
310 | |
311 | #[inline ] |
312 | #[track_caller ] |
313 | fn simd_clamp(self, min: Self, max: Self) -> Self { |
314 | assert!( |
315 | min.simd_le(max).all(), |
316 | "each element in `min` must be less than or equal to the corresponding element in `max`" , |
317 | ); |
318 | self.simd_max(min).simd_min(max) |
319 | } |
320 | } |
321 | |