1/* SPDX-License-Identifier: MIT */
2/* origin: musl src/math/sqrt.c. Ported to generic Rust algorithm in 2025, TG. */
3
4//! Generic square root algorithm.
5//!
6//! This routine operates around `m_u2`, a U.2 (fixed point with two integral bits) mantissa
7//! within the range [1, 4). A table lookup provides an initial estimate, then goldschmidt
8//! iterations at various widths are used to approach the real values.
9//!
10//! For the iterations, `r` is a U0 number that approaches `1/sqrt(m_u2)`, and `s` is a U2 number
11//! that approaches `sqrt(m_u2)`. Recall that m_u2 ∈ [1, 4).
12//!
13//! With Newton-Raphson iterations, this would be:
14//!
15//! - `w = r * r w ~ 1 / m`
16//! - `u = 3 - m * w u ~ 3 - m * w = 3 - m / m = 2`
17//! - `r = r * u / 2 r ~ r`
18//!
19//! (Note that the righthand column does not show anything analytically meaningful (i.e. r ~ r),
20//! since the value of performing one iteration is in reducing the error representable by `~`).
21//!
22//! Instead of Newton-Raphson iterations, Goldschmidt iterations are used to calculate
23//! `s = m * r`:
24//!
25//! - `s = m * r s ~ m / sqrt(m)`
26//! - `u = 3 - s * r u ~ 3 - (m / sqrt(m)) * (1 / sqrt(m)) = 3 - m / m = 2`
27//! - `r = r * u / 2 r ~ r`
28//! - `s = s * u / 2 s ~ s`
29//!
30//! The above is precise because it uses the original value `m`. There is also a faster version
31//! that performs fewer steps but does not use `m`:
32//!
33//! - `u = 3 - s * r u ~ 3 - 1`
34//! - `r = r * u / 2 r ~ r`
35//! - `s = s * u / 2 s ~ s`
36//!
37//! Rounding errors accumulate faster with the second version, so it is only used for subsequent
38//! iterations within the same width integer. The first version is always used for the first
39//! iteration at a new width in order to avoid this accumulation.
40//!
41//! Goldschmidt has the advantage over Newton-Raphson that `sqrt(x)` and `1/sqrt(x)` are
42//! computed at the same time, i.e. there is no need to calculate `1/sqrt(x)` and invert it.
43
44use super::super::support::{FpResult, IntTy, Round, Status, cold_path};
45use super::super::{CastFrom, CastInto, DInt, Float, HInt, Int, MinInt};
46
47#[inline]
48pub fn sqrt<F>(x: F) -> F
49where
50 F: Float + SqrtHelper,
51 F::Int: HInt,
52 F::Int: From<u8>,
53 F::Int: From<F::ISet2>,
54 F::Int: CastInto<F::ISet1>,
55 F::Int: CastInto<F::ISet2>,
56 u32: CastInto<F::Int>,
57{
58 sqrt_round(x, Round::Nearest).val
59}
60
61#[inline]
62pub fn sqrt_round<F>(x: F, _round: Round) -> FpResult<F>
63where
64 F: Float + SqrtHelper,
65 F::Int: HInt,
66 F::Int: From<u8>,
67 F::Int: From<F::ISet2>,
68 F::Int: CastInto<F::ISet1>,
69 F::Int: CastInto<F::ISet2>,
70 u32: CastInto<F::Int>,
71{
72 let zero = IntTy::<F>::ZERO;
73 let one = IntTy::<F>::ONE;
74
75 let mut ix = x.to_bits();
76
77 // Top is the exponent and sign, which may or may not be shifted. If the float fits into a
78 // `u32`, we can get by without paying shifting costs.
79 let noshift = F::BITS <= u32::BITS;
80 let (mut top, special_case) = if noshift {
81 let exp_lsb = one << F::SIG_BITS;
82 let special_case = ix.wrapping_sub(exp_lsb) >= F::EXP_MASK - exp_lsb;
83 (Exp::NoShift(()), special_case)
84 } else {
85 let top = u32::cast_from(ix >> F::SIG_BITS);
86 let special_case = top.wrapping_sub(1) >= F::EXP_SAT - 1;
87 (Exp::Shifted(top), special_case)
88 };
89
90 // Handle NaN, zero, and out of domain (<= 0)
91 if special_case {
92 cold_path();
93
94 // +/-0
95 if ix << 1 == zero {
96 return FpResult::ok(x);
97 }
98
99 // Positive infinity
100 if ix == F::EXP_MASK {
101 return FpResult::ok(x);
102 }
103
104 // NaN or negative
105 if ix > F::EXP_MASK {
106 return FpResult::new(F::NAN, Status::INVALID);
107 }
108
109 // Normalize subnormals by multiplying by 1.0 << SIG_BITS (e.g. 0x1p52 for doubles).
110 let scaled = x * F::from_parts(false, F::SIG_BITS + F::EXP_BIAS, zero);
111 ix = scaled.to_bits();
112 match top {
113 Exp::Shifted(ref mut v) => {
114 *v = scaled.ex();
115 *v = (*v).wrapping_sub(F::SIG_BITS);
116 }
117 Exp::NoShift(()) => {
118 ix = ix.wrapping_sub((F::SIG_BITS << F::SIG_BITS).cast());
119 }
120 }
121 }
122
123 // Reduce arguments such that `x = 4^e * m`:
124 //
125 // - m_u2 ∈ [1, 4), a fixed point U2.BITS number
126 // - 2^e is the exponent part of the result
127 let (m_u2, exp) = match top {
128 Exp::Shifted(top) => {
129 // We now know `x` is positive, so `top` is just its (biased) exponent
130 let mut e = top;
131 // Construct a fixed point representation of the mantissa.
132 let mut m_u2 = (ix | F::IMPLICIT_BIT) << F::EXP_BITS;
133 let even = (e & 1) != 0;
134 if even {
135 m_u2 >>= 1;
136 }
137 e = (e.wrapping_add(F::EXP_SAT >> 1)) >> 1;
138 (m_u2, Exp::Shifted(e))
139 }
140 Exp::NoShift(()) => {
141 let even = ix & (one << F::SIG_BITS) != zero;
142
143 // Exponent part of the return value
144 let mut e_noshift = ix >> 1;
145 // ey &= (F::EXP_MASK << 2) >> 2; // clear the top exponent bit (result = 1.0)
146 e_noshift += (F::EXP_MASK ^ (F::SIGN_MASK >> 1)) >> 1;
147 e_noshift &= F::EXP_MASK;
148
149 let m1 = (ix << F::EXP_BITS) | F::SIGN_MASK;
150 let m0 = (ix << (F::EXP_BITS - 1)) & !F::SIGN_MASK;
151 let m_u2 = if even { m0 } else { m1 };
152
153 (m_u2, Exp::NoShift(e_noshift))
154 }
155 };
156
157 // Extract the top 6 bits of the significand with the lowest bit of the exponent.
158 let i = usize::cast_from(ix >> (F::SIG_BITS - 6)) & 0b1111111;
159
160 // Start with an initial guess for `r = 1 / sqrt(m)` from the table, and shift `m` as an
161 // initial value for `s = sqrt(m)`. See the module documentation for details.
162 let r1_u0: F::ISet1 = F::ISet1::cast_from(RSQRT_TAB[i]) << (F::ISet1::BITS - 16);
163 let s1_u2: F::ISet1 = ((m_u2) >> (F::BITS - F::ISet1::BITS)).cast();
164
165 // Perform iterations, if any, at quarter width (used for `f128`).
166 let (r1_u0, _s1_u2) = goldschmidt::<F, F::ISet1>(r1_u0, s1_u2, F::SET1_ROUNDS, false);
167
168 // Widen values and perform iterations at half width (used for `f64` and `f128`).
169 let r2_u0: F::ISet2 = F::ISet2::from(r1_u0) << (F::ISet2::BITS - F::ISet1::BITS);
170 let s2_u2: F::ISet2 = ((m_u2) >> (F::BITS - F::ISet2::BITS)).cast();
171 let (r2_u0, _s2_u2) = goldschmidt::<F, F::ISet2>(r2_u0, s2_u2, F::SET2_ROUNDS, false);
172
173 // Perform final iterations at full width (used for all float types).
174 let r_u0: F::Int = F::Int::from(r2_u0) << (F::BITS - F::ISet2::BITS);
175 let s_u2: F::Int = m_u2;
176 let (_r_u0, s_u2) = goldschmidt::<F, F::Int>(r_u0, s_u2, F::FINAL_ROUNDS, true);
177
178 // Shift back to mantissa position.
179 let mut m = s_u2 >> (F::EXP_BITS - 2);
180
181 // The musl source includes the following comment (with literals replaced):
182 //
183 // > s < sqrt(m) < s + 0x1.09p-SIG_BITS
184 // > compute nearest rounded result: the nearest result to SIG_BITS bits is either s or
185 // > s+0x1p-SIG_BITS, we can decide by comparing (2^SIG_BITS s + 0.5)^2 to 2^(2*SIG_BITS) m.
186 //
187 // Expanding this with , with `SIG_BITS = p` and adjusting based on the operations done to
188 // `d0` and `d1`:
189 //
190 // - `2^(2p)m ≟ ((2^p)m + 0.5)^2`
191 // - `2^(2p)m ≟ 2^(2p)m^2 + (2^p)m + 0.25`
192 // - `2^(2p)m - m^2 ≟ (2^(2p) - 1)m^2 + (2^p)m + 0.25`
193 // - `(1 - 2^(2p))m + m^2 ≟ (1 - 2^(2p))m^2 + (1 - 2^p)m + 0.25` (?)
194 //
195 // I do not follow how the rounding bit is extracted from this comparison with the below
196 // operations. In any case, the algorithm is well tested.
197
198 // The value needed to shift `m_u2` by to create `m*2^(2p)`. `2p = 2 * F::SIG_BITS`,
199 // `F::BITS - 2` accounts for the offset that `m_u2` already has.
200 let shift = 2 * F::SIG_BITS - (F::BITS - 2);
201
202 // `2^(2p)m - m^2`
203 let d0 = (m_u2 << shift).wrapping_sub(m.wrapping_mul(m));
204 // `m - 2^(2p)m + m^2`
205 let d1 = m.wrapping_sub(d0);
206 m += d1 >> (F::BITS - 1);
207 m &= F::SIG_MASK;
208
209 match exp {
210 Exp::Shifted(e) => m |= IntTy::<F>::cast_from(e) << F::SIG_BITS,
211 Exp::NoShift(e) => m |= e,
212 };
213
214 let mut y = F::from_bits(m);
215
216 // FIXME(f16): the fenv math does not work for `f16`
217 if F::BITS > 16 {
218 // Handle rounding and inexact. `(m + 1)^2 == 2^shift m` is exact; for all other cases, add
219 // a tiny value to cause fenv effects.
220 let d2 = d1.wrapping_add(m).wrapping_add(one);
221 let mut tiny = if d2 == zero {
222 cold_path();
223 zero
224 } else {
225 F::IMPLICIT_BIT
226 };
227
228 tiny |= (d1 ^ d2) & F::SIGN_MASK;
229 let t = F::from_bits(tiny);
230 y = y + t;
231 }
232
233 FpResult::ok(y)
234}
235
236/// Multiply at the wider integer size, returning the high half.
237fn wmulh<I: HInt>(a: I, b: I) -> I {
238 a.widen_mul(b).hi()
239}
240
241/// Perform `count` goldschmidt iterations, returning `(r_u0, s_u?)`.
242///
243/// - `r_u0` is the reciprocal `r ~ 1 / sqrt(m)`, as U0.
244/// - `s_u2` is the square root, `s ~ sqrt(m)`, as U2.
245/// - `count` is the number of iterations to perform.
246/// - `final_set` should be true if this is the last round (same-sized integer). If so, the
247/// returned `s` will be U3, for later shifting. Otherwise, the returned `s` is U2.
248///
249/// Note that performance relies on the optimizer being able to unroll these loops (reasonably
250/// trivial, `count` is a constant when called).
251#[inline]
252fn goldschmidt<F, I>(mut r_u0: I, mut s_u2: I, count: u32, final_set: bool) -> (I, I)
253where
254 F: SqrtHelper,
255 I: HInt + From<u8>,
256{
257 let three_u2 = I::from(0b11u8) << (I::BITS - 2);
258 let mut u_u0 = r_u0;
259
260 for i in 0..count {
261 // First iteration: `s = m*r` (`u_u0 = r_u0` set above)
262 // Subsequent iterations: `s=s*u/2`
263 s_u2 = wmulh(s_u2, u_u0);
264
265 // Perform `s /= 2` if:
266 //
267 // 1. This is not the first iteration (the first iteration is `s = m*r`)...
268 // 2. ... and this is not the last set of iterations
269 // 3. ... or, if this is the last set, it is not the last iteration
270 //
271 // This step is not performed for the final iteration because the shift is combined with
272 // a later shift (moving `s` into the mantissa).
273 if i > 0 && (!final_set || i + 1 < count) {
274 s_u2 <<= 1;
275 }
276
277 // u = 3 - s*r
278 let d_u2 = wmulh(s_u2, r_u0);
279 u_u0 = three_u2.wrapping_sub(d_u2);
280
281 // r = r*u/2
282 r_u0 = wmulh(r_u0, u_u0) << 1;
283 }
284
285 (r_u0, s_u2)
286}
287
288/// Representation of whether we shift the exponent into a `u32`, or modify it in place to save
289/// the shift operations.
290enum Exp<T> {
291 /// The exponent has been shifted to a `u32` and is LSB-aligned.
292 Shifted(u32),
293 /// The exponent is in its natural position in integer repr.
294 NoShift(T),
295}
296
297/// Size-specific constants related to the square root routine.
298pub trait SqrtHelper: Float {
299 /// Integer for the first set of rounds. If unused, set to the same type as the next set.
300 type ISet1: HInt + Into<Self::ISet2> + CastFrom<Self::Int> + From<u8>;
301 /// Integer for the second set of rounds. If unused, set to the same type as the next set.
302 type ISet2: HInt + From<Self::ISet1> + From<u8>;
303
304 /// Number of rounds at `ISet1`.
305 const SET1_ROUNDS: u32 = 0;
306 /// Number of rounds at `ISet2`.
307 const SET2_ROUNDS: u32 = 0;
308 /// Number of rounds at `Self::Int`.
309 const FINAL_ROUNDS: u32;
310}
311
312#[cfg(f16_enabled)]
313impl SqrtHelper for f16 {
314 type ISet1 = u16; // unused
315 type ISet2 = u16; // unused
316
317 const FINAL_ROUNDS: u32 = 2;
318}
319
320impl SqrtHelper for f32 {
321 type ISet1 = u32; // unused
322 type ISet2 = u32; // unused
323
324 const FINAL_ROUNDS: u32 = 3;
325}
326
327impl SqrtHelper for f64 {
328 type ISet1 = u32; // unused
329 type ISet2 = u32;
330
331 const SET2_ROUNDS: u32 = 2;
332 const FINAL_ROUNDS: u32 = 2;
333}
334
335#[cfg(f128_enabled)]
336impl SqrtHelper for f128 {
337 type ISet1 = u32;
338 type ISet2 = u64;
339
340 const SET1_ROUNDS: u32 = 1;
341 const SET2_ROUNDS: u32 = 2;
342 const FINAL_ROUNDS: u32 = 2;
343}
344
345/// A U0.16 representation of `1/sqrt(x)`.
346///
347/// The index is a 7-bit number consisting of a single exponent bit and 6 bits of significand.
348#[rustfmt::skip]
349static RSQRT_TAB: [u16; 128] = [
350 0xb451, 0xb2f0, 0xb196, 0xb044, 0xaef9, 0xadb6, 0xac79, 0xab43,
351 0xaa14, 0xa8eb, 0xa7c8, 0xa6aa, 0xa592, 0xa480, 0xa373, 0xa26b,
352 0xa168, 0xa06a, 0x9f70, 0x9e7b, 0x9d8a, 0x9c9d, 0x9bb5, 0x9ad1,
353 0x99f0, 0x9913, 0x983a, 0x9765, 0x9693, 0x95c4, 0x94f8, 0x9430,
354 0x936b, 0x92a9, 0x91ea, 0x912e, 0x9075, 0x8fbe, 0x8f0a, 0x8e59,
355 0x8daa, 0x8cfe, 0x8c54, 0x8bac, 0x8b07, 0x8a64, 0x89c4, 0x8925,
356 0x8889, 0x87ee, 0x8756, 0x86c0, 0x862b, 0x8599, 0x8508, 0x8479,
357 0x83ec, 0x8361, 0x82d8, 0x8250, 0x81c9, 0x8145, 0x80c2, 0x8040,
358 0xff02, 0xfd0e, 0xfb25, 0xf947, 0xf773, 0xf5aa, 0xf3ea, 0xf234,
359 0xf087, 0xeee3, 0xed47, 0xebb3, 0xea27, 0xe8a3, 0xe727, 0xe5b2,
360 0xe443, 0xe2dc, 0xe17a, 0xe020, 0xdecb, 0xdd7d, 0xdc34, 0xdaf1,
361 0xd9b3, 0xd87b, 0xd748, 0xd61a, 0xd4f1, 0xd3cd, 0xd2ad, 0xd192,
362 0xd07b, 0xcf69, 0xce5b, 0xcd51, 0xcc4a, 0xcb48, 0xca4a, 0xc94f,
363 0xc858, 0xc764, 0xc674, 0xc587, 0xc49d, 0xc3b7, 0xc2d4, 0xc1f4,
364 0xc116, 0xc03c, 0xbf65, 0xbe90, 0xbdbe, 0xbcef, 0xbc23, 0xbb59,
365 0xba91, 0xb9cc, 0xb90a, 0xb84a, 0xb78c, 0xb6d0, 0xb617, 0xb560,
366];
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371
372 /// Test behavior specified in IEEE 754 `squareRoot`.
373 fn spec_test<F>()
374 where
375 F: Float + SqrtHelper,
376 F::Int: HInt,
377 F::Int: From<u8>,
378 F::Int: From<F::ISet2>,
379 F::Int: CastInto<F::ISet1>,
380 F::Int: CastInto<F::ISet2>,
381 u32: CastInto<F::Int>,
382 {
383 // Values that should return a NaN and raise invalid
384 let nan = [F::NEG_INFINITY, F::NEG_ONE, F::NAN, F::MIN];
385
386 // Values that return unaltered
387 let roundtrip = [F::ZERO, F::NEG_ZERO, F::INFINITY];
388
389 for x in nan {
390 let FpResult { val, status } = sqrt_round(x, Round::Nearest);
391 assert!(val.is_nan());
392 assert!(status == Status::INVALID);
393 }
394
395 for x in roundtrip {
396 let FpResult { val, status } = sqrt_round(x, Round::Nearest);
397 assert_biteq!(val, x);
398 assert!(status == Status::OK);
399 }
400 }
401
402 #[test]
403 #[cfg(f16_enabled)]
404 fn sanity_check_f16() {
405 assert_biteq!(sqrt(100.0f16), 10.0);
406 assert_biteq!(sqrt(4.0f16), 2.0);
407 }
408
409 #[test]
410 #[cfg(f16_enabled)]
411 fn spec_tests_f16() {
412 spec_test::<f16>();
413 }
414
415 #[test]
416 #[cfg(f16_enabled)]
417 #[allow(clippy::approx_constant)]
418 fn conformance_tests_f16() {
419 let cases = [
420 (f16::PI, 0x3f17_u16),
421 // 10_000.0, using a hex literal for MSRV hack (Rust < 1.67 checks literal widths as
422 // part of the AST, so the `cfg` is irrelevant here).
423 (f16::from_bits(0x70e2), 0x5640_u16),
424 (f16::from_bits(0x0000000f), 0x13bf_u16),
425 (f16::INFINITY, f16::INFINITY.to_bits()),
426 ];
427
428 for (input, output) in cases {
429 assert_biteq!(
430 sqrt(input),
431 f16::from_bits(output),
432 "input: {input:?} ({:#018x})",
433 input.to_bits()
434 );
435 }
436 }
437
438 #[test]
439 fn sanity_check_f32() {
440 assert_biteq!(sqrt(100.0f32), 10.0);
441 assert_biteq!(sqrt(4.0f32), 2.0);
442 }
443
444 #[test]
445 fn spec_tests_f32() {
446 spec_test::<f32>();
447 }
448
449 #[test]
450 #[allow(clippy::approx_constant)]
451 fn conformance_tests_f32() {
452 let cases = [
453 (f32::PI, 0x3fe2dfc5_u32),
454 (10000.0f32, 0x42c80000_u32),
455 (f32::from_bits(0x0000000f), 0x1b2f456f_u32),
456 (f32::INFINITY, f32::INFINITY.to_bits()),
457 ];
458
459 for (input, output) in cases {
460 assert_biteq!(
461 sqrt(input),
462 f32::from_bits(output),
463 "input: {input:?} ({:#018x})",
464 input.to_bits()
465 );
466 }
467 }
468
469 #[test]
470 fn sanity_check_f64() {
471 assert_biteq!(sqrt(100.0f64), 10.0);
472 assert_biteq!(sqrt(4.0f64), 2.0);
473 }
474
475 #[test]
476 fn spec_tests_f64() {
477 spec_test::<f64>();
478 }
479
480 #[test]
481 #[allow(clippy::approx_constant)]
482 fn conformance_tests_f64() {
483 let cases = [
484 (f64::PI, 0x3ffc5bf891b4ef6a_u64),
485 (10000.0, 0x4059000000000000_u64),
486 (f64::from_bits(0x0000000f), 0x1e7efbdeb14f4eda_u64),
487 (f64::INFINITY, f64::INFINITY.to_bits()),
488 ];
489
490 for (input, output) in cases {
491 assert_biteq!(
492 sqrt(input),
493 f64::from_bits(output),
494 "input: {input:?} ({:#018x})",
495 input.to_bits()
496 );
497 }
498 }
499
500 #[test]
501 #[cfg(f128_enabled)]
502 fn sanity_check_f128() {
503 assert_biteq!(sqrt(100.0f128), 10.0);
504 assert_biteq!(sqrt(4.0f128), 2.0);
505 }
506
507 #[test]
508 #[cfg(f128_enabled)]
509 fn spec_tests_f128() {
510 spec_test::<f128>();
511 }
512
513 #[test]
514 #[cfg(f128_enabled)]
515 #[allow(clippy::approx_constant)]
516 fn conformance_tests_f128() {
517 let cases = [
518 (f128::PI, 0x3fffc5bf891b4ef6aa79c3b0520d5db9_u128),
519 // 10_000.0, see `f16` for reasoning.
520 (
521 f128::from_bits(0x400c3880000000000000000000000000),
522 0x40059000000000000000000000000000_u128,
523 ),
524 (
525 f128::from_bits(0x0000000f),
526 0x1fc9efbdeb14f4ed9b17ae807907e1e9_u128,
527 ),
528 (f128::INFINITY, f128::INFINITY.to_bits()),
529 ];
530
531 for (input, output) in cases {
532 assert_biteq!(
533 sqrt(input),
534 f128::from_bits(output),
535 "input: {input:?} ({:#018x})",
536 input.to_bits()
537 );
538 }
539 }
540}
541