1//===-- Square root of IEEE 754 floating point numbers ----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_SQRT_H
10#define LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_SQRT_H
11
12#include "sqrt_80_bit_long_double.h"
13#include "src/__support/CPP/bit.h" // countl_zero
14#include "src/__support/CPP/type_traits.h"
15#include "src/__support/FPUtil/FEnvImpl.h"
16#include "src/__support/FPUtil/FPBits.h"
17#include "src/__support/FPUtil/rounding_mode.h"
18#include "src/__support/common.h"
19#include "src/__support/uint128.h"
20
21namespace LIBC_NAMESPACE {
22namespace fputil {
23
24namespace internal {
25
26template <typename T> struct SpecialLongDouble {
27 static constexpr bool VALUE = false;
28};
29
30#if defined(LIBC_TYPES_LONG_DOUBLE_IS_X86_FLOAT80)
31template <> struct SpecialLongDouble<long double> {
32 static constexpr bool VALUE = true;
33};
34#endif // LIBC_TYPES_LONG_DOUBLE_IS_X86_FLOAT80
35
36template <typename T>
37LIBC_INLINE void normalize(int &exponent,
38 typename FPBits<T>::StorageType &mantissa) {
39 const int shift =
40 cpp::countl_zero(mantissa) -
41 (8 * static_cast<int>(sizeof(mantissa)) - 1 - FPBits<T>::FRACTION_LEN);
42 exponent -= shift;
43 mantissa <<= shift;
44}
45
46#ifdef LIBC_TYPES_LONG_DOUBLE_IS_FLOAT64
47template <>
48LIBC_INLINE void normalize<long double>(int &exponent, uint64_t &mantissa) {
49 normalize<double>(exponent, mantissa);
50}
51#elif !defined(LIBC_TYPES_LONG_DOUBLE_IS_X86_FLOAT80)
52template <>
53LIBC_INLINE void normalize<long double>(int &exponent, UInt128 &mantissa) {
54 const uint64_t hi_bits = static_cast<uint64_t>(mantissa >> 64);
55 const int shift =
56 hi_bits ? (cpp::countl_zero(hi_bits) - 15)
57 : (cpp::countl_zero(static_cast<uint64_t>(mantissa)) + 49);
58 exponent -= shift;
59 mantissa <<= shift;
60}
61#endif
62
63} // namespace internal
64
65// Correctly rounded IEEE 754 SQRT for all rounding modes.
66// Shift-and-add algorithm.
67template <typename T>
68LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<T>, T> sqrt(T x) {
69
70 if constexpr (internal::SpecialLongDouble<T>::VALUE) {
71 // Special 80-bit long double.
72 return x86::sqrt(x);
73 } else {
74 // IEEE floating points formats.
75 using FPBits_t = typename fputil::FPBits<T>;
76 using StorageType = typename FPBits_t::StorageType;
77 constexpr StorageType ONE = StorageType(1) << FPBits_t::FRACTION_LEN;
78 constexpr auto FLT_NAN = FPBits_t::quiet_nan().get_val();
79
80 FPBits_t bits(x);
81
82 if (bits == FPBits_t::inf(Sign::POS) || bits.is_zero() || bits.is_nan()) {
83 // sqrt(+Inf) = +Inf
84 // sqrt(+0) = +0
85 // sqrt(-0) = -0
86 // sqrt(NaN) = NaN
87 // sqrt(-NaN) = -NaN
88 return x;
89 } else if (bits.is_neg()) {
90 // sqrt(-Inf) = NaN
91 // sqrt(-x) = NaN
92 return FLT_NAN;
93 } else {
94 int x_exp = bits.get_exponent();
95 StorageType x_mant = bits.get_mantissa();
96
97 // Step 1a: Normalize denormal input and append hidden bit to the mantissa
98 if (bits.is_subnormal()) {
99 ++x_exp; // let x_exp be the correct exponent of ONE bit.
100 internal::normalize<T>(x_exp, x_mant);
101 } else {
102 x_mant |= ONE;
103 }
104
105 // Step 1b: Make sure the exponent is even.
106 if (x_exp & 1) {
107 --x_exp;
108 x_mant <<= 1;
109 }
110
111 // After step 1b, x = 2^(x_exp) * x_mant, where x_exp is even, and
112 // 1 <= x_mant < 4. So sqrt(x) = 2^(x_exp / 2) * y, with 1 <= y < 2.
113 // Notice that the output of sqrt is always in the normal range.
114 // To perform shift-and-add algorithm to find y, let denote:
115 // y(n) = 1.y_1 y_2 ... y_n, we can define the nth residue to be:
116 // r(n) = 2^n ( x_mant - y(n)^2 ).
117 // That leads to the following recurrence formula:
118 // r(n) = 2*r(n-1) - y_n*[ 2*y(n-1) + 2^(-n-1) ]
119 // with the initial conditions: y(0) = 1, and r(0) = x - 1.
120 // So the nth digit y_n of the mantissa of sqrt(x) can be found by:
121 // y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1)
122 // 0 otherwise.
123 StorageType y = ONE;
124 StorageType r = x_mant - ONE;
125
126 for (StorageType current_bit = ONE >> 1; current_bit; current_bit >>= 1) {
127 r <<= 1;
128 StorageType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1)
129 if (r >= tmp) {
130 r -= tmp;
131 y += current_bit;
132 }
133 }
134
135 // We compute one more iteration in order to round correctly.
136 bool lsb = static_cast<bool>(y & 1); // Least significant bit
137 bool rb = false; // Round bit
138 r <<= 2;
139 StorageType tmp = (y << 2) + 1;
140 if (r >= tmp) {
141 r -= tmp;
142 rb = true;
143 }
144
145 // Remove hidden bit and append the exponent field.
146 x_exp = ((x_exp >> 1) + FPBits_t::EXP_BIAS);
147
148 y = (y - ONE) |
149 (static_cast<StorageType>(x_exp) << FPBits_t::FRACTION_LEN);
150
151 switch (quick_get_round()) {
152 case FE_TONEAREST:
153 // Round to nearest, ties to even
154 if (rb && (lsb || (r != 0)))
155 ++y;
156 break;
157 case FE_UPWARD:
158 if (rb || (r != 0))
159 ++y;
160 break;
161 }
162
163 return cpp::bit_cast<T>(y);
164 }
165 }
166}
167
168} // namespace fputil
169} // namespace LIBC_NAMESPACE
170
171#endif // LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_SQRT_H
172

source code of libc/src/__support/FPUtil/generic/sqrt.h