| 1 | // Copyright 2018 Developers of the Rand project. |
| 2 | // |
| 3 | // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or |
| 4 | // https://www.apache.org/licenses/LICENSE-2.0> or the MIT license |
| 5 | // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your |
| 6 | // option. This file may not be copied, modified, or distributed |
| 7 | // except according to those terms. |
| 8 | |
| 9 | //! Math helper functions |
| 10 | |
| 11 | #[cfg (feature = "simd_support" )] |
| 12 | use core::simd::prelude::*; |
| 13 | #[cfg (feature = "simd_support" )] |
| 14 | use core::simd::{LaneCount, SimdElement, SupportedLaneCount}; |
| 15 | |
| 16 | pub(crate) trait WideningMultiply<RHS = Self> { |
| 17 | type Output; |
| 18 | |
| 19 | fn wmul(self, x: RHS) -> Self::Output; |
| 20 | } |
| 21 | |
| 22 | macro_rules! wmul_impl { |
| 23 | ($ty:ty, $wide:ty, $shift:expr) => { |
| 24 | impl WideningMultiply for $ty { |
| 25 | type Output = ($ty, $ty); |
| 26 | |
| 27 | #[inline(always)] |
| 28 | fn wmul(self, x: $ty) -> Self::Output { |
| 29 | let tmp = (self as $wide) * (x as $wide); |
| 30 | ((tmp >> $shift) as $ty, tmp as $ty) |
| 31 | } |
| 32 | } |
| 33 | }; |
| 34 | |
| 35 | // simd bulk implementation |
| 36 | ($(($ty:ident, $wide:ty),)+, $shift:expr) => { |
| 37 | $( |
| 38 | impl WideningMultiply for $ty { |
| 39 | type Output = ($ty, $ty); |
| 40 | |
| 41 | #[inline(always)] |
| 42 | fn wmul(self, x: $ty) -> Self::Output { |
| 43 | // For supported vectors, this should compile to a couple |
| 44 | // supported multiply & swizzle instructions (no actual |
| 45 | // casting). |
| 46 | // TODO: optimize |
| 47 | let y: $wide = self.cast(); |
| 48 | let x: $wide = x.cast(); |
| 49 | let tmp = y * x; |
| 50 | let hi: $ty = (tmp >> Simd::splat($shift)).cast(); |
| 51 | let lo: $ty = tmp.cast(); |
| 52 | (hi, lo) |
| 53 | } |
| 54 | } |
| 55 | )+ |
| 56 | }; |
| 57 | } |
| 58 | wmul_impl! { u8, u16, 8 } |
| 59 | wmul_impl! { u16, u32, 16 } |
| 60 | wmul_impl! { u32, u64, 32 } |
| 61 | wmul_impl! { u64, u128, 64 } |
| 62 | |
| 63 | // This code is a translation of the __mulddi3 function in LLVM's |
| 64 | // compiler-rt. It is an optimised variant of the common method |
| 65 | // `(a + b) * (c + d) = ac + ad + bc + bd`. |
| 66 | // |
| 67 | // For some reason LLVM can optimise the C version very well, but |
| 68 | // keeps shuffling registers in this Rust translation. |
| 69 | macro_rules! wmul_impl_large { |
| 70 | ($ty:ty, $half:expr) => { |
| 71 | impl WideningMultiply for $ty { |
| 72 | type Output = ($ty, $ty); |
| 73 | |
| 74 | #[inline(always)] |
| 75 | fn wmul(self, b: $ty) -> Self::Output { |
| 76 | const LOWER_MASK: $ty = !0 >> $half; |
| 77 | let mut low = (self & LOWER_MASK).wrapping_mul(b & LOWER_MASK); |
| 78 | let mut t = low >> $half; |
| 79 | low &= LOWER_MASK; |
| 80 | t += (self >> $half).wrapping_mul(b & LOWER_MASK); |
| 81 | low += (t & LOWER_MASK) << $half; |
| 82 | let mut high = t >> $half; |
| 83 | t = low >> $half; |
| 84 | low &= LOWER_MASK; |
| 85 | t += (b >> $half).wrapping_mul(self & LOWER_MASK); |
| 86 | low += (t & LOWER_MASK) << $half; |
| 87 | high += t >> $half; |
| 88 | high += (self >> $half).wrapping_mul(b >> $half); |
| 89 | |
| 90 | (high, low) |
| 91 | } |
| 92 | } |
| 93 | }; |
| 94 | |
| 95 | // simd bulk implementation |
| 96 | (($($ty:ty,)+) $scalar:ty, $half:expr) => { |
| 97 | $( |
| 98 | impl WideningMultiply for $ty { |
| 99 | type Output = ($ty, $ty); |
| 100 | |
| 101 | #[inline(always)] |
| 102 | fn wmul(self, b: $ty) -> Self::Output { |
| 103 | // needs wrapping multiplication |
| 104 | let lower_mask = <$ty>::splat(!0 >> $half); |
| 105 | let half = <$ty>::splat($half); |
| 106 | let mut low = (self & lower_mask) * (b & lower_mask); |
| 107 | let mut t = low >> half; |
| 108 | low &= lower_mask; |
| 109 | t += (self >> half) * (b & lower_mask); |
| 110 | low += (t & lower_mask) << half; |
| 111 | let mut high = t >> half; |
| 112 | t = low >> half; |
| 113 | low &= lower_mask; |
| 114 | t += (b >> half) * (self & lower_mask); |
| 115 | low += (t & lower_mask) << half; |
| 116 | high += t >> half; |
| 117 | high += (self >> half) * (b >> half); |
| 118 | |
| 119 | (high, low) |
| 120 | } |
| 121 | } |
| 122 | )+ |
| 123 | }; |
| 124 | } |
| 125 | wmul_impl_large! { u128, 64 } |
| 126 | |
| 127 | macro_rules! wmul_impl_usize { |
| 128 | ($ty:ty) => { |
| 129 | impl WideningMultiply for usize { |
| 130 | type Output = (usize, usize); |
| 131 | |
| 132 | #[inline(always)] |
| 133 | fn wmul(self, x: usize) -> Self::Output { |
| 134 | let (high, low) = (self as $ty).wmul(x as $ty); |
| 135 | (high as usize, low as usize) |
| 136 | } |
| 137 | } |
| 138 | }; |
| 139 | } |
| 140 | #[cfg (target_pointer_width = "16" )] |
| 141 | wmul_impl_usize! { u16 } |
| 142 | #[cfg (target_pointer_width = "32" )] |
| 143 | wmul_impl_usize! { u32 } |
| 144 | #[cfg (target_pointer_width = "64" )] |
| 145 | wmul_impl_usize! { u64 } |
| 146 | |
| 147 | #[cfg (feature = "simd_support" )] |
| 148 | mod simd_wmul { |
| 149 | use super::*; |
| 150 | #[cfg (target_arch = "x86" )] |
| 151 | use core::arch::x86::*; |
| 152 | #[cfg (target_arch = "x86_64" )] |
| 153 | use core::arch::x86_64::*; |
| 154 | |
| 155 | wmul_impl! { |
| 156 | (u8x4, u16x4), |
| 157 | (u8x8, u16x8), |
| 158 | (u8x16, u16x16), |
| 159 | (u8x32, u16x32), |
| 160 | (u8x64, Simd<u16, 64>),, |
| 161 | 8 |
| 162 | } |
| 163 | |
| 164 | wmul_impl! { (u16x2, u32x2),, 16 } |
| 165 | wmul_impl! { (u16x4, u32x4),, 16 } |
| 166 | #[cfg (not(target_feature = "sse2" ))] |
| 167 | wmul_impl! { (u16x8, u32x8),, 16 } |
| 168 | #[cfg (not(target_feature = "avx2" ))] |
| 169 | wmul_impl! { (u16x16, u32x16),, 16 } |
| 170 | #[cfg (not(target_feature = "avx512bw" ))] |
| 171 | wmul_impl! { (u16x32, Simd<u32, 32>),, 16 } |
| 172 | |
| 173 | // 16-bit lane widths allow use of the x86 `mulhi` instructions, which |
| 174 | // means `wmul` can be implemented with only two instructions. |
| 175 | #[allow (unused_macros)] |
| 176 | macro_rules! wmul_impl_16 { |
| 177 | ($ty:ident, $mulhi:ident, $mullo:ident) => { |
| 178 | impl WideningMultiply for $ty { |
| 179 | type Output = ($ty, $ty); |
| 180 | |
| 181 | #[inline(always)] |
| 182 | fn wmul(self, x: $ty) -> Self::Output { |
| 183 | let hi = unsafe { $mulhi(self.into(), x.into()) }.into(); |
| 184 | let lo = unsafe { $mullo(self.into(), x.into()) }.into(); |
| 185 | (hi, lo) |
| 186 | } |
| 187 | } |
| 188 | }; |
| 189 | } |
| 190 | |
| 191 | #[cfg (target_feature = "sse2" )] |
| 192 | wmul_impl_16! { u16x8, _mm_mulhi_epu16, _mm_mullo_epi16 } |
| 193 | #[cfg (target_feature = "avx2" )] |
| 194 | wmul_impl_16! { u16x16, _mm256_mulhi_epu16, _mm256_mullo_epi16 } |
| 195 | #[cfg (target_feature = "avx512bw" )] |
| 196 | wmul_impl_16! { u16x32, _mm512_mulhi_epu16, _mm512_mullo_epi16 } |
| 197 | |
| 198 | wmul_impl! { |
| 199 | (u32x2, u64x2), |
| 200 | (u32x4, u64x4), |
| 201 | (u32x8, u64x8), |
| 202 | (u32x16, Simd<u64, 16>),, |
| 203 | 32 |
| 204 | } |
| 205 | |
| 206 | wmul_impl_large! { (u64x2, u64x4, u64x8,) u64, 32 } |
| 207 | } |
| 208 | |
| 209 | /// Helper trait when dealing with scalar and SIMD floating point types. |
| 210 | pub(crate) trait FloatSIMDUtils { |
| 211 | // `PartialOrd` for vectors compares lexicographically. We want to compare all |
| 212 | // the individual SIMD lanes instead, and get the combined result over all |
| 213 | // lanes. This is possible using something like `a.lt(b).all()`, but we |
| 214 | // implement it as a trait so we can write the same code for `f32` and `f64`. |
| 215 | // Only the comparison functions we need are implemented. |
| 216 | fn all_lt(self, other: Self) -> bool; |
| 217 | fn all_le(self, other: Self) -> bool; |
| 218 | fn all_finite(self) -> bool; |
| 219 | |
| 220 | type Mask; |
| 221 | fn gt_mask(self, other: Self) -> Self::Mask; |
| 222 | |
| 223 | // Decrease all lanes where the mask is `true` to the next lower value |
| 224 | // representable by the floating-point type. At least one of the lanes |
| 225 | // must be set. |
| 226 | fn decrease_masked(self, mask: Self::Mask) -> Self; |
| 227 | |
| 228 | // Convert from int value. Conversion is done while retaining the numerical |
| 229 | // value, not by retaining the binary representation. |
| 230 | type UInt; |
| 231 | fn cast_from_int(i: Self::UInt) -> Self; |
| 232 | } |
| 233 | |
| 234 | #[cfg (test)] |
| 235 | pub(crate) trait FloatSIMDScalarUtils: FloatSIMDUtils { |
| 236 | type Scalar; |
| 237 | |
| 238 | fn replace(self, index: usize, new_value: Self::Scalar) -> Self; |
| 239 | fn extract(self, index: usize) -> Self::Scalar; |
| 240 | } |
| 241 | |
| 242 | /// Implement functions on f32/f64 to give them APIs similar to SIMD types |
| 243 | pub(crate) trait FloatAsSIMD: Sized { |
| 244 | #[cfg (test)] |
| 245 | const LEN: usize = 1; |
| 246 | |
| 247 | #[inline (always)] |
| 248 | fn splat(scalar: Self) -> Self { |
| 249 | scalar |
| 250 | } |
| 251 | } |
| 252 | |
| 253 | pub(crate) trait IntAsSIMD: Sized { |
| 254 | #[inline (always)] |
| 255 | fn splat(scalar: Self) -> Self { |
| 256 | scalar |
| 257 | } |
| 258 | } |
| 259 | |
| 260 | impl IntAsSIMD for u32 {} |
| 261 | impl IntAsSIMD for u64 {} |
| 262 | |
| 263 | pub(crate) trait BoolAsSIMD: Sized { |
| 264 | fn any(self) -> bool; |
| 265 | } |
| 266 | |
| 267 | impl BoolAsSIMD for bool { |
| 268 | #[inline (always)] |
| 269 | fn any(self) -> bool { |
| 270 | self |
| 271 | } |
| 272 | } |
| 273 | |
| 274 | macro_rules! scalar_float_impl { |
| 275 | ($ty:ident, $uty:ident) => { |
| 276 | impl FloatSIMDUtils for $ty { |
| 277 | type Mask = bool; |
| 278 | type UInt = $uty; |
| 279 | |
| 280 | #[inline(always)] |
| 281 | fn all_lt(self, other: Self) -> bool { |
| 282 | self < other |
| 283 | } |
| 284 | |
| 285 | #[inline(always)] |
| 286 | fn all_le(self, other: Self) -> bool { |
| 287 | self <= other |
| 288 | } |
| 289 | |
| 290 | #[inline(always)] |
| 291 | fn all_finite(self) -> bool { |
| 292 | self.is_finite() |
| 293 | } |
| 294 | |
| 295 | #[inline(always)] |
| 296 | fn gt_mask(self, other: Self) -> Self::Mask { |
| 297 | self > other |
| 298 | } |
| 299 | |
| 300 | #[inline(always)] |
| 301 | fn decrease_masked(self, mask: Self::Mask) -> Self { |
| 302 | debug_assert!(mask, "At least one lane must be set" ); |
| 303 | <$ty>::from_bits(self.to_bits() - 1) |
| 304 | } |
| 305 | |
| 306 | #[inline] |
| 307 | fn cast_from_int(i: Self::UInt) -> Self { |
| 308 | i as $ty |
| 309 | } |
| 310 | } |
| 311 | |
| 312 | #[cfg(test)] |
| 313 | impl FloatSIMDScalarUtils for $ty { |
| 314 | type Scalar = $ty; |
| 315 | |
| 316 | #[inline] |
| 317 | fn replace(self, index: usize, new_value: Self::Scalar) -> Self { |
| 318 | debug_assert_eq!(index, 0); |
| 319 | new_value |
| 320 | } |
| 321 | |
| 322 | #[inline] |
| 323 | fn extract(self, index: usize) -> Self::Scalar { |
| 324 | debug_assert_eq!(index, 0); |
| 325 | self |
| 326 | } |
| 327 | } |
| 328 | |
| 329 | impl FloatAsSIMD for $ty {} |
| 330 | }; |
| 331 | } |
| 332 | |
| 333 | scalar_float_impl!(f32, u32); |
| 334 | scalar_float_impl!(f64, u64); |
| 335 | |
| 336 | #[cfg (feature = "simd_support" )] |
| 337 | macro_rules! simd_impl { |
| 338 | ($fty:ident, $uty:ident) => { |
| 339 | impl<const LANES: usize> FloatSIMDUtils for Simd<$fty, LANES> |
| 340 | where |
| 341 | LaneCount<LANES>: SupportedLaneCount, |
| 342 | { |
| 343 | type Mask = Mask<<$fty as SimdElement>::Mask, LANES>; |
| 344 | type UInt = Simd<$uty, LANES>; |
| 345 | |
| 346 | #[inline(always)] |
| 347 | fn all_lt(self, other: Self) -> bool { |
| 348 | self.simd_lt(other).all() |
| 349 | } |
| 350 | |
| 351 | #[inline(always)] |
| 352 | fn all_le(self, other: Self) -> bool { |
| 353 | self.simd_le(other).all() |
| 354 | } |
| 355 | |
| 356 | #[inline(always)] |
| 357 | fn all_finite(self) -> bool { |
| 358 | self.is_finite().all() |
| 359 | } |
| 360 | |
| 361 | #[inline(always)] |
| 362 | fn gt_mask(self, other: Self) -> Self::Mask { |
| 363 | self.simd_gt(other) |
| 364 | } |
| 365 | |
| 366 | #[inline(always)] |
| 367 | fn decrease_masked(self, mask: Self::Mask) -> Self { |
| 368 | // Casting a mask into ints will produce all bits set for |
| 369 | // true, and 0 for false. Adding that to the binary |
| 370 | // representation of a float means subtracting one from |
| 371 | // the binary representation, resulting in the next lower |
| 372 | // value representable by $fty. This works even when the |
| 373 | // current value is infinity. |
| 374 | debug_assert!(mask.any(), "At least one lane must be set" ); |
| 375 | Self::from_bits(self.to_bits() + mask.to_int().cast()) |
| 376 | } |
| 377 | |
| 378 | #[inline] |
| 379 | fn cast_from_int(i: Self::UInt) -> Self { |
| 380 | i.cast() |
| 381 | } |
| 382 | } |
| 383 | |
| 384 | #[cfg(test)] |
| 385 | impl<const LANES: usize> FloatSIMDScalarUtils for Simd<$fty, LANES> |
| 386 | where |
| 387 | LaneCount<LANES>: SupportedLaneCount, |
| 388 | { |
| 389 | type Scalar = $fty; |
| 390 | |
| 391 | #[inline] |
| 392 | fn replace(mut self, index: usize, new_value: Self::Scalar) -> Self { |
| 393 | self.as_mut_array()[index] = new_value; |
| 394 | self |
| 395 | } |
| 396 | |
| 397 | #[inline] |
| 398 | fn extract(self, index: usize) -> Self::Scalar { |
| 399 | self.as_array()[index] |
| 400 | } |
| 401 | } |
| 402 | }; |
| 403 | } |
| 404 | |
| 405 | #[cfg (feature = "simd_support" )] |
| 406 | simd_impl!(f32, u32); |
| 407 | #[cfg (feature = "simd_support" )] |
| 408 | simd_impl!(f64, u64); |
| 409 | |