Warning: This file is not a C or C++ file. It does not have highlighting.
| 1 | //===-- Common header for FMA implementations -------------------*- C++ -*-===// |
|---|---|
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #ifndef LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_FMA_H |
| 10 | #define LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_FMA_H |
| 11 | |
| 12 | #include "src/__support/CPP/bit.h" |
| 13 | #include "src/__support/CPP/limits.h" |
| 14 | #include "src/__support/CPP/type_traits.h" |
| 15 | #include "src/__support/FPUtil/BasicOperations.h" |
| 16 | #include "src/__support/FPUtil/FPBits.h" |
| 17 | #include "src/__support/FPUtil/cast.h" |
| 18 | #include "src/__support/FPUtil/dyadic_float.h" |
| 19 | #include "src/__support/FPUtil/rounding_mode.h" |
| 20 | #include "src/__support/big_int.h" |
| 21 | #include "src/__support/macros/attributes.h" // LIBC_INLINE |
| 22 | #include "src/__support/macros/config.h" |
| 23 | #include "src/__support/macros/optimization.h" // LIBC_UNLIKELY |
| 24 | |
| 25 | #include "hdr/fenv_macros.h" |
| 26 | |
| 27 | namespace LIBC_NAMESPACE_DECL { |
| 28 | namespace fputil { |
| 29 | namespace generic { |
| 30 | |
| 31 | template <typename OutType, typename InType> |
| 32 | LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<OutType> && |
| 33 | cpp::is_floating_point_v<InType> && |
| 34 | sizeof(OutType) <= sizeof(InType), |
| 35 | OutType> |
| 36 | fma(InType x, InType y, InType z); |
| 37 | |
| 38 | // TODO(lntue): Implement fmaf that is correctly rounded to all rounding modes. |
| 39 | // The implementation below only is only correct for the default rounding mode, |
| 40 | // round-to-nearest tie-to-even. |
| 41 | template <> LIBC_INLINE float fma<float>(float x, float y, float z) { |
| 42 | // Product is exact. |
| 43 | double prod = static_cast<double>(x) * static_cast<double>(y); |
| 44 | double z_d = static_cast<double>(z); |
| 45 | double sum = prod + z_d; |
| 46 | fputil::FPBits<double> bit_prod(prod), bitz(z_d), bit_sum(sum); |
| 47 | |
| 48 | if (!(bit_sum.is_inf_or_nan() || bit_sum.is_zero())) { |
| 49 | // Since the sum is computed in double precision, rounding might happen |
| 50 | // (for instance, when bitz.exponent > bit_prod.exponent + 5, or |
| 51 | // bit_prod.exponent > bitz.exponent + 40). In that case, when we round |
| 52 | // the sum back to float, double rounding error might occur. |
| 53 | // A concrete example of this phenomenon is as follows: |
| 54 | // x = y = 1 + 2^(-12), z = 2^(-53) |
| 55 | // The exact value of x*y + z is 1 + 2^(-11) + 2^(-24) + 2^(-53) |
| 56 | // So when rounding to float, fmaf(x, y, z) = 1 + 2^(-11) + 2^(-23) |
| 57 | // On the other hand, with the default rounding mode, |
| 58 | // double(x*y + z) = 1 + 2^(-11) + 2^(-24) |
| 59 | // and casting again to float gives us: |
| 60 | // float(double(x*y + z)) = 1 + 2^(-11). |
| 61 | // |
| 62 | // In order to correct this possible double rounding error, first we use |
| 63 | // Dekker's 2Sum algorithm to find t such that sum - t = prod + z exactly, |
| 64 | // assuming the (default) rounding mode is round-to-the-nearest, |
| 65 | // tie-to-even. Moreover, t satisfies the condition that t < eps(sum), |
| 66 | // i.e., t.exponent < sum.exponent - 52. So if t is not 0, meaning rounding |
| 67 | // occurs when computing the sum, we just need to use t to adjust (any) last |
| 68 | // bit of sum, so that the sticky bits used when rounding sum to float are |
| 69 | // correct (when it matters). |
| 70 | fputil::FPBits<double> t( |
| 71 | (bit_prod.get_biased_exponent() >= bitz.get_biased_exponent()) |
| 72 | ? ((bit_sum.get_val() - bit_prod.get_val()) - bitz.get_val()) |
| 73 | : ((bit_sum.get_val() - bitz.get_val()) - bit_prod.get_val())); |
| 74 | |
| 75 | // Update sticky bits if t != 0.0 and the least (52 - 23 - 1 = 28) bits are |
| 76 | // zero. |
| 77 | if (!t.is_zero() && ((bit_sum.get_mantissa() & 0xfff'ffffULL) == 0)) { |
| 78 | if (bit_sum.sign() != t.sign()) |
| 79 | bit_sum.set_mantissa(bit_sum.get_mantissa() + 1); |
| 80 | else if (bit_sum.get_mantissa()) |
| 81 | bit_sum.set_mantissa(bit_sum.get_mantissa() - 1); |
| 82 | } |
| 83 | } |
| 84 | |
| 85 | return static_cast<float>(bit_sum.get_val()); |
| 86 | } |
| 87 | |
| 88 | namespace internal { |
| 89 | |
| 90 | // Extract the sticky bits and shift the `mantissa` to the right by |
| 91 | // `shift_length`. |
| 92 | template <typename T> |
| 93 | LIBC_INLINE cpp::enable_if_t<is_unsigned_integral_or_big_int_v<T>, bool> |
| 94 | shift_mantissa(int shift_length, T &mant) { |
| 95 | if (shift_length >= cpp::numeric_limits<T>::digits) { |
| 96 | mant = 0; |
| 97 | return true; // prod_mant is non-zero. |
| 98 | } |
| 99 | T mask = (T(1) << shift_length) - 1; |
| 100 | bool sticky_bits = (mant & mask) != 0; |
| 101 | mant >>= shift_length; |
| 102 | return sticky_bits; |
| 103 | } |
| 104 | |
| 105 | } // namespace internal |
| 106 | |
| 107 | template <typename OutType, typename InType> |
| 108 | LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<OutType> && |
| 109 | cpp::is_floating_point_v<InType> && |
| 110 | sizeof(OutType) <= sizeof(InType), |
| 111 | OutType> |
| 112 | fma(InType x, InType y, InType z) { |
| 113 | using OutFPBits = FPBits<OutType>; |
| 114 | using OutStorageType = typename OutFPBits::StorageType; |
| 115 | using InFPBits = FPBits<InType>; |
| 116 | using InStorageType = typename InFPBits::StorageType; |
| 117 | |
| 118 | constexpr int IN_EXPLICIT_MANT_LEN = InFPBits::FRACTION_LEN + 1; |
| 119 | constexpr size_t PROD_LEN = 2 * IN_EXPLICIT_MANT_LEN; |
| 120 | constexpr size_t TMP_RESULT_LEN = cpp::bit_ceil(PROD_LEN + 1); |
| 121 | using TmpResultType = UInt<TMP_RESULT_LEN>; |
| 122 | using DyadicFloat = DyadicFloat<TMP_RESULT_LEN>; |
| 123 | |
| 124 | InFPBits x_bits(x), y_bits(y), z_bits(z); |
| 125 | |
| 126 | if (LIBC_UNLIKELY(x_bits.is_nan() || y_bits.is_nan() || z_bits.is_nan())) { |
| 127 | if (x_bits.is_nan() || y_bits.is_nan()) { |
| 128 | if (x_bits.is_signaling_nan() || y_bits.is_signaling_nan() || |
| 129 | z_bits.is_signaling_nan()) |
| 130 | raise_except_if_required(FE_INVALID); |
| 131 | |
| 132 | if (x_bits.is_quiet_nan()) { |
| 133 | InStorageType x_payload = x_bits.get_mantissa(); |
| 134 | x_payload >>= InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN; |
| 135 | return OutFPBits::quiet_nan(x_bits.sign(), |
| 136 | static_cast<OutStorageType>(x_payload)) |
| 137 | .get_val(); |
| 138 | } |
| 139 | |
| 140 | if (y_bits.is_quiet_nan()) { |
| 141 | InStorageType y_payload = y_bits.get_mantissa(); |
| 142 | y_payload >>= InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN; |
| 143 | return OutFPBits::quiet_nan(y_bits.sign(), |
| 144 | static_cast<OutStorageType>(y_payload)) |
| 145 | .get_val(); |
| 146 | } |
| 147 | |
| 148 | if (z_bits.is_quiet_nan()) { |
| 149 | InStorageType z_payload = z_bits.get_mantissa(); |
| 150 | z_payload >>= InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN; |
| 151 | return OutFPBits::quiet_nan(z_bits.sign(), |
| 152 | static_cast<OutStorageType>(z_payload)) |
| 153 | .get_val(); |
| 154 | } |
| 155 | |
| 156 | return OutFPBits::quiet_nan().get_val(); |
| 157 | } |
| 158 | } |
| 159 | |
| 160 | if (LIBC_UNLIKELY(x == 0 || y == 0 || z == 0)) |
| 161 | return cast<OutType>(x * y + z); |
| 162 | |
| 163 | int x_exp = 0; |
| 164 | int y_exp = 0; |
| 165 | int z_exp = 0; |
| 166 | |
| 167 | // Denormal scaling = 2^(fraction length). |
| 168 | constexpr InStorageType IMPLICIT_MASK = |
| 169 | InFPBits::SIG_MASK - InFPBits::FRACTION_MASK; |
| 170 | |
| 171 | constexpr InType DENORMAL_SCALING = |
| 172 | InFPBits::create_value( |
| 173 | Sign::POS, InFPBits::FRACTION_LEN + InFPBits::EXP_BIAS, IMPLICIT_MASK) |
| 174 | .get_val(); |
| 175 | |
| 176 | // Normalize denormal inputs. |
| 177 | if (LIBC_UNLIKELY(InFPBits(x).is_subnormal())) { |
| 178 | x_exp -= InFPBits::FRACTION_LEN; |
| 179 | x *= DENORMAL_SCALING; |
| 180 | } |
| 181 | if (LIBC_UNLIKELY(InFPBits(y).is_subnormal())) { |
| 182 | y_exp -= InFPBits::FRACTION_LEN; |
| 183 | y *= DENORMAL_SCALING; |
| 184 | } |
| 185 | if (LIBC_UNLIKELY(InFPBits(z).is_subnormal())) { |
| 186 | z_exp -= InFPBits::FRACTION_LEN; |
| 187 | z *= DENORMAL_SCALING; |
| 188 | } |
| 189 | |
| 190 | x_bits = InFPBits(x); |
| 191 | y_bits = InFPBits(y); |
| 192 | z_bits = InFPBits(z); |
| 193 | const Sign z_sign = z_bits.sign(); |
| 194 | Sign prod_sign = (x_bits.sign() == y_bits.sign()) ? Sign::POS : Sign::NEG; |
| 195 | x_exp += x_bits.get_biased_exponent(); |
| 196 | y_exp += y_bits.get_biased_exponent(); |
| 197 | z_exp += z_bits.get_biased_exponent(); |
| 198 | |
| 199 | if (LIBC_UNLIKELY(x_exp == InFPBits::MAX_BIASED_EXPONENT || |
| 200 | y_exp == InFPBits::MAX_BIASED_EXPONENT || |
| 201 | z_exp == InFPBits::MAX_BIASED_EXPONENT)) |
| 202 | return cast<OutType>(x * y + z); |
| 203 | |
| 204 | // Extract mantissa and append hidden leading bits. |
| 205 | InStorageType x_mant = x_bits.get_explicit_mantissa(); |
| 206 | InStorageType y_mant = y_bits.get_explicit_mantissa(); |
| 207 | TmpResultType z_mant = z_bits.get_explicit_mantissa(); |
| 208 | |
| 209 | // If the exponent of the product x*y > the exponent of z, then no extra |
| 210 | // precision beside the entire product x*y is needed. On the other hand, when |
| 211 | // the exponent of z >= the exponent of the product x*y, the worst-case that |
| 212 | // we need extra precision is when there is cancellation and the most |
| 213 | // significant bit of the product is aligned exactly with the second most |
| 214 | // significant bit of z: |
| 215 | // z : 10aa...a |
| 216 | // - prod : 1bb...bb....b |
| 217 | // In that case, in order to store the exact result, we need at least |
| 218 | // (Length of prod) - (Fraction length of z) |
| 219 | // = 2*(Length of input explicit mantissa) - (Fraction length of z) bits. |
| 220 | // Overall, before aligning the mantissas and exponents, we can simply left- |
| 221 | // shift the mantissa of z by that amount. After that, it is enough to align |
| 222 | // the least significant bit, given that we keep track of the round and sticky |
| 223 | // bits after the least significant bit. |
| 224 | |
| 225 | TmpResultType prod_mant = TmpResultType(x_mant) * y_mant; |
| 226 | int prod_lsb_exp = |
| 227 | x_exp + y_exp - (InFPBits::EXP_BIAS + 2 * InFPBits::FRACTION_LEN); |
| 228 | |
| 229 | constexpr int RESULT_MIN_LEN = PROD_LEN - InFPBits::FRACTION_LEN; |
| 230 | z_mant <<= RESULT_MIN_LEN; |
| 231 | int z_lsb_exp = z_exp - (InFPBits::FRACTION_LEN + RESULT_MIN_LEN); |
| 232 | bool sticky_bits = false; |
| 233 | bool z_shifted = false; |
| 234 | |
| 235 | // Align exponents. |
| 236 | if (prod_lsb_exp < z_lsb_exp) { |
| 237 | sticky_bits = internal::shift_mantissa(z_lsb_exp - prod_lsb_exp, prod_mant); |
| 238 | prod_lsb_exp = z_lsb_exp; |
| 239 | } else if (z_lsb_exp < prod_lsb_exp) { |
| 240 | z_shifted = true; |
| 241 | sticky_bits = internal::shift_mantissa(prod_lsb_exp - z_lsb_exp, z_mant); |
| 242 | } |
| 243 | |
| 244 | // Perform the addition: |
| 245 | // (-1)^prod_sign * prod_mant + (-1)^z_sign * z_mant. |
| 246 | // The final result will be stored in prod_sign and prod_mant. |
| 247 | if (prod_sign == z_sign) { |
| 248 | // Effectively an addition. |
| 249 | prod_mant += z_mant; |
| 250 | } else { |
| 251 | // Subtraction cases. |
| 252 | if (prod_mant >= z_mant) { |
| 253 | if (z_shifted && sticky_bits) { |
| 254 | // Add 1 more to the subtrahend so that the sticky bits remain |
| 255 | // positive. This would simplify the rounding logic. |
| 256 | ++z_mant; |
| 257 | } |
| 258 | prod_mant -= z_mant; |
| 259 | } else { |
| 260 | if (!z_shifted && sticky_bits) { |
| 261 | // Add 1 more to the subtrahend so that the sticky bits remain |
| 262 | // positive. This would simplify the rounding logic. |
| 263 | ++prod_mant; |
| 264 | } |
| 265 | prod_mant = z_mant - prod_mant; |
| 266 | prod_sign = z_sign; |
| 267 | } |
| 268 | } |
| 269 | |
| 270 | if (prod_mant == 0) { |
| 271 | // When there is exact cancellation, i.e., x*y == -z exactly, return -0.0 if |
| 272 | // rounding downward and +0.0 for other rounding modes. |
| 273 | if (quick_get_round() == FE_DOWNWARD) |
| 274 | prod_sign = Sign::NEG; |
| 275 | else |
| 276 | prod_sign = Sign::POS; |
| 277 | } |
| 278 | |
| 279 | DyadicFloat result(prod_sign, prod_lsb_exp - InFPBits::EXP_BIAS, prod_mant); |
| 280 | result.mantissa |= static_cast<unsigned int>(sticky_bits); |
| 281 | return result.template as<OutType, /*ShouldSignalExceptions=*/true>(); |
| 282 | } |
| 283 | |
| 284 | } // namespace generic |
| 285 | } // namespace fputil |
| 286 | } // namespace LIBC_NAMESPACE_DECL |
| 287 | |
| 288 | #endif // LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_FMA_H |
| 289 |
Warning: This file is not a C or C++ file. It does not have highlighting.
