1use crate::simd::{
2 cmp::SimdPartialEq,
3 ptr::{SimdConstPtr, SimdMutPtr},
4 LaneCount, Mask, Simd, SupportedLaneCount,
5};
6
7/// Parallel `PartialOrd`.
8pub trait SimdPartialOrd: SimdPartialEq {
9 /// Test if each element is less than the corresponding element in `other`.
10 #[must_use = "method returns a new mask and does not mutate the original value"]
11 fn simd_lt(self, other: Self) -> Self::Mask;
12
13 /// Test if each element is less than or equal to the corresponding element in `other`.
14 #[must_use = "method returns a new mask and does not mutate the original value"]
15 fn simd_le(self, other: Self) -> Self::Mask;
16
17 /// Test if each element is greater than the corresponding element in `other`.
18 #[must_use = "method returns a new mask and does not mutate the original value"]
19 fn simd_gt(self, other: Self) -> Self::Mask;
20
21 /// Test if each element is greater than or equal to the corresponding element in `other`.
22 #[must_use = "method returns a new mask and does not mutate the original value"]
23 fn simd_ge(self, other: Self) -> Self::Mask;
24}
25
26/// Parallel `Ord`.
27pub trait SimdOrd: SimdPartialOrd {
28 /// Returns the element-wise maximum with `other`.
29 #[must_use = "method returns a new vector and does not mutate the original value"]
30 fn simd_max(self, other: Self) -> Self;
31
32 /// Returns the element-wise minimum with `other`.
33 #[must_use = "method returns a new vector and does not mutate the original value"]
34 fn simd_min(self, other: Self) -> Self;
35
36 /// Restrict each element to a certain interval.
37 ///
38 /// For each element, returns `max` if `self` is greater than `max`, and `min` if `self` is
39 /// less than `min`. Otherwise returns `self`.
40 ///
41 /// # Panics
42 ///
43 /// Panics if `min > max` on any element.
44 #[must_use = "method returns a new vector and does not mutate the original value"]
45 fn simd_clamp(self, min: Self, max: Self) -> Self;
46}
47
48macro_rules! impl_integer {
49 { $($integer:ty),* } => {
50 $(
51 impl<const N: usize> SimdPartialOrd for Simd<$integer, N>
52 where
53 LaneCount<N>: SupportedLaneCount,
54 {
55 #[inline]
56 fn simd_lt(self, other: Self) -> Self::Mask {
57 // Safety: `self` is a vector, and the result of the comparison
58 // is always a valid mask.
59 unsafe { Mask::from_int_unchecked(core::intrinsics::simd::simd_lt(self, other)) }
60 }
61
62 #[inline]
63 fn simd_le(self, other: Self) -> Self::Mask {
64 // Safety: `self` is a vector, and the result of the comparison
65 // is always a valid mask.
66 unsafe { Mask::from_int_unchecked(core::intrinsics::simd::simd_le(self, other)) }
67 }
68
69 #[inline]
70 fn simd_gt(self, other: Self) -> Self::Mask {
71 // Safety: `self` is a vector, and the result of the comparison
72 // is always a valid mask.
73 unsafe { Mask::from_int_unchecked(core::intrinsics::simd::simd_gt(self, other)) }
74 }
75
76 #[inline]
77 fn simd_ge(self, other: Self) -> Self::Mask {
78 // Safety: `self` is a vector, and the result of the comparison
79 // is always a valid mask.
80 unsafe { Mask::from_int_unchecked(core::intrinsics::simd::simd_ge(self, other)) }
81 }
82 }
83
84 impl<const N: usize> SimdOrd for Simd<$integer, N>
85 where
86 LaneCount<N>: SupportedLaneCount,
87 {
88 #[inline]
89 fn simd_max(self, other: Self) -> Self {
90 self.simd_lt(other).select(other, self)
91 }
92
93 #[inline]
94 fn simd_min(self, other: Self) -> Self {
95 self.simd_gt(other).select(other, self)
96 }
97
98 #[inline]
99 #[track_caller]
100 fn simd_clamp(self, min: Self, max: Self) -> Self {
101 assert!(
102 min.simd_le(max).all(),
103 "each element in `min` must be less than or equal to the corresponding element in `max`",
104 );
105 self.simd_max(min).simd_min(max)
106 }
107 }
108 )*
109 }
110}
111
112impl_integer! { u8, u16, u32, u64, usize, i8, i16, i32, i64, isize }
113
114macro_rules! impl_float {
115 { $($float:ty),* } => {
116 $(
117 impl<const N: usize> SimdPartialOrd for Simd<$float, N>
118 where
119 LaneCount<N>: SupportedLaneCount,
120 {
121 #[inline]
122 fn simd_lt(self, other: Self) -> Self::Mask {
123 // Safety: `self` is a vector, and the result of the comparison
124 // is always a valid mask.
125 unsafe { Mask::from_int_unchecked(core::intrinsics::simd::simd_lt(self, other)) }
126 }
127
128 #[inline]
129 fn simd_le(self, other: Self) -> Self::Mask {
130 // Safety: `self` is a vector, and the result of the comparison
131 // is always a valid mask.
132 unsafe { Mask::from_int_unchecked(core::intrinsics::simd::simd_le(self, other)) }
133 }
134
135 #[inline]
136 fn simd_gt(self, other: Self) -> Self::Mask {
137 // Safety: `self` is a vector, and the result of the comparison
138 // is always a valid mask.
139 unsafe { Mask::from_int_unchecked(core::intrinsics::simd::simd_gt(self, other)) }
140 }
141
142 #[inline]
143 fn simd_ge(self, other: Self) -> Self::Mask {
144 // Safety: `self` is a vector, and the result of the comparison
145 // is always a valid mask.
146 unsafe { Mask::from_int_unchecked(core::intrinsics::simd::simd_ge(self, other)) }
147 }
148 }
149 )*
150 }
151}
152
153impl_float! { f32, f64 }
154
155macro_rules! impl_mask {
156 { $($integer:ty),* } => {
157 $(
158 impl<const N: usize> SimdPartialOrd for Mask<$integer, N>
159 where
160 LaneCount<N>: SupportedLaneCount,
161 {
162 #[inline]
163 fn simd_lt(self, other: Self) -> Self::Mask {
164 // Safety: `self` is a vector, and the result of the comparison
165 // is always a valid mask.
166 unsafe { Self::from_int_unchecked(core::intrinsics::simd::simd_lt(self.to_int(), other.to_int())) }
167 }
168
169 #[inline]
170 fn simd_le(self, other: Self) -> Self::Mask {
171 // Safety: `self` is a vector, and the result of the comparison
172 // is always a valid mask.
173 unsafe { Self::from_int_unchecked(core::intrinsics::simd::simd_le(self.to_int(), other.to_int())) }
174 }
175
176 #[inline]
177 fn simd_gt(self, other: Self) -> Self::Mask {
178 // Safety: `self` is a vector, and the result of the comparison
179 // is always a valid mask.
180 unsafe { Self::from_int_unchecked(core::intrinsics::simd::simd_gt(self.to_int(), other.to_int())) }
181 }
182
183 #[inline]
184 fn simd_ge(self, other: Self) -> Self::Mask {
185 // Safety: `self` is a vector, and the result of the comparison
186 // is always a valid mask.
187 unsafe { Self::from_int_unchecked(core::intrinsics::simd::simd_ge(self.to_int(), other.to_int())) }
188 }
189 }
190
191 impl<const N: usize> SimdOrd for Mask<$integer, N>
192 where
193 LaneCount<N>: SupportedLaneCount,
194 {
195 #[inline]
196 fn simd_max(self, other: Self) -> Self {
197 self.simd_gt(other).select_mask(other, self)
198 }
199
200 #[inline]
201 fn simd_min(self, other: Self) -> Self {
202 self.simd_lt(other).select_mask(other, self)
203 }
204
205 #[inline]
206 #[track_caller]
207 fn simd_clamp(self, min: Self, max: Self) -> Self {
208 assert!(
209 min.simd_le(max).all(),
210 "each element in `min` must be less than or equal to the corresponding element in `max`",
211 );
212 self.simd_max(min).simd_min(max)
213 }
214 }
215 )*
216 }
217}
218
219impl_mask! { i8, i16, i32, i64, isize }
220
221impl<T, const N: usize> SimdPartialOrd for Simd<*const T, N>
222where
223 LaneCount<N>: SupportedLaneCount,
224{
225 #[inline]
226 fn simd_lt(self, other: Self) -> Self::Mask {
227 self.addr().simd_lt(other.addr())
228 }
229
230 #[inline]
231 fn simd_le(self, other: Self) -> Self::Mask {
232 self.addr().simd_le(other.addr())
233 }
234
235 #[inline]
236 fn simd_gt(self, other: Self) -> Self::Mask {
237 self.addr().simd_gt(other.addr())
238 }
239
240 #[inline]
241 fn simd_ge(self, other: Self) -> Self::Mask {
242 self.addr().simd_ge(other.addr())
243 }
244}
245
246impl<T, const N: usize> SimdOrd for Simd<*const T, N>
247where
248 LaneCount<N>: SupportedLaneCount,
249{
250 #[inline]
251 fn simd_max(self, other: Self) -> Self {
252 self.simd_lt(other).select(true_values:other, self)
253 }
254
255 #[inline]
256 fn simd_min(self, other: Self) -> Self {
257 self.simd_gt(other).select(true_values:other, self)
258 }
259
260 #[inline]
261 #[track_caller]
262 fn simd_clamp(self, min: Self, max: Self) -> Self {
263 assert!(
264 min.simd_le(max).all(),
265 "each element in `min` must be less than or equal to the corresponding element in `max`",
266 );
267 self.simd_max(min).simd_min(max)
268 }
269}
270
271impl<T, const N: usize> SimdPartialOrd for Simd<*mut T, N>
272where
273 LaneCount<N>: SupportedLaneCount,
274{
275 #[inline]
276 fn simd_lt(self, other: Self) -> Self::Mask {
277 self.addr().simd_lt(other.addr())
278 }
279
280 #[inline]
281 fn simd_le(self, other: Self) -> Self::Mask {
282 self.addr().simd_le(other.addr())
283 }
284
285 #[inline]
286 fn simd_gt(self, other: Self) -> Self::Mask {
287 self.addr().simd_gt(other.addr())
288 }
289
290 #[inline]
291 fn simd_ge(self, other: Self) -> Self::Mask {
292 self.addr().simd_ge(other.addr())
293 }
294}
295
296impl<T, const N: usize> SimdOrd for Simd<*mut T, N>
297where
298 LaneCount<N>: SupportedLaneCount,
299{
300 #[inline]
301 fn simd_max(self, other: Self) -> Self {
302 self.simd_lt(other).select(true_values:other, self)
303 }
304
305 #[inline]
306 fn simd_min(self, other: Self) -> Self {
307 self.simd_gt(other).select(true_values:other, self)
308 }
309
310 #[inline]
311 #[track_caller]
312 fn simd_clamp(self, min: Self, max: Self) -> Self {
313 assert!(
314 min.simd_le(max).all(),
315 "each element in `min` must be less than or equal to the corresponding element in `max`",
316 );
317 self.simd_max(min).simd_min(max)
318 }
319}
320