1// Copyright 2018 Developers of the Rand project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! Math helper functions
10
11#[cfg(feature = "simd_support")] use packed_simd::*;
12
13
14pub(crate) trait WideningMultiply<RHS = Self> {
15 type Output;
16
17 fn wmul(self, x: RHS) -> Self::Output;
18}
19
20macro_rules! wmul_impl {
21 ($ty:ty, $wide:ty, $shift:expr) => {
22 impl WideningMultiply for $ty {
23 type Output = ($ty, $ty);
24
25 #[inline(always)]
26 fn wmul(self, x: $ty) -> Self::Output {
27 let tmp = (self as $wide) * (x as $wide);
28 ((tmp >> $shift) as $ty, tmp as $ty)
29 }
30 }
31 };
32
33 // simd bulk implementation
34 ($(($ty:ident, $wide:ident),)+, $shift:expr) => {
35 $(
36 impl WideningMultiply for $ty {
37 type Output = ($ty, $ty);
38
39 #[inline(always)]
40 fn wmul(self, x: $ty) -> Self::Output {
41 // For supported vectors, this should compile to a couple
42 // supported multiply & swizzle instructions (no actual
43 // casting).
44 // TODO: optimize
45 let y: $wide = self.cast();
46 let x: $wide = x.cast();
47 let tmp = y * x;
48 let hi: $ty = (tmp >> $shift).cast();
49 let lo: $ty = tmp.cast();
50 (hi, lo)
51 }
52 }
53 )+
54 };
55}
56wmul_impl! { u8, u16, 8 }
57wmul_impl! { u16, u32, 16 }
58wmul_impl! { u32, u64, 32 }
59wmul_impl! { u64, u128, 64 }
60
61// This code is a translation of the __mulddi3 function in LLVM's
62// compiler-rt. It is an optimised variant of the common method
63// `(a + b) * (c + d) = ac + ad + bc + bd`.
64//
65// For some reason LLVM can optimise the C version very well, but
66// keeps shuffling registers in this Rust translation.
67macro_rules! wmul_impl_large {
68 ($ty:ty, $half:expr) => {
69 impl WideningMultiply for $ty {
70 type Output = ($ty, $ty);
71
72 #[inline(always)]
73 fn wmul(self, b: $ty) -> Self::Output {
74 const LOWER_MASK: $ty = !0 >> $half;
75 let mut low = (self & LOWER_MASK).wrapping_mul(b & LOWER_MASK);
76 let mut t = low >> $half;
77 low &= LOWER_MASK;
78 t += (self >> $half).wrapping_mul(b & LOWER_MASK);
79 low += (t & LOWER_MASK) << $half;
80 let mut high = t >> $half;
81 t = low >> $half;
82 low &= LOWER_MASK;
83 t += (b >> $half).wrapping_mul(self & LOWER_MASK);
84 low += (t & LOWER_MASK) << $half;
85 high += t >> $half;
86 high += (self >> $half).wrapping_mul(b >> $half);
87
88 (high, low)
89 }
90 }
91 };
92
93 // simd bulk implementation
94 (($($ty:ty,)+) $scalar:ty, $half:expr) => {
95 $(
96 impl WideningMultiply for $ty {
97 type Output = ($ty, $ty);
98
99 #[inline(always)]
100 fn wmul(self, b: $ty) -> Self::Output {
101 // needs wrapping multiplication
102 const LOWER_MASK: $scalar = !0 >> $half;
103 let mut low = (self & LOWER_MASK) * (b & LOWER_MASK);
104 let mut t = low >> $half;
105 low &= LOWER_MASK;
106 t += (self >> $half) * (b & LOWER_MASK);
107 low += (t & LOWER_MASK) << $half;
108 let mut high = t >> $half;
109 t = low >> $half;
110 low &= LOWER_MASK;
111 t += (b >> $half) * (self & LOWER_MASK);
112 low += (t & LOWER_MASK) << $half;
113 high += t >> $half;
114 high += (self >> $half) * (b >> $half);
115
116 (high, low)
117 }
118 }
119 )+
120 };
121}
122wmul_impl_large! { u128, 64 }
123
124macro_rules! wmul_impl_usize {
125 ($ty:ty) => {
126 impl WideningMultiply for usize {
127 type Output = (usize, usize);
128
129 #[inline(always)]
130 fn wmul(self, x: usize) -> Self::Output {
131 let (high, low) = (self as $ty).wmul(x as $ty);
132 (high as usize, low as usize)
133 }
134 }
135 };
136}
137#[cfg(target_pointer_width = "16")]
138wmul_impl_usize! { u16 }
139#[cfg(target_pointer_width = "32")]
140wmul_impl_usize! { u32 }
141#[cfg(target_pointer_width = "64")]
142wmul_impl_usize! { u64 }
143
144#[cfg(feature = "simd_support")]
145mod simd_wmul {
146 use super::*;
147 #[cfg(target_arch = "x86")] use core::arch::x86::*;
148 #[cfg(target_arch = "x86_64")] use core::arch::x86_64::*;
149
150 wmul_impl! {
151 (u8x2, u16x2),
152 (u8x4, u16x4),
153 (u8x8, u16x8),
154 (u8x16, u16x16),
155 (u8x32, u16x32),,
156 8
157 }
158
159 wmul_impl! { (u16x2, u32x2),, 16 }
160 wmul_impl! { (u16x4, u32x4),, 16 }
161 #[cfg(not(target_feature = "sse2"))]
162 wmul_impl! { (u16x8, u32x8),, 16 }
163 #[cfg(not(target_feature = "avx2"))]
164 wmul_impl! { (u16x16, u32x16),, 16 }
165
166 // 16-bit lane widths allow use of the x86 `mulhi` instructions, which
167 // means `wmul` can be implemented with only two instructions.
168 #[allow(unused_macros)]
169 macro_rules! wmul_impl_16 {
170 ($ty:ident, $intrinsic:ident, $mulhi:ident, $mullo:ident) => {
171 impl WideningMultiply for $ty {
172 type Output = ($ty, $ty);
173
174 #[inline(always)]
175 fn wmul(self, x: $ty) -> Self::Output {
176 let b = $intrinsic::from_bits(x);
177 let a = $intrinsic::from_bits(self);
178 let hi = $ty::from_bits(unsafe { $mulhi(a, b) });
179 let lo = $ty::from_bits(unsafe { $mullo(a, b) });
180 (hi, lo)
181 }
182 }
183 };
184 }
185
186 #[cfg(target_feature = "sse2")]
187 wmul_impl_16! { u16x8, __m128i, _mm_mulhi_epu16, _mm_mullo_epi16 }
188 #[cfg(target_feature = "avx2")]
189 wmul_impl_16! { u16x16, __m256i, _mm256_mulhi_epu16, _mm256_mullo_epi16 }
190 // FIXME: there are no `__m512i` types in stdsimd yet, so `wmul::<u16x32>`
191 // cannot use the same implementation.
192
193 wmul_impl! {
194 (u32x2, u64x2),
195 (u32x4, u64x4),
196 (u32x8, u64x8),,
197 32
198 }
199
200 // TODO: optimize, this seems to seriously slow things down
201 wmul_impl_large! { (u8x64,) u8, 4 }
202 wmul_impl_large! { (u16x32,) u16, 8 }
203 wmul_impl_large! { (u32x16,) u32, 16 }
204 wmul_impl_large! { (u64x2, u64x4, u64x8,) u64, 32 }
205}
206
207/// Helper trait when dealing with scalar and SIMD floating point types.
208pub(crate) trait FloatSIMDUtils {
209 // `PartialOrd` for vectors compares lexicographically. We want to compare all
210 // the individual SIMD lanes instead, and get the combined result over all
211 // lanes. This is possible using something like `a.lt(b).all()`, but we
212 // implement it as a trait so we can write the same code for `f32` and `f64`.
213 // Only the comparison functions we need are implemented.
214 fn all_lt(self, other: Self) -> bool;
215 fn all_le(self, other: Self) -> bool;
216 fn all_finite(self) -> bool;
217
218 type Mask;
219 fn finite_mask(self) -> Self::Mask;
220 fn gt_mask(self, other: Self) -> Self::Mask;
221 fn ge_mask(self, other: Self) -> Self::Mask;
222
223 // Decrease all lanes where the mask is `true` to the next lower value
224 // representable by the floating-point type. At least one of the lanes
225 // must be set.
226 fn decrease_masked(self, mask: Self::Mask) -> Self;
227
228 // Convert from int value. Conversion is done while retaining the numerical
229 // value, not by retaining the binary representation.
230 type UInt;
231 fn cast_from_int(i: Self::UInt) -> Self;
232}
233
234/// Implement functions available in std builds but missing from core primitives
235#[cfg(not(std))]
236// False positive: We are following `std` here.
237#[allow(clippy::wrong_self_convention)]
238pub(crate) trait Float: Sized {
239 fn is_nan(self) -> bool;
240 fn is_infinite(self) -> bool;
241 fn is_finite(self) -> bool;
242}
243
244/// Implement functions on f32/f64 to give them APIs similar to SIMD types
245pub(crate) trait FloatAsSIMD: Sized {
246 #[inline(always)]
247 fn lanes() -> usize {
248 1
249 }
250 #[inline(always)]
251 fn splat(scalar: Self) -> Self {
252 scalar
253 }
254 #[inline(always)]
255 fn extract(self, index: usize) -> Self {
256 debug_assert_eq!(index, 0);
257 self
258 }
259 #[inline(always)]
260 fn replace(self, index: usize, new_value: Self) -> Self {
261 debug_assert_eq!(index, 0);
262 new_value
263 }
264}
265
266pub(crate) trait BoolAsSIMD: Sized {
267 fn any(self) -> bool;
268 fn all(self) -> bool;
269 fn none(self) -> bool;
270}
271
272impl BoolAsSIMD for bool {
273 #[inline(always)]
274 fn any(self) -> bool {
275 self
276 }
277
278 #[inline(always)]
279 fn all(self) -> bool {
280 self
281 }
282
283 #[inline(always)]
284 fn none(self) -> bool {
285 !self
286 }
287}
288
289macro_rules! scalar_float_impl {
290 ($ty:ident, $uty:ident) => {
291 #[cfg(not(std))]
292 impl Float for $ty {
293 #[inline]
294 fn is_nan(self) -> bool {
295 self != self
296 }
297
298 #[inline]
299 fn is_infinite(self) -> bool {
300 self == ::core::$ty::INFINITY || self == ::core::$ty::NEG_INFINITY
301 }
302
303 #[inline]
304 fn is_finite(self) -> bool {
305 !(self.is_nan() || self.is_infinite())
306 }
307 }
308
309 impl FloatSIMDUtils for $ty {
310 type Mask = bool;
311 type UInt = $uty;
312
313 #[inline(always)]
314 fn all_lt(self, other: Self) -> bool {
315 self < other
316 }
317
318 #[inline(always)]
319 fn all_le(self, other: Self) -> bool {
320 self <= other
321 }
322
323 #[inline(always)]
324 fn all_finite(self) -> bool {
325 self.is_finite()
326 }
327
328 #[inline(always)]
329 fn finite_mask(self) -> Self::Mask {
330 self.is_finite()
331 }
332
333 #[inline(always)]
334 fn gt_mask(self, other: Self) -> Self::Mask {
335 self > other
336 }
337
338 #[inline(always)]
339 fn ge_mask(self, other: Self) -> Self::Mask {
340 self >= other
341 }
342
343 #[inline(always)]
344 fn decrease_masked(self, mask: Self::Mask) -> Self {
345 debug_assert!(mask, "At least one lane must be set");
346 <$ty>::from_bits(self.to_bits() - 1)
347 }
348
349 #[inline]
350 fn cast_from_int(i: Self::UInt) -> Self {
351 i as $ty
352 }
353 }
354
355 impl FloatAsSIMD for $ty {}
356 };
357}
358
359scalar_float_impl!(f32, u32);
360scalar_float_impl!(f64, u64);
361
362
363#[cfg(feature = "simd_support")]
364macro_rules! simd_impl {
365 ($ty:ident, $f_scalar:ident, $mty:ident, $uty:ident) => {
366 impl FloatSIMDUtils for $ty {
367 type Mask = $mty;
368 type UInt = $uty;
369
370 #[inline(always)]
371 fn all_lt(self, other: Self) -> bool {
372 self.lt(other).all()
373 }
374
375 #[inline(always)]
376 fn all_le(self, other: Self) -> bool {
377 self.le(other).all()
378 }
379
380 #[inline(always)]
381 fn all_finite(self) -> bool {
382 self.finite_mask().all()
383 }
384
385 #[inline(always)]
386 fn finite_mask(self) -> Self::Mask {
387 // This can possibly be done faster by checking bit patterns
388 let neg_inf = $ty::splat(::core::$f_scalar::NEG_INFINITY);
389 let pos_inf = $ty::splat(::core::$f_scalar::INFINITY);
390 self.gt(neg_inf) & self.lt(pos_inf)
391 }
392
393 #[inline(always)]
394 fn gt_mask(self, other: Self) -> Self::Mask {
395 self.gt(other)
396 }
397
398 #[inline(always)]
399 fn ge_mask(self, other: Self) -> Self::Mask {
400 self.ge(other)
401 }
402
403 #[inline(always)]
404 fn decrease_masked(self, mask: Self::Mask) -> Self {
405 // Casting a mask into ints will produce all bits set for
406 // true, and 0 for false. Adding that to the binary
407 // representation of a float means subtracting one from
408 // the binary representation, resulting in the next lower
409 // value representable by $ty. This works even when the
410 // current value is infinity.
411 debug_assert!(mask.any(), "At least one lane must be set");
412 <$ty>::from_bits(<$uty>::from_bits(self) + <$uty>::from_bits(mask))
413 }
414
415 #[inline]
416 fn cast_from_int(i: Self::UInt) -> Self {
417 i.cast()
418 }
419 }
420 };
421}
422
423#[cfg(feature="simd_support")] simd_impl! { f32x2, f32, m32x2, u32x2 }
424#[cfg(feature="simd_support")] simd_impl! { f32x4, f32, m32x4, u32x4 }
425#[cfg(feature="simd_support")] simd_impl! { f32x8, f32, m32x8, u32x8 }
426#[cfg(feature="simd_support")] simd_impl! { f32x16, f32, m32x16, u32x16 }
427#[cfg(feature="simd_support")] simd_impl! { f64x2, f64, m64x2, u64x2 }
428#[cfg(feature="simd_support")] simd_impl! { f64x4, f64, m64x4, u64x4 }
429#[cfg(feature="simd_support")] simd_impl! { f64x8, f64, m64x8, u64x8 }
430