1 | //===-- Square root of IEEE 754 floating point numbers ----------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_SQRT_H |
10 | #define LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_SQRT_H |
11 | |
12 | #include "sqrt_80_bit_long_double.h" |
13 | #include "src/__support/CPP/bit.h" // countl_zero |
14 | #include "src/__support/CPP/type_traits.h" |
15 | #include "src/__support/FPUtil/FEnvImpl.h" |
16 | #include "src/__support/FPUtil/FPBits.h" |
17 | #include "src/__support/FPUtil/rounding_mode.h" |
18 | #include "src/__support/common.h" |
19 | #include "src/__support/uint128.h" |
20 | |
21 | namespace LIBC_NAMESPACE { |
22 | namespace fputil { |
23 | |
24 | namespace internal { |
25 | |
26 | template <typename T> struct SpecialLongDouble { |
27 | static constexpr bool VALUE = false; |
28 | }; |
29 | |
30 | #if defined(LIBC_TYPES_LONG_DOUBLE_IS_X86_FLOAT80) |
31 | template <> struct SpecialLongDouble<long double> { |
32 | static constexpr bool VALUE = true; |
33 | }; |
34 | #endif // LIBC_TYPES_LONG_DOUBLE_IS_X86_FLOAT80 |
35 | |
36 | template <typename T> |
37 | LIBC_INLINE void normalize(int &exponent, |
38 | typename FPBits<T>::StorageType &mantissa) { |
39 | const int shift = |
40 | cpp::countl_zero(mantissa) - |
41 | (8 * static_cast<int>(sizeof(mantissa)) - 1 - FPBits<T>::FRACTION_LEN); |
42 | exponent -= shift; |
43 | mantissa <<= shift; |
44 | } |
45 | |
46 | #ifdef LIBC_TYPES_LONG_DOUBLE_IS_FLOAT64 |
47 | template <> |
48 | LIBC_INLINE void normalize<long double>(int &exponent, uint64_t &mantissa) { |
49 | normalize<double>(exponent, mantissa); |
50 | } |
51 | #elif !defined(LIBC_TYPES_LONG_DOUBLE_IS_X86_FLOAT80) |
52 | template <> |
53 | LIBC_INLINE void normalize<long double>(int &exponent, UInt128 &mantissa) { |
54 | const uint64_t hi_bits = static_cast<uint64_t>(mantissa >> 64); |
55 | const int shift = |
56 | hi_bits ? (cpp::countl_zero(hi_bits) - 15) |
57 | : (cpp::countl_zero(static_cast<uint64_t>(mantissa)) + 49); |
58 | exponent -= shift; |
59 | mantissa <<= shift; |
60 | } |
61 | #endif |
62 | |
63 | } // namespace internal |
64 | |
65 | // Correctly rounded IEEE 754 SQRT for all rounding modes. |
66 | // Shift-and-add algorithm. |
67 | template <typename T> |
68 | LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<T>, T> sqrt(T x) { |
69 | |
70 | if constexpr (internal::SpecialLongDouble<T>::VALUE) { |
71 | // Special 80-bit long double. |
72 | return x86::sqrt(x); |
73 | } else { |
74 | // IEEE floating points formats. |
75 | using FPBits_t = typename fputil::FPBits<T>; |
76 | using StorageType = typename FPBits_t::StorageType; |
77 | constexpr StorageType ONE = StorageType(1) << FPBits_t::FRACTION_LEN; |
78 | constexpr auto FLT_NAN = FPBits_t::quiet_nan().get_val(); |
79 | |
80 | FPBits_t bits(x); |
81 | |
82 | if (bits == FPBits_t::inf(Sign::POS) || bits.is_zero() || bits.is_nan()) { |
83 | // sqrt(+Inf) = +Inf |
84 | // sqrt(+0) = +0 |
85 | // sqrt(-0) = -0 |
86 | // sqrt(NaN) = NaN |
87 | // sqrt(-NaN) = -NaN |
88 | return x; |
89 | } else if (bits.is_neg()) { |
90 | // sqrt(-Inf) = NaN |
91 | // sqrt(-x) = NaN |
92 | return FLT_NAN; |
93 | } else { |
94 | int x_exp = bits.get_exponent(); |
95 | StorageType x_mant = bits.get_mantissa(); |
96 | |
97 | // Step 1a: Normalize denormal input and append hidden bit to the mantissa |
98 | if (bits.is_subnormal()) { |
99 | ++x_exp; // let x_exp be the correct exponent of ONE bit. |
100 | internal::normalize<T>(x_exp, x_mant); |
101 | } else { |
102 | x_mant |= ONE; |
103 | } |
104 | |
105 | // Step 1b: Make sure the exponent is even. |
106 | if (x_exp & 1) { |
107 | --x_exp; |
108 | x_mant <<= 1; |
109 | } |
110 | |
111 | // After step 1b, x = 2^(x_exp) * x_mant, where x_exp is even, and |
112 | // 1 <= x_mant < 4. So sqrt(x) = 2^(x_exp / 2) * y, with 1 <= y < 2. |
113 | // Notice that the output of sqrt is always in the normal range. |
114 | // To perform shift-and-add algorithm to find y, let denote: |
115 | // y(n) = 1.y_1 y_2 ... y_n, we can define the nth residue to be: |
116 | // r(n) = 2^n ( x_mant - y(n)^2 ). |
117 | // That leads to the following recurrence formula: |
118 | // r(n) = 2*r(n-1) - y_n*[ 2*y(n-1) + 2^(-n-1) ] |
119 | // with the initial conditions: y(0) = 1, and r(0) = x - 1. |
120 | // So the nth digit y_n of the mantissa of sqrt(x) can be found by: |
121 | // y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1) |
122 | // 0 otherwise. |
123 | StorageType y = ONE; |
124 | StorageType r = x_mant - ONE; |
125 | |
126 | for (StorageType current_bit = ONE >> 1; current_bit; current_bit >>= 1) { |
127 | r <<= 1; |
128 | StorageType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1) |
129 | if (r >= tmp) { |
130 | r -= tmp; |
131 | y += current_bit; |
132 | } |
133 | } |
134 | |
135 | // We compute one more iteration in order to round correctly. |
136 | bool lsb = static_cast<bool>(y & 1); // Least significant bit |
137 | bool rb = false; // Round bit |
138 | r <<= 2; |
139 | StorageType tmp = (y << 2) + 1; |
140 | if (r >= tmp) { |
141 | r -= tmp; |
142 | rb = true; |
143 | } |
144 | |
145 | // Remove hidden bit and append the exponent field. |
146 | x_exp = ((x_exp >> 1) + FPBits_t::EXP_BIAS); |
147 | |
148 | y = (y - ONE) | |
149 | (static_cast<StorageType>(x_exp) << FPBits_t::FRACTION_LEN); |
150 | |
151 | switch (quick_get_round()) { |
152 | case FE_TONEAREST: |
153 | // Round to nearest, ties to even |
154 | if (rb && (lsb || (r != 0))) |
155 | ++y; |
156 | break; |
157 | case FE_UPWARD: |
158 | if (rb || (r != 0)) |
159 | ++y; |
160 | break; |
161 | } |
162 | |
163 | return cpp::bit_cast<T>(y); |
164 | } |
165 | } |
166 | } |
167 | |
168 | } // namespace fputil |
169 | } // namespace LIBC_NAMESPACE |
170 | |
171 | #endif // LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_SQRT_H |
172 | |