1 | /// Fused multiply-add. Computes `(self * a) + b` with only one rounding |
2 | /// error, yielding a more accurate result than an unfused multiply-add. |
3 | /// |
4 | /// Using `mul_add` can be more performant than an unfused multiply-add if |
5 | /// the target architecture has a dedicated `fma` CPU instruction. |
6 | /// |
7 | /// Note that `A` and `B` are `Self` by default, but this is not mandatory. |
8 | /// |
9 | /// # Example |
10 | /// |
11 | /// ``` |
12 | /// use std::f32; |
13 | /// |
14 | /// let m = 10.0_f32; |
15 | /// let x = 4.0_f32; |
16 | /// let b = 60.0_f32; |
17 | /// |
18 | /// // 100.0 |
19 | /// let abs_difference = (m.mul_add(x, b) - (m*x + b)).abs(); |
20 | /// |
21 | /// assert!(abs_difference <= 100.0 * f32::EPSILON); |
22 | /// ``` |
23 | pub trait MulAdd<A = Self, B = Self> { |
24 | /// The resulting type after applying the fused multiply-add. |
25 | type Output; |
26 | |
27 | /// Performs the fused multiply-add operation `(self * a) + b` |
28 | fn mul_add(self, a: A, b: B) -> Self::Output; |
29 | } |
30 | |
31 | /// The fused multiply-add assignment operation `*self = (*self * a) + b` |
32 | pub trait MulAddAssign<A = Self, B = Self> { |
33 | /// Performs the fused multiply-add assignment operation `*self = (*self * a) + b` |
34 | fn mul_add_assign(&mut self, a: A, b: B); |
35 | } |
36 | |
37 | #[cfg (any(feature = "std" , feature = "libm" ))] |
38 | impl MulAdd<f32, f32> for f32 { |
39 | type Output = Self; |
40 | |
41 | #[inline ] |
42 | fn mul_add(self, a: Self, b: Self) -> Self::Output { |
43 | <Self as crate::Float>::mul_add(self, a, b) |
44 | } |
45 | } |
46 | |
47 | #[cfg (any(feature = "std" , feature = "libm" ))] |
48 | impl MulAdd<f64, f64> for f64 { |
49 | type Output = Self; |
50 | |
51 | #[inline ] |
52 | fn mul_add(self, a: Self, b: Self) -> Self::Output { |
53 | <Self as crate::Float>::mul_add(self, a, b) |
54 | } |
55 | } |
56 | |
57 | macro_rules! mul_add_impl { |
58 | ($trait_name:ident for $($t:ty)*) => {$( |
59 | impl $trait_name for $t { |
60 | type Output = Self; |
61 | |
62 | #[inline] |
63 | fn mul_add(self, a: Self, b: Self) -> Self::Output { |
64 | (self * a) + b |
65 | } |
66 | } |
67 | )*} |
68 | } |
69 | |
70 | mul_add_impl!(MulAdd for isize i8 i16 i32 i64 i128); |
71 | mul_add_impl!(MulAdd for usize u8 u16 u32 u64 u128); |
72 | |
73 | #[cfg (any(feature = "std" , feature = "libm" ))] |
74 | impl MulAddAssign<f32, f32> for f32 { |
75 | #[inline ] |
76 | fn mul_add_assign(&mut self, a: Self, b: Self) { |
77 | *self = <Self as crate::Float>::mul_add(*self, a, b) |
78 | } |
79 | } |
80 | |
81 | #[cfg (any(feature = "std" , feature = "libm" ))] |
82 | impl MulAddAssign<f64, f64> for f64 { |
83 | #[inline ] |
84 | fn mul_add_assign(&mut self, a: Self, b: Self) { |
85 | *self = <Self as crate::Float>::mul_add(*self, a, b) |
86 | } |
87 | } |
88 | |
89 | macro_rules! mul_add_assign_impl { |
90 | ($trait_name:ident for $($t:ty)*) => {$( |
91 | impl $trait_name for $t { |
92 | #[inline] |
93 | fn mul_add_assign(&mut self, a: Self, b: Self) { |
94 | *self = (*self * a) + b |
95 | } |
96 | } |
97 | )*} |
98 | } |
99 | |
100 | mul_add_assign_impl!(MulAddAssign for isize i8 i16 i32 i64 i128); |
101 | mul_add_assign_impl!(MulAddAssign for usize u8 u16 u32 u64 u128); |
102 | |
103 | #[cfg (test)] |
104 | mod tests { |
105 | use super::*; |
106 | |
107 | #[test ] |
108 | fn mul_add_integer() { |
109 | macro_rules! test_mul_add { |
110 | ($($t:ident)+) => { |
111 | $( |
112 | { |
113 | let m: $t = 2; |
114 | let x: $t = 3; |
115 | let b: $t = 4; |
116 | |
117 | assert_eq!(MulAdd::mul_add(m, x, b), (m*x + b)); |
118 | } |
119 | )+ |
120 | }; |
121 | } |
122 | |
123 | test_mul_add!(usize u8 u16 u32 u64 isize i8 i16 i32 i64); |
124 | } |
125 | |
126 | #[test ] |
127 | #[cfg (feature = "std" )] |
128 | fn mul_add_float() { |
129 | macro_rules! test_mul_add { |
130 | ($($t:ident)+) => { |
131 | $( |
132 | { |
133 | use core::$t; |
134 | |
135 | let m: $t = 12.0; |
136 | let x: $t = 3.4; |
137 | let b: $t = 5.6; |
138 | |
139 | let abs_difference = (MulAdd::mul_add(m, x, b) - (m*x + b)).abs(); |
140 | |
141 | assert!(abs_difference <= 46.4 * $t::EPSILON); |
142 | } |
143 | )+ |
144 | }; |
145 | } |
146 | |
147 | test_mul_add!(f32 f64); |
148 | } |
149 | } |
150 | |