| 1 | /// Fused multiply-add. Computes `(self * a) + b` with only one rounding |
| 2 | /// error, yielding a more accurate result than an unfused multiply-add. |
| 3 | /// |
| 4 | /// Using `mul_add` can be more performant than an unfused multiply-add if |
| 5 | /// the target architecture has a dedicated `fma` CPU instruction. |
| 6 | /// |
| 7 | /// Note that `A` and `B` are `Self` by default, but this is not mandatory. |
| 8 | /// |
| 9 | /// # Example |
| 10 | /// |
| 11 | /// ``` |
| 12 | /// use std::f32; |
| 13 | /// |
| 14 | /// let m = 10.0_f32; |
| 15 | /// let x = 4.0_f32; |
| 16 | /// let b = 60.0_f32; |
| 17 | /// |
| 18 | /// // 100.0 |
| 19 | /// let abs_difference = (m.mul_add(x, b) - (m*x + b)).abs(); |
| 20 | /// |
| 21 | /// assert!(abs_difference <= 100.0 * f32::EPSILON); |
| 22 | /// ``` |
| 23 | pub trait MulAdd<A = Self, B = Self> { |
| 24 | /// The resulting type after applying the fused multiply-add. |
| 25 | type Output; |
| 26 | |
| 27 | /// Performs the fused multiply-add operation `(self * a) + b` |
| 28 | fn mul_add(self, a: A, b: B) -> Self::Output; |
| 29 | } |
| 30 | |
| 31 | /// The fused multiply-add assignment operation `*self = (*self * a) + b` |
| 32 | pub trait MulAddAssign<A = Self, B = Self> { |
| 33 | /// Performs the fused multiply-add assignment operation `*self = (*self * a) + b` |
| 34 | fn mul_add_assign(&mut self, a: A, b: B); |
| 35 | } |
| 36 | |
| 37 | #[cfg (any(feature = "std" , feature = "libm" ))] |
| 38 | impl MulAdd<f32, f32> for f32 { |
| 39 | type Output = Self; |
| 40 | |
| 41 | #[inline ] |
| 42 | fn mul_add(self, a: Self, b: Self) -> Self::Output { |
| 43 | <Self as crate::Float>::mul_add(self, a, b) |
| 44 | } |
| 45 | } |
| 46 | |
| 47 | #[cfg (any(feature = "std" , feature = "libm" ))] |
| 48 | impl MulAdd<f64, f64> for f64 { |
| 49 | type Output = Self; |
| 50 | |
| 51 | #[inline ] |
| 52 | fn mul_add(self, a: Self, b: Self) -> Self::Output { |
| 53 | <Self as crate::Float>::mul_add(self, a, b) |
| 54 | } |
| 55 | } |
| 56 | |
| 57 | macro_rules! mul_add_impl { |
| 58 | ($trait_name:ident for $($t:ty)*) => {$( |
| 59 | impl $trait_name for $t { |
| 60 | type Output = Self; |
| 61 | |
| 62 | #[inline] |
| 63 | fn mul_add(self, a: Self, b: Self) -> Self::Output { |
| 64 | (self * a) + b |
| 65 | } |
| 66 | } |
| 67 | )*} |
| 68 | } |
| 69 | |
| 70 | mul_add_impl!(MulAdd for isize i8 i16 i32 i64 i128); |
| 71 | mul_add_impl!(MulAdd for usize u8 u16 u32 u64 u128); |
| 72 | |
| 73 | #[cfg (any(feature = "std" , feature = "libm" ))] |
| 74 | impl MulAddAssign<f32, f32> for f32 { |
| 75 | #[inline ] |
| 76 | fn mul_add_assign(&mut self, a: Self, b: Self) { |
| 77 | *self = <Self as crate::Float>::mul_add(*self, a, b) |
| 78 | } |
| 79 | } |
| 80 | |
| 81 | #[cfg (any(feature = "std" , feature = "libm" ))] |
| 82 | impl MulAddAssign<f64, f64> for f64 { |
| 83 | #[inline ] |
| 84 | fn mul_add_assign(&mut self, a: Self, b: Self) { |
| 85 | *self = <Self as crate::Float>::mul_add(*self, a, b) |
| 86 | } |
| 87 | } |
| 88 | |
| 89 | macro_rules! mul_add_assign_impl { |
| 90 | ($trait_name:ident for $($t:ty)*) => {$( |
| 91 | impl $trait_name for $t { |
| 92 | #[inline] |
| 93 | fn mul_add_assign(&mut self, a: Self, b: Self) { |
| 94 | *self = (*self * a) + b |
| 95 | } |
| 96 | } |
| 97 | )*} |
| 98 | } |
| 99 | |
| 100 | mul_add_assign_impl!(MulAddAssign for isize i8 i16 i32 i64 i128); |
| 101 | mul_add_assign_impl!(MulAddAssign for usize u8 u16 u32 u64 u128); |
| 102 | |
| 103 | #[cfg (test)] |
| 104 | mod tests { |
| 105 | use super::*; |
| 106 | |
| 107 | #[test ] |
| 108 | fn mul_add_integer() { |
| 109 | macro_rules! test_mul_add { |
| 110 | ($($t:ident)+) => { |
| 111 | $( |
| 112 | { |
| 113 | let m: $t = 2; |
| 114 | let x: $t = 3; |
| 115 | let b: $t = 4; |
| 116 | |
| 117 | assert_eq!(MulAdd::mul_add(m, x, b), (m*x + b)); |
| 118 | } |
| 119 | )+ |
| 120 | }; |
| 121 | } |
| 122 | |
| 123 | test_mul_add!(usize u8 u16 u32 u64 isize i8 i16 i32 i64); |
| 124 | } |
| 125 | |
| 126 | #[test ] |
| 127 | #[cfg (feature = "std" )] |
| 128 | fn mul_add_float() { |
| 129 | macro_rules! test_mul_add { |
| 130 | ($($t:ident)+) => { |
| 131 | $( |
| 132 | { |
| 133 | use core::$t; |
| 134 | |
| 135 | let m: $t = 12.0; |
| 136 | let x: $t = 3.4; |
| 137 | let b: $t = 5.6; |
| 138 | |
| 139 | let abs_difference = (MulAdd::mul_add(m, x, b) - (m*x + b)).abs(); |
| 140 | |
| 141 | assert!(abs_difference <= 46.4 * $t::EPSILON); |
| 142 | } |
| 143 | )+ |
| 144 | }; |
| 145 | } |
| 146 | |
| 147 | test_mul_add!(f32 f64); |
| 148 | } |
| 149 | } |
| 150 | |