1 | //===-- Square root of x86 long double numbers ------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_SQRT_80_BIT_LONG_DOUBLE_H |
10 | #define LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_SQRT_80_BIT_LONG_DOUBLE_H |
11 | |
12 | #include "src/__support/CPP/bit.h" |
13 | #include "src/__support/FPUtil/FEnvImpl.h" |
14 | #include "src/__support/FPUtil/FPBits.h" |
15 | #include "src/__support/FPUtil/rounding_mode.h" |
16 | #include "src/__support/common.h" |
17 | #include "src/__support/uint128.h" |
18 | |
19 | namespace LIBC_NAMESPACE { |
20 | namespace fputil { |
21 | namespace x86 { |
22 | |
23 | LIBC_INLINE void normalize(int &exponent, UInt128 &mantissa) { |
24 | const unsigned int shift = static_cast<unsigned int>( |
25 | cpp::countl_zero(value: static_cast<uint64_t>(mantissa)) - |
26 | (8 * sizeof(uint64_t) - 1 - FPBits<long double>::FRACTION_LEN)); |
27 | exponent -= shift; |
28 | mantissa <<= shift; |
29 | } |
30 | |
31 | // if constexpr statement in sqrt.h still requires x86::sqrt to be declared |
32 | // even when it's not used. |
33 | LIBC_INLINE long double sqrt(long double x); |
34 | |
35 | // Correctly rounded SQRT for all rounding modes. |
36 | // Shift-and-add algorithm. |
37 | #if defined(LIBC_TYPES_LONG_DOUBLE_IS_X86_FLOAT80) |
38 | LIBC_INLINE long double sqrt(long double x) { |
39 | using LDBits = FPBits<long double>; |
40 | using StorageType = typename LDBits::StorageType; |
41 | constexpr StorageType ONE = StorageType(1) << int(LDBits::FRACTION_LEN); |
42 | constexpr auto LDNAN = LDBits::quiet_nan().get_val(); |
43 | |
44 | LDBits bits(x); |
45 | |
46 | if (bits == LDBits::inf(sign: Sign::POS) || bits.is_zero() || bits.is_nan()) { |
47 | // sqrt(+Inf) = +Inf |
48 | // sqrt(+0) = +0 |
49 | // sqrt(-0) = -0 |
50 | // sqrt(NaN) = NaN |
51 | // sqrt(-NaN) = -NaN |
52 | return x; |
53 | } else if (bits.is_neg()) { |
54 | // sqrt(-Inf) = NaN |
55 | // sqrt(-x) = NaN |
56 | return LDNAN; |
57 | } else { |
58 | int x_exp = bits.get_explicit_exponent(); |
59 | StorageType x_mant = bits.get_mantissa(); |
60 | |
61 | // Step 1a: Normalize denormal input |
62 | if (bits.get_implicit_bit()) { |
63 | x_mant |= ONE; |
64 | } else if (bits.is_subnormal()) { |
65 | normalize(exponent&: x_exp, mantissa&: x_mant); |
66 | } |
67 | |
68 | // Step 1b: Make sure the exponent is even. |
69 | if (x_exp & 1) { |
70 | --x_exp; |
71 | x_mant <<= 1; |
72 | } |
73 | |
74 | // After step 1b, x = 2^(x_exp) * x_mant, where x_exp is even, and |
75 | // 1 <= x_mant < 4. So sqrt(x) = 2^(x_exp / 2) * y, with 1 <= y < 2. |
76 | // Notice that the output of sqrt is always in the normal range. |
77 | // To perform shift-and-add algorithm to find y, let denote: |
78 | // y(n) = 1.y_1 y_2 ... y_n, we can define the nth residue to be: |
79 | // r(n) = 2^n ( x_mant - y(n)^2 ). |
80 | // That leads to the following recurrence formula: |
81 | // r(n) = 2*r(n-1) - y_n*[ 2*y(n-1) + 2^(-n-1) ] |
82 | // with the initial conditions: y(0) = 1, and r(0) = x - 1. |
83 | // So the nth digit y_n of the mantissa of sqrt(x) can be found by: |
84 | // y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1) |
85 | // 0 otherwise. |
86 | StorageType y = ONE; |
87 | StorageType r = x_mant - ONE; |
88 | |
89 | for (StorageType current_bit = ONE >> 1; current_bit; current_bit >>= 1) { |
90 | r <<= 1; |
91 | StorageType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1) |
92 | if (r >= tmp) { |
93 | r -= tmp; |
94 | y += current_bit; |
95 | } |
96 | } |
97 | |
98 | // We compute one more iteration in order to round correctly. |
99 | bool lsb = static_cast<bool>(y & 1); // Least significant bit |
100 | bool rb = false; // Round bit |
101 | r <<= 2; |
102 | StorageType tmp = (y << 2) + 1; |
103 | if (r >= tmp) { |
104 | r -= tmp; |
105 | rb = true; |
106 | } |
107 | |
108 | // Append the exponent field. |
109 | x_exp = ((x_exp >> 1) + LDBits::EXP_BIAS); |
110 | y |= (static_cast<StorageType>(x_exp) << (LDBits::FRACTION_LEN + 1)); |
111 | |
112 | switch (quick_get_round()) { |
113 | case FE_TONEAREST: |
114 | // Round to nearest, ties to even |
115 | if (rb && (lsb || (r != 0))) |
116 | ++y; |
117 | break; |
118 | case FE_UPWARD: |
119 | if (rb || (r != 0)) |
120 | ++y; |
121 | break; |
122 | } |
123 | |
124 | // Extract output |
125 | FPBits<long double> out(0.0L); |
126 | out.set_biased_exponent(x_exp); |
127 | out.set_implicit_bit(1); |
128 | out.set_mantissa((y & (ONE - 1))); |
129 | |
130 | return out.get_val(); |
131 | } |
132 | } |
133 | #endif // LIBC_TYPES_LONG_DOUBLE_IS_X86_FLOAT80 |
134 | |
135 | } // namespace x86 |
136 | } // namespace fputil |
137 | } // namespace LIBC_NAMESPACE |
138 | |
139 | #endif // LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_SQRT_80_BIT_LONG_DOUBLE_H |
140 | |