1 | /// Fused multiply-add. Computes `(self * a) + b` with only one rounding |
2 | /// error, yielding a more accurate result than an unfused multiply-add. |
3 | /// |
4 | /// Using `mul_add` can be more performant than an unfused multiply-add if |
5 | /// the target architecture has a dedicated `fma` CPU instruction. |
6 | /// |
7 | /// Note that `A` and `B` are `Self` by default, but this is not mandatory. |
8 | /// |
9 | /// # Example |
10 | /// |
11 | /// ``` |
12 | /// use std::f32; |
13 | /// |
14 | /// let m = 10.0_f32; |
15 | /// let x = 4.0_f32; |
16 | /// let b = 60.0_f32; |
17 | /// |
18 | /// // 100.0 |
19 | /// let abs_difference = (m.mul_add(x, b) - (m*x + b)).abs(); |
20 | /// |
21 | /// assert!(abs_difference <= 100.0 * f32::EPSILON); |
22 | /// ``` |
23 | pub trait MulAdd<A = Self, B = Self> { |
24 | /// The resulting type after applying the fused multiply-add. |
25 | type Output; |
26 | |
27 | /// Performs the fused multiply-add operation. |
28 | fn mul_add(self, a: A, b: B) -> Self::Output; |
29 | } |
30 | |
31 | /// The fused multiply-add assignment operation. |
32 | pub trait MulAddAssign<A = Self, B = Self> { |
33 | /// Performs the fused multiply-add operation. |
34 | fn mul_add_assign(&mut self, a: A, b: B); |
35 | } |
36 | |
37 | #[cfg (any(feature = "std" , feature = "libm" ))] |
38 | impl MulAdd<f32, f32> for f32 { |
39 | type Output = Self; |
40 | |
41 | #[inline ] |
42 | fn mul_add(self, a: Self, b: Self) -> Self::Output { |
43 | <Self as ::Float>::mul_add(self, a, b) |
44 | } |
45 | } |
46 | |
47 | #[cfg (any(feature = "std" , feature = "libm" ))] |
48 | impl MulAdd<f64, f64> for f64 { |
49 | type Output = Self; |
50 | |
51 | #[inline ] |
52 | fn mul_add(self, a: Self, b: Self) -> Self::Output { |
53 | <Self as ::Float>::mul_add(self, a, b) |
54 | } |
55 | } |
56 | |
57 | macro_rules! mul_add_impl { |
58 | ($trait_name:ident for $($t:ty)*) => {$( |
59 | impl $trait_name for $t { |
60 | type Output = Self; |
61 | |
62 | #[inline] |
63 | fn mul_add(self, a: Self, b: Self) -> Self::Output { |
64 | (self * a) + b |
65 | } |
66 | } |
67 | )*} |
68 | } |
69 | |
70 | mul_add_impl!(MulAdd for isize usize i8 u8 i16 u16 i32 u32 i64 u64); |
71 | #[cfg (has_i128)] |
72 | mul_add_impl!(MulAdd for i128 u128); |
73 | |
74 | #[cfg (any(feature = "std" , feature = "libm" ))] |
75 | impl MulAddAssign<f32, f32> for f32 { |
76 | #[inline ] |
77 | fn mul_add_assign(&mut self, a: Self, b: Self) { |
78 | *self = <Self as ::Float>::mul_add(*self, a, b) |
79 | } |
80 | } |
81 | |
82 | #[cfg (any(feature = "std" , feature = "libm" ))] |
83 | impl MulAddAssign<f64, f64> for f64 { |
84 | #[inline ] |
85 | fn mul_add_assign(&mut self, a: Self, b: Self) { |
86 | *self = <Self as ::Float>::mul_add(*self, a, b) |
87 | } |
88 | } |
89 | |
90 | macro_rules! mul_add_assign_impl { |
91 | ($trait_name:ident for $($t:ty)*) => {$( |
92 | impl $trait_name for $t { |
93 | #[inline] |
94 | fn mul_add_assign(&mut self, a: Self, b: Self) { |
95 | *self = (*self * a) + b |
96 | } |
97 | } |
98 | )*} |
99 | } |
100 | |
101 | mul_add_assign_impl!(MulAddAssign for isize usize i8 u8 i16 u16 i32 u32 i64 u64); |
102 | #[cfg (has_i128)] |
103 | mul_add_assign_impl!(MulAddAssign for i128 u128); |
104 | |
105 | #[cfg (test)] |
106 | mod tests { |
107 | use super::*; |
108 | |
109 | #[test ] |
110 | fn mul_add_integer() { |
111 | macro_rules! test_mul_add { |
112 | ($($t:ident)+) => { |
113 | $( |
114 | { |
115 | let m: $t = 2; |
116 | let x: $t = 3; |
117 | let b: $t = 4; |
118 | |
119 | assert_eq!(MulAdd::mul_add(m, x, b), (m*x + b)); |
120 | } |
121 | )+ |
122 | }; |
123 | } |
124 | |
125 | test_mul_add!(usize u8 u16 u32 u64 isize i8 i16 i32 i64); |
126 | } |
127 | |
128 | #[test ] |
129 | #[cfg (feature = "std" )] |
130 | fn mul_add_float() { |
131 | macro_rules! test_mul_add { |
132 | ($($t:ident)+) => { |
133 | $( |
134 | { |
135 | use core::$t; |
136 | |
137 | let m: $t = 12.0; |
138 | let x: $t = 3.4; |
139 | let b: $t = 5.6; |
140 | |
141 | let abs_difference = (MulAdd::mul_add(m, x, b) - (m*x + b)).abs(); |
142 | |
143 | assert!(abs_difference <= 46.4 * $t::EPSILON); |
144 | } |
145 | )+ |
146 | }; |
147 | } |
148 | |
149 | test_mul_add!(f32 f64); |
150 | } |
151 | } |
152 | |