1use crate::simd::{
2 cmp::SimdPartialEq,
3 intrinsics,
4 ptr::{SimdConstPtr, SimdMutPtr},
5 LaneCount, Mask, Simd, SupportedLaneCount,
6};
7
8/// Parallel `PartialOrd`.
9pub trait SimdPartialOrd: SimdPartialEq {
10 /// Test if each element is less than the corresponding element in `other`.
11 #[must_use = "method returns a new mask and does not mutate the original value"]
12 fn simd_lt(self, other: Self) -> Self::Mask;
13
14 /// Test if each element is less than or equal to the corresponding element in `other`.
15 #[must_use = "method returns a new mask and does not mutate the original value"]
16 fn simd_le(self, other: Self) -> Self::Mask;
17
18 /// Test if each element is greater than the corresponding element in `other`.
19 #[must_use = "method returns a new mask and does not mutate the original value"]
20 fn simd_gt(self, other: Self) -> Self::Mask;
21
22 /// Test if each element is greater than or equal to the corresponding element in `other`.
23 #[must_use = "method returns a new mask and does not mutate the original value"]
24 fn simd_ge(self, other: Self) -> Self::Mask;
25}
26
27/// Parallel `Ord`.
28pub trait SimdOrd: SimdPartialOrd {
29 /// Returns the element-wise maximum with `other`.
30 #[must_use = "method returns a new vector and does not mutate the original value"]
31 fn simd_max(self, other: Self) -> Self;
32
33 /// Returns the element-wise minimum with `other`.
34 #[must_use = "method returns a new vector and does not mutate the original value"]
35 fn simd_min(self, other: Self) -> Self;
36
37 /// Restrict each element to a certain interval.
38 ///
39 /// For each element, returns `max` if `self` is greater than `max`, and `min` if `self` is
40 /// less than `min`. Otherwise returns `self`.
41 ///
42 /// # Panics
43 ///
44 /// Panics if `min > max` on any element.
45 #[must_use = "method returns a new vector and does not mutate the original value"]
46 fn simd_clamp(self, min: Self, max: Self) -> Self;
47}
48
49macro_rules! impl_integer {
50 { $($integer:ty),* } => {
51 $(
52 impl<const N: usize> SimdPartialOrd for Simd<$integer, N>
53 where
54 LaneCount<N>: SupportedLaneCount,
55 {
56 #[inline]
57 fn simd_lt(self, other: Self) -> Self::Mask {
58 // Safety: `self` is a vector, and the result of the comparison
59 // is always a valid mask.
60 unsafe { Mask::from_int_unchecked(intrinsics::simd_lt(self, other)) }
61 }
62
63 #[inline]
64 fn simd_le(self, other: Self) -> Self::Mask {
65 // Safety: `self` is a vector, and the result of the comparison
66 // is always a valid mask.
67 unsafe { Mask::from_int_unchecked(intrinsics::simd_le(self, other)) }
68 }
69
70 #[inline]
71 fn simd_gt(self, other: Self) -> Self::Mask {
72 // Safety: `self` is a vector, and the result of the comparison
73 // is always a valid mask.
74 unsafe { Mask::from_int_unchecked(intrinsics::simd_gt(self, other)) }
75 }
76
77 #[inline]
78 fn simd_ge(self, other: Self) -> Self::Mask {
79 // Safety: `self` is a vector, and the result of the comparison
80 // is always a valid mask.
81 unsafe { Mask::from_int_unchecked(intrinsics::simd_ge(self, other)) }
82 }
83 }
84
85 impl<const N: usize> SimdOrd for Simd<$integer, N>
86 where
87 LaneCount<N>: SupportedLaneCount,
88 {
89 #[inline]
90 fn simd_max(self, other: Self) -> Self {
91 self.simd_lt(other).select(other, self)
92 }
93
94 #[inline]
95 fn simd_min(self, other: Self) -> Self {
96 self.simd_gt(other).select(other, self)
97 }
98
99 #[inline]
100 #[track_caller]
101 fn simd_clamp(self, min: Self, max: Self) -> Self {
102 assert!(
103 min.simd_le(max).all(),
104 "each element in `min` must be less than or equal to the corresponding element in `max`",
105 );
106 self.simd_max(min).simd_min(max)
107 }
108 }
109 )*
110 }
111}
112
113impl_integer! { u8, u16, u32, u64, usize, i8, i16, i32, i64, isize }
114
115macro_rules! impl_float {
116 { $($float:ty),* } => {
117 $(
118 impl<const N: usize> SimdPartialOrd for Simd<$float, N>
119 where
120 LaneCount<N>: SupportedLaneCount,
121 {
122 #[inline]
123 fn simd_lt(self, other: Self) -> Self::Mask {
124 // Safety: `self` is a vector, and the result of the comparison
125 // is always a valid mask.
126 unsafe { Mask::from_int_unchecked(intrinsics::simd_lt(self, other)) }
127 }
128
129 #[inline]
130 fn simd_le(self, other: Self) -> Self::Mask {
131 // Safety: `self` is a vector, and the result of the comparison
132 // is always a valid mask.
133 unsafe { Mask::from_int_unchecked(intrinsics::simd_le(self, other)) }
134 }
135
136 #[inline]
137 fn simd_gt(self, other: Self) -> Self::Mask {
138 // Safety: `self` is a vector, and the result of the comparison
139 // is always a valid mask.
140 unsafe { Mask::from_int_unchecked(intrinsics::simd_gt(self, other)) }
141 }
142
143 #[inline]
144 fn simd_ge(self, other: Self) -> Self::Mask {
145 // Safety: `self` is a vector, and the result of the comparison
146 // is always a valid mask.
147 unsafe { Mask::from_int_unchecked(intrinsics::simd_ge(self, other)) }
148 }
149 }
150 )*
151 }
152}
153
154impl_float! { f32, f64 }
155
156macro_rules! impl_mask {
157 { $($integer:ty),* } => {
158 $(
159 impl<const N: usize> SimdPartialOrd for Mask<$integer, N>
160 where
161 LaneCount<N>: SupportedLaneCount,
162 {
163 #[inline]
164 fn simd_lt(self, other: Self) -> Self::Mask {
165 // Safety: `self` is a vector, and the result of the comparison
166 // is always a valid mask.
167 unsafe { Self::from_int_unchecked(intrinsics::simd_lt(self.to_int(), other.to_int())) }
168 }
169
170 #[inline]
171 fn simd_le(self, other: Self) -> Self::Mask {
172 // Safety: `self` is a vector, and the result of the comparison
173 // is always a valid mask.
174 unsafe { Self::from_int_unchecked(intrinsics::simd_le(self.to_int(), other.to_int())) }
175 }
176
177 #[inline]
178 fn simd_gt(self, other: Self) -> Self::Mask {
179 // Safety: `self` is a vector, and the result of the comparison
180 // is always a valid mask.
181 unsafe { Self::from_int_unchecked(intrinsics::simd_gt(self.to_int(), other.to_int())) }
182 }
183
184 #[inline]
185 fn simd_ge(self, other: Self) -> Self::Mask {
186 // Safety: `self` is a vector, and the result of the comparison
187 // is always a valid mask.
188 unsafe { Self::from_int_unchecked(intrinsics::simd_ge(self.to_int(), other.to_int())) }
189 }
190 }
191
192 impl<const N: usize> SimdOrd for Mask<$integer, N>
193 where
194 LaneCount<N>: SupportedLaneCount,
195 {
196 #[inline]
197 fn simd_max(self, other: Self) -> Self {
198 self.simd_gt(other).select_mask(other, self)
199 }
200
201 #[inline]
202 fn simd_min(self, other: Self) -> Self {
203 self.simd_lt(other).select_mask(other, self)
204 }
205
206 #[inline]
207 #[track_caller]
208 fn simd_clamp(self, min: Self, max: Self) -> Self {
209 assert!(
210 min.simd_le(max).all(),
211 "each element in `min` must be less than or equal to the corresponding element in `max`",
212 );
213 self.simd_max(min).simd_min(max)
214 }
215 }
216 )*
217 }
218}
219
220impl_mask! { i8, i16, i32, i64, isize }
221
222impl<T, const N: usize> SimdPartialOrd for Simd<*const T, N>
223where
224 LaneCount<N>: SupportedLaneCount,
225{
226 #[inline]
227 fn simd_lt(self, other: Self) -> Self::Mask {
228 self.addr().simd_lt(other.addr())
229 }
230
231 #[inline]
232 fn simd_le(self, other: Self) -> Self::Mask {
233 self.addr().simd_le(other.addr())
234 }
235
236 #[inline]
237 fn simd_gt(self, other: Self) -> Self::Mask {
238 self.addr().simd_gt(other.addr())
239 }
240
241 #[inline]
242 fn simd_ge(self, other: Self) -> Self::Mask {
243 self.addr().simd_ge(other.addr())
244 }
245}
246
247impl<T, const N: usize> SimdOrd for Simd<*const T, N>
248where
249 LaneCount<N>: SupportedLaneCount,
250{
251 #[inline]
252 fn simd_max(self, other: Self) -> Self {
253 self.simd_lt(other).select(true_values:other, self)
254 }
255
256 #[inline]
257 fn simd_min(self, other: Self) -> Self {
258 self.simd_gt(other).select(true_values:other, self)
259 }
260
261 #[inline]
262 #[track_caller]
263 fn simd_clamp(self, min: Self, max: Self) -> Self {
264 assert!(
265 min.simd_le(max).all(),
266 "each element in `min` must be less than or equal to the corresponding element in `max`",
267 );
268 self.simd_max(min).simd_min(max)
269 }
270}
271
272impl<T, const N: usize> SimdPartialOrd for Simd<*mut T, N>
273where
274 LaneCount<N>: SupportedLaneCount,
275{
276 #[inline]
277 fn simd_lt(self, other: Self) -> Self::Mask {
278 self.addr().simd_lt(other.addr())
279 }
280
281 #[inline]
282 fn simd_le(self, other: Self) -> Self::Mask {
283 self.addr().simd_le(other.addr())
284 }
285
286 #[inline]
287 fn simd_gt(self, other: Self) -> Self::Mask {
288 self.addr().simd_gt(other.addr())
289 }
290
291 #[inline]
292 fn simd_ge(self, other: Self) -> Self::Mask {
293 self.addr().simd_ge(other.addr())
294 }
295}
296
297impl<T, const N: usize> SimdOrd for Simd<*mut T, N>
298where
299 LaneCount<N>: SupportedLaneCount,
300{
301 #[inline]
302 fn simd_max(self, other: Self) -> Self {
303 self.simd_lt(other).select(true_values:other, self)
304 }
305
306 #[inline]
307 fn simd_min(self, other: Self) -> Self {
308 self.simd_gt(other).select(true_values:other, self)
309 }
310
311 #[inline]
312 #[track_caller]
313 fn simd_clamp(self, min: Self, max: Self) -> Self {
314 assert!(
315 min.simd_le(max).all(),
316 "each element in `min` must be less than or equal to the corresponding element in `max`",
317 );
318 self.simd_max(min).simd_min(max)
319 }
320}
321