1/* SPDX-License-Identifier: MIT */
2/* origin: musl src/math/sqrt.c. Ported to generic Rust algorithm in 2025, TG. */
3
4//! Generic square root algorithm.
5//!
6//! This routine operates around `m_u2`, a U.2 (fixed point with two integral bits) mantissa
7//! within the range [1, 4). A table lookup provides an initial estimate, then goldschmidt
8//! iterations at various widths are used to approach the real values.
9//!
10//! For the iterations, `r` is a U0 number that approaches `1/sqrt(m_u2)`, and `s` is a U2 number
11//! that approaches `sqrt(m_u2)`. Recall that m_u2 ∈ [1, 4).
12//!
13//! With Newton-Raphson iterations, this would be:
14//!
15//! - `w = r * r w ~ 1 / m`
16//! - `u = 3 - m * w u ~ 3 - m * w = 3 - m / m = 2`
17//! - `r = r * u / 2 r ~ r`
18//!
19//! (Note that the righthand column does not show anything analytically meaningful (i.e. r ~ r),
20//! since the value of performing one iteration is in reducing the error representable by `~`).
21//!
22//! Instead of Newton-Raphson iterations, Goldschmidt iterations are used to calculate
23//! `s = m * r`:
24//!
25//! - `s = m * r s ~ m / sqrt(m)`
26//! - `u = 3 - s * r u ~ 3 - (m / sqrt(m)) * (1 / sqrt(m)) = 3 - m / m = 2`
27//! - `r = r * u / 2 r ~ r`
28//! - `s = s * u / 2 s ~ s`
29//!
30//! The above is precise because it uses the original value `m`. There is also a faster version
31//! that performs fewer steps but does not use `m`:
32//!
33//! - `u = 3 - s * r u ~ 3 - 1`
34//! - `r = r * u / 2 r ~ r`
35//! - `s = s * u / 2 s ~ s`
36//!
37//! Rounding errors accumulate faster with the second version, so it is only used for subsequent
38//! iterations within the same width integer. The first version is always used for the first
39//! iteration at a new width in order to avoid this accumulation.
40//!
41//! Goldschmidt has the advantage over Newton-Raphson that `sqrt(x)` and `1/sqrt(x)` are
42//! computed at the same time, i.e. there is no need to calculate `1/sqrt(x)` and invert it.
43
44use crate::support::{
45 CastFrom, CastInto, DInt, Float, FpResult, HInt, Int, IntTy, MinInt, Round, Status, cold_path,
46};
47
48#[inline]
49pub fn sqrt<F>(x: F) -> F
50where
51 F: Float + SqrtHelper,
52 F::Int: HInt,
53 F::Int: From<u8>,
54 F::Int: From<F::ISet2>,
55 F::Int: CastInto<F::ISet1>,
56 F::Int: CastInto<F::ISet2>,
57 u32: CastInto<F::Int>,
58{
59 sqrt_round(x, Round::Nearest).val
60}
61
62#[inline]
63pub fn sqrt_round<F>(x: F, _round: Round) -> FpResult<F>
64where
65 F: Float + SqrtHelper,
66 F::Int: HInt,
67 F::Int: From<u8>,
68 F::Int: From<F::ISet2>,
69 F::Int: CastInto<F::ISet1>,
70 F::Int: CastInto<F::ISet2>,
71 u32: CastInto<F::Int>,
72{
73 let zero = IntTy::<F>::ZERO;
74 let one = IntTy::<F>::ONE;
75
76 let mut ix = x.to_bits();
77
78 // Top is the exponent and sign, which may or may not be shifted. If the float fits into a
79 // `u32`, we can get by without paying shifting costs.
80 let noshift = F::BITS <= u32::BITS;
81 let (mut top, special_case) = if noshift {
82 let exp_lsb = one << F::SIG_BITS;
83 let special_case = ix.wrapping_sub(exp_lsb) >= F::EXP_MASK - exp_lsb;
84 (Exp::NoShift(()), special_case)
85 } else {
86 let top = u32::cast_from(ix >> F::SIG_BITS);
87 let special_case = top.wrapping_sub(1) >= F::EXP_SAT - 1;
88 (Exp::Shifted(top), special_case)
89 };
90
91 // Handle NaN, zero, and out of domain (<= 0)
92 if special_case {
93 cold_path();
94
95 // +/-0
96 if ix << 1 == zero {
97 return FpResult::ok(x);
98 }
99
100 // Positive infinity
101 if ix == F::EXP_MASK {
102 return FpResult::ok(x);
103 }
104
105 // NaN or negative
106 if ix > F::EXP_MASK {
107 return FpResult::new(F::NAN, Status::INVALID);
108 }
109
110 // Normalize subnormals by multiplying by 1.0 << SIG_BITS (e.g. 0x1p52 for doubles).
111 let scaled = x * F::from_parts(false, F::SIG_BITS + F::EXP_BIAS, zero);
112 ix = scaled.to_bits();
113 match top {
114 Exp::Shifted(ref mut v) => {
115 *v = scaled.ex();
116 *v = (*v).wrapping_sub(F::SIG_BITS);
117 }
118 Exp::NoShift(()) => {
119 ix = ix.wrapping_sub((F::SIG_BITS << F::SIG_BITS).cast());
120 }
121 }
122 }
123
124 // Reduce arguments such that `x = 4^e * m`:
125 //
126 // - m_u2 ∈ [1, 4), a fixed point U2.BITS number
127 // - 2^e is the exponent part of the result
128 let (m_u2, exp) = match top {
129 Exp::Shifted(top) => {
130 // We now know `x` is positive, so `top` is just its (biased) exponent
131 let mut e = top;
132 // Construct a fixed point representation of the mantissa.
133 let mut m_u2 = (ix | F::IMPLICIT_BIT) << F::EXP_BITS;
134 let even = (e & 1) != 0;
135 if even {
136 m_u2 >>= 1;
137 }
138 e = (e.wrapping_add(F::EXP_SAT >> 1)) >> 1;
139 (m_u2, Exp::Shifted(e))
140 }
141 Exp::NoShift(()) => {
142 let even = ix & (one << F::SIG_BITS) != zero;
143
144 // Exponent part of the return value
145 let mut e_noshift = ix >> 1;
146 // ey &= (F::EXP_MASK << 2) >> 2; // clear the top exponent bit (result = 1.0)
147 e_noshift += (F::EXP_MASK ^ (F::SIGN_MASK >> 1)) >> 1;
148 e_noshift &= F::EXP_MASK;
149
150 let m1 = (ix << F::EXP_BITS) | F::SIGN_MASK;
151 let m0 = (ix << (F::EXP_BITS - 1)) & !F::SIGN_MASK;
152 let m_u2 = if even { m0 } else { m1 };
153
154 (m_u2, Exp::NoShift(e_noshift))
155 }
156 };
157
158 // Extract the top 6 bits of the significand with the lowest bit of the exponent.
159 let i = usize::cast_from(ix >> (F::SIG_BITS - 6)) & 0b1111111;
160
161 // Start with an initial guess for `r = 1 / sqrt(m)` from the table, and shift `m` as an
162 // initial value for `s = sqrt(m)`. See the module documentation for details.
163 let r1_u0: F::ISet1 = F::ISet1::cast_from(RSQRT_TAB[i]) << (F::ISet1::BITS - 16);
164 let s1_u2: F::ISet1 = ((m_u2) >> (F::BITS - F::ISet1::BITS)).cast();
165
166 // Perform iterations, if any, at quarter width (used for `f128`).
167 let (r1_u0, _s1_u2) = goldschmidt::<F, F::ISet1>(r1_u0, s1_u2, F::SET1_ROUNDS, false);
168
169 // Widen values and perform iterations at half width (used for `f64` and `f128`).
170 let r2_u0: F::ISet2 = F::ISet2::from(r1_u0) << (F::ISet2::BITS - F::ISet1::BITS);
171 let s2_u2: F::ISet2 = ((m_u2) >> (F::BITS - F::ISet2::BITS)).cast();
172 let (r2_u0, _s2_u2) = goldschmidt::<F, F::ISet2>(r2_u0, s2_u2, F::SET2_ROUNDS, false);
173
174 // Perform final iterations at full width (used for all float types).
175 let r_u0: F::Int = F::Int::from(r2_u0) << (F::BITS - F::ISet2::BITS);
176 let s_u2: F::Int = m_u2;
177 let (_r_u0, s_u2) = goldschmidt::<F, F::Int>(r_u0, s_u2, F::FINAL_ROUNDS, true);
178
179 // Shift back to mantissa position.
180 let mut m = s_u2 >> (F::EXP_BITS - 2);
181
182 // The musl source includes the following comment (with literals replaced):
183 //
184 // > s < sqrt(m) < s + 0x1.09p-SIG_BITS
185 // > compute nearest rounded result: the nearest result to SIG_BITS bits is either s or
186 // > s+0x1p-SIG_BITS, we can decide by comparing (2^SIG_BITS s + 0.5)^2 to 2^(2*SIG_BITS) m.
187 //
188 // Expanding this with , with `SIG_BITS = p` and adjusting based on the operations done to
189 // `d0` and `d1`:
190 //
191 // - `2^(2p)m ≟ ((2^p)m + 0.5)^2`
192 // - `2^(2p)m ≟ 2^(2p)m^2 + (2^p)m + 0.25`
193 // - `2^(2p)m - m^2 ≟ (2^(2p) - 1)m^2 + (2^p)m + 0.25`
194 // - `(1 - 2^(2p))m + m^2 ≟ (1 - 2^(2p))m^2 + (1 - 2^p)m + 0.25` (?)
195 //
196 // I do not follow how the rounding bit is extracted from this comparison with the below
197 // operations. In any case, the algorithm is well tested.
198
199 // The value needed to shift `m_u2` by to create `m*2^(2p)`. `2p = 2 * F::SIG_BITS`,
200 // `F::BITS - 2` accounts for the offset that `m_u2` already has.
201 let shift = 2 * F::SIG_BITS - (F::BITS - 2);
202
203 // `2^(2p)m - m^2`
204 let d0 = (m_u2 << shift).wrapping_sub(m.wrapping_mul(m));
205 // `m - 2^(2p)m + m^2`
206 let d1 = m.wrapping_sub(d0);
207 m += d1 >> (F::BITS - 1);
208 m &= F::SIG_MASK;
209
210 match exp {
211 Exp::Shifted(e) => m |= IntTy::<F>::cast_from(e) << F::SIG_BITS,
212 Exp::NoShift(e) => m |= e,
213 };
214
215 let mut y = F::from_bits(m);
216
217 // FIXME(f16): the fenv math does not work for `f16`
218 if F::BITS > 16 {
219 // Handle rounding and inexact. `(m + 1)^2 == 2^shift m` is exact; for all other cases, add
220 // a tiny value to cause fenv effects.
221 let d2 = d1.wrapping_add(m).wrapping_add(one);
222 let mut tiny = if d2 == zero {
223 cold_path();
224 zero
225 } else {
226 F::IMPLICIT_BIT
227 };
228
229 tiny |= (d1 ^ d2) & F::SIGN_MASK;
230 let t = F::from_bits(tiny);
231 y = y + t;
232 }
233
234 FpResult::ok(y)
235}
236
237/// Multiply at the wider integer size, returning the high half.
238fn wmulh<I: HInt>(a: I, b: I) -> I {
239 a.widen_mul(b).hi()
240}
241
242/// Perform `count` goldschmidt iterations, returning `(r_u0, s_u?)`.
243///
244/// - `r_u0` is the reciprocal `r ~ 1 / sqrt(m)`, as U0.
245/// - `s_u2` is the square root, `s ~ sqrt(m)`, as U2.
246/// - `count` is the number of iterations to perform.
247/// - `final_set` should be true if this is the last round (same-sized integer). If so, the
248/// returned `s` will be U3, for later shifting. Otherwise, the returned `s` is U2.
249///
250/// Note that performance relies on the optimizer being able to unroll these loops (reasonably
251/// trivial, `count` is a constant when called).
252#[inline]
253fn goldschmidt<F, I>(mut r_u0: I, mut s_u2: I, count: u32, final_set: bool) -> (I, I)
254where
255 F: SqrtHelper,
256 I: HInt + From<u8>,
257{
258 let three_u2 = I::from(0b11u8) << (I::BITS - 2);
259 let mut u_u0 = r_u0;
260
261 for i in 0..count {
262 // First iteration: `s = m*r` (`u_u0 = r_u0` set above)
263 // Subsequent iterations: `s=s*u/2`
264 s_u2 = wmulh(s_u2, u_u0);
265
266 // Perform `s /= 2` if:
267 //
268 // 1. This is not the first iteration (the first iteration is `s = m*r`)...
269 // 2. ... and this is not the last set of iterations
270 // 3. ... or, if this is the last set, it is not the last iteration
271 //
272 // This step is not performed for the final iteration because the shift is combined with
273 // a later shift (moving `s` into the mantissa).
274 if i > 0 && (!final_set || i + 1 < count) {
275 s_u2 <<= 1;
276 }
277
278 // u = 3 - s*r
279 let d_u2 = wmulh(s_u2, r_u0);
280 u_u0 = three_u2.wrapping_sub(d_u2);
281
282 // r = r*u/2
283 r_u0 = wmulh(r_u0, u_u0) << 1;
284 }
285
286 (r_u0, s_u2)
287}
288
289/// Representation of whether we shift the exponent into a `u32`, or modify it in place to save
290/// the shift operations.
291enum Exp<T> {
292 /// The exponent has been shifted to a `u32` and is LSB-aligned.
293 Shifted(u32),
294 /// The exponent is in its natural position in integer repr.
295 NoShift(T),
296}
297
298/// Size-specific constants related to the square root routine.
299pub trait SqrtHelper: Float {
300 /// Integer for the first set of rounds. If unused, set to the same type as the next set.
301 type ISet1: HInt + Into<Self::ISet2> + CastFrom<Self::Int> + From<u8>;
302 /// Integer for the second set of rounds. If unused, set to the same type as the next set.
303 type ISet2: HInt + From<Self::ISet1> + From<u8>;
304
305 /// Number of rounds at `ISet1`.
306 const SET1_ROUNDS: u32 = 0;
307 /// Number of rounds at `ISet2`.
308 const SET2_ROUNDS: u32 = 0;
309 /// Number of rounds at `Self::Int`.
310 const FINAL_ROUNDS: u32;
311}
312
313#[cfg(f16_enabled)]
314impl SqrtHelper for f16 {
315 type ISet1 = u16; // unused
316 type ISet2 = u16; // unused
317
318 const FINAL_ROUNDS: u32 = 2;
319}
320
321impl SqrtHelper for f32 {
322 type ISet1 = u32; // unused
323 type ISet2 = u32; // unused
324
325 const FINAL_ROUNDS: u32 = 3;
326}
327
328impl SqrtHelper for f64 {
329 type ISet1 = u32; // unused
330 type ISet2 = u32;
331
332 const SET2_ROUNDS: u32 = 2;
333 const FINAL_ROUNDS: u32 = 2;
334}
335
336#[cfg(f128_enabled)]
337impl SqrtHelper for f128 {
338 type ISet1 = u32;
339 type ISet2 = u64;
340
341 const SET1_ROUNDS: u32 = 1;
342 const SET2_ROUNDS: u32 = 2;
343 const FINAL_ROUNDS: u32 = 2;
344}
345
346/// A U0.16 representation of `1/sqrt(x)`.
347///
348/// The index is a 7-bit number consisting of a single exponent bit and 6 bits of significand.
349#[rustfmt::skip]
350static RSQRT_TAB: [u16; 128] = [
351 0xb451, 0xb2f0, 0xb196, 0xb044, 0xaef9, 0xadb6, 0xac79, 0xab43,
352 0xaa14, 0xa8eb, 0xa7c8, 0xa6aa, 0xa592, 0xa480, 0xa373, 0xa26b,
353 0xa168, 0xa06a, 0x9f70, 0x9e7b, 0x9d8a, 0x9c9d, 0x9bb5, 0x9ad1,
354 0x99f0, 0x9913, 0x983a, 0x9765, 0x9693, 0x95c4, 0x94f8, 0x9430,
355 0x936b, 0x92a9, 0x91ea, 0x912e, 0x9075, 0x8fbe, 0x8f0a, 0x8e59,
356 0x8daa, 0x8cfe, 0x8c54, 0x8bac, 0x8b07, 0x8a64, 0x89c4, 0x8925,
357 0x8889, 0x87ee, 0x8756, 0x86c0, 0x862b, 0x8599, 0x8508, 0x8479,
358 0x83ec, 0x8361, 0x82d8, 0x8250, 0x81c9, 0x8145, 0x80c2, 0x8040,
359 0xff02, 0xfd0e, 0xfb25, 0xf947, 0xf773, 0xf5aa, 0xf3ea, 0xf234,
360 0xf087, 0xeee3, 0xed47, 0xebb3, 0xea27, 0xe8a3, 0xe727, 0xe5b2,
361 0xe443, 0xe2dc, 0xe17a, 0xe020, 0xdecb, 0xdd7d, 0xdc34, 0xdaf1,
362 0xd9b3, 0xd87b, 0xd748, 0xd61a, 0xd4f1, 0xd3cd, 0xd2ad, 0xd192,
363 0xd07b, 0xcf69, 0xce5b, 0xcd51, 0xcc4a, 0xcb48, 0xca4a, 0xc94f,
364 0xc858, 0xc764, 0xc674, 0xc587, 0xc49d, 0xc3b7, 0xc2d4, 0xc1f4,
365 0xc116, 0xc03c, 0xbf65, 0xbe90, 0xbdbe, 0xbcef, 0xbc23, 0xbb59,
366 0xba91, 0xb9cc, 0xb90a, 0xb84a, 0xb78c, 0xb6d0, 0xb617, 0xb560,
367];
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372
373 /// Test behavior specified in IEEE 754 `squareRoot`.
374 fn spec_test<F>()
375 where
376 F: Float + SqrtHelper,
377 F::Int: HInt,
378 F::Int: From<u8>,
379 F::Int: From<F::ISet2>,
380 F::Int: CastInto<F::ISet1>,
381 F::Int: CastInto<F::ISet2>,
382 u32: CastInto<F::Int>,
383 {
384 // Values that should return a NaN and raise invalid
385 let nan = [F::NEG_INFINITY, F::NEG_ONE, F::NAN, F::MIN];
386
387 // Values that return unaltered
388 let roundtrip = [F::ZERO, F::NEG_ZERO, F::INFINITY];
389
390 for x in nan {
391 let FpResult { val, status } = sqrt_round(x, Round::Nearest);
392 assert!(val.is_nan());
393 assert!(status == Status::INVALID);
394 }
395
396 for x in roundtrip {
397 let FpResult { val, status } = sqrt_round(x, Round::Nearest);
398 assert_biteq!(val, x);
399 assert!(status == Status::OK);
400 }
401 }
402
403 #[test]
404 #[cfg(f16_enabled)]
405 fn sanity_check_f16() {
406 assert_biteq!(sqrt(100.0f16), 10.0);
407 assert_biteq!(sqrt(4.0f16), 2.0);
408 }
409
410 #[test]
411 #[cfg(f16_enabled)]
412 fn spec_tests_f16() {
413 spec_test::<f16>();
414 }
415
416 #[test]
417 #[cfg(f16_enabled)]
418 #[allow(clippy::approx_constant)]
419 fn conformance_tests_f16() {
420 let cases = [
421 (f16::PI, 0x3f17_u16),
422 // 10_000.0, using a hex literal for MSRV hack (Rust < 1.67 checks literal widths as
423 // part of the AST, so the `cfg` is irrelevant here).
424 (f16::from_bits(0x70e2), 0x5640_u16),
425 (f16::from_bits(0x0000000f), 0x13bf_u16),
426 (f16::INFINITY, f16::INFINITY.to_bits()),
427 ];
428
429 for (input, output) in cases {
430 assert_biteq!(
431 sqrt(input),
432 f16::from_bits(output),
433 "input: {input:?} ({:#018x})",
434 input.to_bits()
435 );
436 }
437 }
438
439 #[test]
440 fn sanity_check_f32() {
441 assert_biteq!(sqrt(100.0f32), 10.0);
442 assert_biteq!(sqrt(4.0f32), 2.0);
443 }
444
445 #[test]
446 fn spec_tests_f32() {
447 spec_test::<f32>();
448 }
449
450 #[test]
451 #[allow(clippy::approx_constant)]
452 fn conformance_tests_f32() {
453 let cases = [
454 (f32::PI, 0x3fe2dfc5_u32),
455 (10000.0f32, 0x42c80000_u32),
456 (f32::from_bits(0x0000000f), 0x1b2f456f_u32),
457 (f32::INFINITY, f32::INFINITY.to_bits()),
458 ];
459
460 for (input, output) in cases {
461 assert_biteq!(
462 sqrt(input),
463 f32::from_bits(output),
464 "input: {input:?} ({:#018x})",
465 input.to_bits()
466 );
467 }
468 }
469
470 #[test]
471 fn sanity_check_f64() {
472 assert_biteq!(sqrt(100.0f64), 10.0);
473 assert_biteq!(sqrt(4.0f64), 2.0);
474 }
475
476 #[test]
477 fn spec_tests_f64() {
478 spec_test::<f64>();
479 }
480
481 #[test]
482 #[allow(clippy::approx_constant)]
483 fn conformance_tests_f64() {
484 let cases = [
485 (f64::PI, 0x3ffc5bf891b4ef6a_u64),
486 (10000.0, 0x4059000000000000_u64),
487 (f64::from_bits(0x0000000f), 0x1e7efbdeb14f4eda_u64),
488 (f64::INFINITY, f64::INFINITY.to_bits()),
489 ];
490
491 for (input, output) in cases {
492 assert_biteq!(
493 sqrt(input),
494 f64::from_bits(output),
495 "input: {input:?} ({:#018x})",
496 input.to_bits()
497 );
498 }
499 }
500
501 #[test]
502 #[cfg(f128_enabled)]
503 fn sanity_check_f128() {
504 assert_biteq!(sqrt(100.0f128), 10.0);
505 assert_biteq!(sqrt(4.0f128), 2.0);
506 }
507
508 #[test]
509 #[cfg(f128_enabled)]
510 fn spec_tests_f128() {
511 spec_test::<f128>();
512 }
513
514 #[test]
515 #[cfg(f128_enabled)]
516 #[allow(clippy::approx_constant)]
517 fn conformance_tests_f128() {
518 let cases = [
519 (f128::PI, 0x3fffc5bf891b4ef6aa79c3b0520d5db9_u128),
520 // 10_000.0, see `f16` for reasoning.
521 (
522 f128::from_bits(0x400c3880000000000000000000000000),
523 0x40059000000000000000000000000000_u128,
524 ),
525 (
526 f128::from_bits(0x0000000f),
527 0x1fc9efbdeb14f4ed9b17ae807907e1e9_u128,
528 ),
529 (f128::INFINITY, f128::INFINITY.to_bits()),
530 ];
531
532 for (input, output) in cases {
533 assert_biteq!(
534 sqrt(input),
535 f128::from_bits(output),
536 "input: {input:?} ({:#018x})",
537 input.to_bits()
538 );
539 }
540 }
541}
542