1/* SPDX-License-Identifier: MIT */
2/* origin: musl src/math/sqrt.c. Ported to generic Rust algorithm in 2025, TG. */
3
4//! Generic square root algorithm.
5//!
6//! This routine operates around `m_u2`, a U.2 (fixed point with two integral bits) mantissa
7//! within the range [1, 4). A table lookup provides an initial estimate, then goldschmidt
8//! iterations at various widths are used to approach the real values.
9//!
10//! For the iterations, `r` is a U0 number that approaches `1/sqrt(m_u2)`, and `s` is a U2 number
11//! that approaches `sqrt(m_u2)`. Recall that m_u2 ∈ [1, 4).
12//!
13//! With Newton-Raphson iterations, this would be:
14//!
15//! - `w = r * r w ~ 1 / m`
16//! - `u = 3 - m * w u ~ 3 - m * w = 3 - m / m = 2`
17//! - `r = r * u / 2 r ~ r`
18//!
19//! (Note that the righthand column does not show anything analytically meaningful (i.e. r ~ r),
20//! since the value of performing one iteration is in reducing the error representable by `~`).
21//!
22//! Instead of Newton-Raphson iterations, Goldschmidt iterations are used to calculate
23//! `s = m * r`:
24//!
25//! - `s = m * r s ~ m / sqrt(m)`
26//! - `u = 3 - s * r u ~ 3 - (m / sqrt(m)) * (1 / sqrt(m)) = 3 - m / m = 2`
27//! - `r = r * u / 2 r ~ r`
28//! - `s = s * u / 2 s ~ s`
29//!
30//! The above is precise because it uses the original value `m`. There is also a faster version
31//! that performs fewer steps but does not use `m`:
32//!
33//! - `u = 3 - s * r u ~ 3 - 1`
34//! - `r = r * u / 2 r ~ r`
35//! - `s = s * u / 2 s ~ s`
36//!
37//! Rounding errors accumulate faster with the second version, so it is only used for subsequent
38//! iterations within the same width integer. The first version is always used for the first
39//! iteration at a new width in order to avoid this accumulation.
40//!
41//! Goldschmidt has the advantage over Newton-Raphson that `sqrt(x)` and `1/sqrt(x)` are
42//! computed at the same time, i.e. there is no need to calculate `1/sqrt(x)` and invert it.
43
44use super::super::support::{FpResult, IntTy, Round, Status, cold_path};
45use super::super::{CastFrom, CastInto, DInt, Float, HInt, Int, MinInt};
46
47pub fn sqrt<F>(x: F) -> F
48where
49 F: Float + SqrtHelper,
50 F::Int: HInt,
51 F::Int: From<u8>,
52 F::Int: From<F::ISet2>,
53 F::Int: CastInto<F::ISet1>,
54 F::Int: CastInto<F::ISet2>,
55 u32: CastInto<F::Int>,
56{
57 sqrt_round(x, Round::Nearest).val
58}
59
60pub fn sqrt_round<F>(x: F, _round: Round) -> FpResult<F>
61where
62 F: Float + SqrtHelper,
63 F::Int: HInt,
64 F::Int: From<u8>,
65 F::Int: From<F::ISet2>,
66 F::Int: CastInto<F::ISet1>,
67 F::Int: CastInto<F::ISet2>,
68 u32: CastInto<F::Int>,
69{
70 let zero = IntTy::<F>::ZERO;
71 let one = IntTy::<F>::ONE;
72
73 let mut ix = x.to_bits();
74
75 // Top is the exponent and sign, which may or may not be shifted. If the float fits into a
76 // `u32`, we can get by without paying shifting costs.
77 let noshift = F::BITS <= u32::BITS;
78 let (mut top, special_case) = if noshift {
79 let exp_lsb = one << F::SIG_BITS;
80 let special_case = ix.wrapping_sub(exp_lsb) >= F::EXP_MASK - exp_lsb;
81 (Exp::NoShift(()), special_case)
82 } else {
83 let top = u32::cast_from(ix >> F::SIG_BITS);
84 let special_case = top.wrapping_sub(1) >= F::EXP_SAT - 1;
85 (Exp::Shifted(top), special_case)
86 };
87
88 // Handle NaN, zero, and out of domain (<= 0)
89 if special_case {
90 cold_path();
91
92 // +/-0
93 if ix << 1 == zero {
94 return FpResult::ok(x);
95 }
96
97 // Positive infinity
98 if ix == F::EXP_MASK {
99 return FpResult::ok(x);
100 }
101
102 // NaN or negative
103 if ix > F::EXP_MASK {
104 return FpResult::new(F::NAN, Status::INVALID);
105 }
106
107 // Normalize subnormals by multiplying by 1.0 << SIG_BITS (e.g. 0x1p52 for doubles).
108 let scaled = x * F::from_parts(false, F::SIG_BITS + F::EXP_BIAS, zero);
109 ix = scaled.to_bits();
110 match top {
111 Exp::Shifted(ref mut v) => {
112 *v = scaled.ex();
113 *v = (*v).wrapping_sub(F::SIG_BITS);
114 }
115 Exp::NoShift(()) => {
116 ix = ix.wrapping_sub((F::SIG_BITS << F::SIG_BITS).cast());
117 }
118 }
119 }
120
121 // Reduce arguments such that `x = 4^e * m`:
122 //
123 // - m_u2 ∈ [1, 4), a fixed point U2.BITS number
124 // - 2^e is the exponent part of the result
125 let (m_u2, exp) = match top {
126 Exp::Shifted(top) => {
127 // We now know `x` is positive, so `top` is just its (biased) exponent
128 let mut e = top;
129 // Construct a fixed point representation of the mantissa.
130 let mut m_u2 = (ix | F::IMPLICIT_BIT) << F::EXP_BITS;
131 let even = (e & 1) != 0;
132 if even {
133 m_u2 >>= 1;
134 }
135 e = (e.wrapping_add(F::EXP_SAT >> 1)) >> 1;
136 (m_u2, Exp::Shifted(e))
137 }
138 Exp::NoShift(()) => {
139 let even = ix & (one << F::SIG_BITS) != zero;
140
141 // Exponent part of the return value
142 let mut e_noshift = ix >> 1;
143 // ey &= (F::EXP_MASK << 2) >> 2; // clear the top exponent bit (result = 1.0)
144 e_noshift += (F::EXP_MASK ^ (F::SIGN_MASK >> 1)) >> 1;
145 e_noshift &= F::EXP_MASK;
146
147 let m1 = (ix << F::EXP_BITS) | F::SIGN_MASK;
148 let m0 = (ix << (F::EXP_BITS - 1)) & !F::SIGN_MASK;
149 let m_u2 = if even { m0 } else { m1 };
150
151 (m_u2, Exp::NoShift(e_noshift))
152 }
153 };
154
155 // Extract the top 6 bits of the significand with the lowest bit of the exponent.
156 let i = usize::cast_from(ix >> (F::SIG_BITS - 6)) & 0b1111111;
157
158 // Start with an initial guess for `r = 1 / sqrt(m)` from the table, and shift `m` as an
159 // initial value for `s = sqrt(m)`. See the module documentation for details.
160 let r1_u0: F::ISet1 = F::ISet1::cast_from(RSQRT_TAB[i]) << (F::ISet1::BITS - 16);
161 let s1_u2: F::ISet1 = ((m_u2) >> (F::BITS - F::ISet1::BITS)).cast();
162
163 // Perform iterations, if any, at quarter width (used for `f128`).
164 let (r1_u0, _s1_u2) = goldschmidt::<F, F::ISet1>(r1_u0, s1_u2, F::SET1_ROUNDS, false);
165
166 // Widen values and perform iterations at half width (used for `f64` and `f128`).
167 let r2_u0: F::ISet2 = F::ISet2::from(r1_u0) << (F::ISet2::BITS - F::ISet1::BITS);
168 let s2_u2: F::ISet2 = ((m_u2) >> (F::BITS - F::ISet2::BITS)).cast();
169 let (r2_u0, _s2_u2) = goldschmidt::<F, F::ISet2>(r2_u0, s2_u2, F::SET2_ROUNDS, false);
170
171 // Perform final iterations at full width (used for all float types).
172 let r_u0: F::Int = F::Int::from(r2_u0) << (F::BITS - F::ISet2::BITS);
173 let s_u2: F::Int = m_u2;
174 let (_r_u0, s_u2) = goldschmidt::<F, F::Int>(r_u0, s_u2, F::FINAL_ROUNDS, true);
175
176 // Shift back to mantissa position.
177 let mut m = s_u2 >> (F::EXP_BITS - 2);
178
179 // The musl source includes the following comment (with literals replaced):
180 //
181 // > s < sqrt(m) < s + 0x1.09p-SIG_BITS
182 // > compute nearest rounded result: the nearest result to SIG_BITS bits is either s or
183 // > s+0x1p-SIG_BITS, we can decide by comparing (2^SIG_BITS s + 0.5)^2 to 2^(2*SIG_BITS) m.
184 //
185 // Expanding this with , with `SIG_BITS = p` and adjusting based on the operations done to
186 // `d0` and `d1`:
187 //
188 // - `2^(2p)m ≟ ((2^p)m + 0.5)^2`
189 // - `2^(2p)m ≟ 2^(2p)m^2 + (2^p)m + 0.25`
190 // - `2^(2p)m - m^2 ≟ (2^(2p) - 1)m^2 + (2^p)m + 0.25`
191 // - `(1 - 2^(2p))m + m^2 ≟ (1 - 2^(2p))m^2 + (1 - 2^p)m + 0.25` (?)
192 //
193 // I do not follow how the rounding bit is extracted from this comparison with the below
194 // operations. In any case, the algorithm is well tested.
195
196 // The value needed to shift `m_u2` by to create `m*2^(2p)`. `2p = 2 * F::SIG_BITS`,
197 // `F::BITS - 2` accounts for the offset that `m_u2` already has.
198 let shift = 2 * F::SIG_BITS - (F::BITS - 2);
199
200 // `2^(2p)m - m^2`
201 let d0 = (m_u2 << shift).wrapping_sub(m.wrapping_mul(m));
202 // `m - 2^(2p)m + m^2`
203 let d1 = m.wrapping_sub(d0);
204 m += d1 >> (F::BITS - 1);
205 m &= F::SIG_MASK;
206
207 match exp {
208 Exp::Shifted(e) => m |= IntTy::<F>::cast_from(e) << F::SIG_BITS,
209 Exp::NoShift(e) => m |= e,
210 };
211
212 let mut y = F::from_bits(m);
213
214 // FIXME(f16): the fenv math does not work for `f16`
215 if F::BITS > 16 {
216 // Handle rounding and inexact. `(m + 1)^2 == 2^shift m` is exact; for all other cases, add
217 // a tiny value to cause fenv effects.
218 let d2 = d1.wrapping_add(m).wrapping_add(one);
219 let mut tiny = if d2 == zero {
220 cold_path();
221 zero
222 } else {
223 F::IMPLICIT_BIT
224 };
225
226 tiny |= (d1 ^ d2) & F::SIGN_MASK;
227 let t = F::from_bits(tiny);
228 y = y + t;
229 }
230
231 FpResult::ok(y)
232}
233
234/// Multiply at the wider integer size, returning the high half.
235fn wmulh<I: HInt>(a: I, b: I) -> I {
236 a.widen_mul(b).hi()
237}
238
239/// Perform `count` goldschmidt iterations, returning `(r_u0, s_u?)`.
240///
241/// - `r_u0` is the reciprocal `r ~ 1 / sqrt(m)`, as U0.
242/// - `s_u2` is the square root, `s ~ sqrt(m)`, as U2.
243/// - `count` is the number of iterations to perform.
244/// - `final_set` should be true if this is the last round (same-sized integer). If so, the
245/// returned `s` will be U3, for later shifting. Otherwise, the returned `s` is U2.
246///
247/// Note that performance relies on the optimizer being able to unroll these loops (reasonably
248/// trivial, `count` is a constant when called).
249#[inline]
250fn goldschmidt<F, I>(mut r_u0: I, mut s_u2: I, count: u32, final_set: bool) -> (I, I)
251where
252 F: SqrtHelper,
253 I: HInt + From<u8>,
254{
255 let three_u2 = I::from(0b11u8) << (I::BITS - 2);
256 let mut u_u0 = r_u0;
257
258 for i in 0..count {
259 // First iteration: `s = m*r` (`u_u0 = r_u0` set above)
260 // Subsequent iterations: `s=s*u/2`
261 s_u2 = wmulh(s_u2, u_u0);
262
263 // Perform `s /= 2` if:
264 //
265 // 1. This is not the first iteration (the first iteration is `s = m*r`)...
266 // 2. ... and this is not the last set of iterations
267 // 3. ... or, if this is the last set, it is not the last iteration
268 //
269 // This step is not performed for the final iteration because the shift is combined with
270 // a later shift (moving `s` into the mantissa).
271 if i > 0 && (!final_set || i + 1 < count) {
272 s_u2 <<= 1;
273 }
274
275 // u = 3 - s*r
276 let d_u2 = wmulh(s_u2, r_u0);
277 u_u0 = three_u2.wrapping_sub(d_u2);
278
279 // r = r*u/2
280 r_u0 = wmulh(r_u0, u_u0) << 1;
281 }
282
283 (r_u0, s_u2)
284}
285
286/// Representation of whether we shift the exponent into a `u32`, or modify it in place to save
287/// the shift operations.
288enum Exp<T> {
289 /// The exponent has been shifted to a `u32` and is LSB-aligned.
290 Shifted(u32),
291 /// The exponent is in its natural position in integer repr.
292 NoShift(T),
293}
294
295/// Size-specific constants related to the square root routine.
296pub trait SqrtHelper: Float {
297 /// Integer for the first set of rounds. If unused, set to the same type as the next set.
298 type ISet1: HInt + Into<Self::ISet2> + CastFrom<Self::Int> + From<u8>;
299 /// Integer for the second set of rounds. If unused, set to the same type as the next set.
300 type ISet2: HInt + From<Self::ISet1> + From<u8>;
301
302 /// Number of rounds at `ISet1`.
303 const SET1_ROUNDS: u32 = 0;
304 /// Number of rounds at `ISet2`.
305 const SET2_ROUNDS: u32 = 0;
306 /// Number of rounds at `Self::Int`.
307 const FINAL_ROUNDS: u32;
308}
309
310#[cfg(f16_enabled)]
311impl SqrtHelper for f16 {
312 type ISet1 = u16; // unused
313 type ISet2 = u16; // unused
314
315 const FINAL_ROUNDS: u32 = 2;
316}
317
318impl SqrtHelper for f32 {
319 type ISet1 = u32; // unused
320 type ISet2 = u32; // unused
321
322 const FINAL_ROUNDS: u32 = 3;
323}
324
325impl SqrtHelper for f64 {
326 type ISet1 = u32; // unused
327 type ISet2 = u32;
328
329 const SET2_ROUNDS: u32 = 2;
330 const FINAL_ROUNDS: u32 = 2;
331}
332
333#[cfg(f128_enabled)]
334impl SqrtHelper for f128 {
335 type ISet1 = u32;
336 type ISet2 = u64;
337
338 const SET1_ROUNDS: u32 = 1;
339 const SET2_ROUNDS: u32 = 2;
340 const FINAL_ROUNDS: u32 = 2;
341}
342
343/// A U0.16 representation of `1/sqrt(x)`.
344///
345/// The index is a 7-bit number consisting of a single exponent bit and 6 bits of significand.
346#[rustfmt::skip]
347static RSQRT_TAB: [u16; 128] = [
348 0xb451, 0xb2f0, 0xb196, 0xb044, 0xaef9, 0xadb6, 0xac79, 0xab43,
349 0xaa14, 0xa8eb, 0xa7c8, 0xa6aa, 0xa592, 0xa480, 0xa373, 0xa26b,
350 0xa168, 0xa06a, 0x9f70, 0x9e7b, 0x9d8a, 0x9c9d, 0x9bb5, 0x9ad1,
351 0x99f0, 0x9913, 0x983a, 0x9765, 0x9693, 0x95c4, 0x94f8, 0x9430,
352 0x936b, 0x92a9, 0x91ea, 0x912e, 0x9075, 0x8fbe, 0x8f0a, 0x8e59,
353 0x8daa, 0x8cfe, 0x8c54, 0x8bac, 0x8b07, 0x8a64, 0x89c4, 0x8925,
354 0x8889, 0x87ee, 0x8756, 0x86c0, 0x862b, 0x8599, 0x8508, 0x8479,
355 0x83ec, 0x8361, 0x82d8, 0x8250, 0x81c9, 0x8145, 0x80c2, 0x8040,
356 0xff02, 0xfd0e, 0xfb25, 0xf947, 0xf773, 0xf5aa, 0xf3ea, 0xf234,
357 0xf087, 0xeee3, 0xed47, 0xebb3, 0xea27, 0xe8a3, 0xe727, 0xe5b2,
358 0xe443, 0xe2dc, 0xe17a, 0xe020, 0xdecb, 0xdd7d, 0xdc34, 0xdaf1,
359 0xd9b3, 0xd87b, 0xd748, 0xd61a, 0xd4f1, 0xd3cd, 0xd2ad, 0xd192,
360 0xd07b, 0xcf69, 0xce5b, 0xcd51, 0xcc4a, 0xcb48, 0xca4a, 0xc94f,
361 0xc858, 0xc764, 0xc674, 0xc587, 0xc49d, 0xc3b7, 0xc2d4, 0xc1f4,
362 0xc116, 0xc03c, 0xbf65, 0xbe90, 0xbdbe, 0xbcef, 0xbc23, 0xbb59,
363 0xba91, 0xb9cc, 0xb90a, 0xb84a, 0xb78c, 0xb6d0, 0xb617, 0xb560,
364];
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369
370 /// Test behavior specified in IEEE 754 `squareRoot`.
371 fn spec_test<F>()
372 where
373 F: Float + SqrtHelper,
374 F::Int: HInt,
375 F::Int: From<u8>,
376 F::Int: From<F::ISet2>,
377 F::Int: CastInto<F::ISet1>,
378 F::Int: CastInto<F::ISet2>,
379 u32: CastInto<F::Int>,
380 {
381 // Values that should return a NaN and raise invalid
382 let nan = [F::NEG_INFINITY, F::NEG_ONE, F::NAN, F::MIN];
383
384 // Values that return unaltered
385 let roundtrip = [F::ZERO, F::NEG_ZERO, F::INFINITY];
386
387 for x in nan {
388 let FpResult { val, status } = sqrt_round(x, Round::Nearest);
389 assert!(val.is_nan());
390 assert!(status == Status::INVALID);
391 }
392
393 for x in roundtrip {
394 let FpResult { val, status } = sqrt_round(x, Round::Nearest);
395 assert_biteq!(val, x);
396 assert!(status == Status::OK);
397 }
398 }
399
400 #[test]
401 #[cfg(f16_enabled)]
402 fn sanity_check_f16() {
403 assert_biteq!(sqrt(100.0f16), 10.0);
404 assert_biteq!(sqrt(4.0f16), 2.0);
405 }
406
407 #[test]
408 #[cfg(f16_enabled)]
409 fn spec_tests_f16() {
410 spec_test::<f16>();
411 }
412
413 #[test]
414 #[cfg(f16_enabled)]
415 #[allow(clippy::approx_constant)]
416 fn conformance_tests_f16() {
417 let cases = [
418 (f16::PI, 0x3f17_u16),
419 // 10_000.0, using a hex literal for MSRV hack (Rust < 1.67 checks literal widths as
420 // part of the AST, so the `cfg` is irrelevant here).
421 (f16::from_bits(0x70e2), 0x5640_u16),
422 (f16::from_bits(0x0000000f), 0x13bf_u16),
423 (f16::INFINITY, f16::INFINITY.to_bits()),
424 ];
425
426 for (input, output) in cases {
427 assert_biteq!(
428 sqrt(input),
429 f16::from_bits(output),
430 "input: {input:?} ({:#018x})",
431 input.to_bits()
432 );
433 }
434 }
435
436 #[test]
437 fn sanity_check_f32() {
438 assert_biteq!(sqrt(100.0f32), 10.0);
439 assert_biteq!(sqrt(4.0f32), 2.0);
440 }
441
442 #[test]
443 fn spec_tests_f32() {
444 spec_test::<f32>();
445 }
446
447 #[test]
448 #[allow(clippy::approx_constant)]
449 fn conformance_tests_f32() {
450 let cases = [
451 (f32::PI, 0x3fe2dfc5_u32),
452 (10000.0f32, 0x42c80000_u32),
453 (f32::from_bits(0x0000000f), 0x1b2f456f_u32),
454 (f32::INFINITY, f32::INFINITY.to_bits()),
455 ];
456
457 for (input, output) in cases {
458 assert_biteq!(
459 sqrt(input),
460 f32::from_bits(output),
461 "input: {input:?} ({:#018x})",
462 input.to_bits()
463 );
464 }
465 }
466
467 #[test]
468 fn sanity_check_f64() {
469 assert_biteq!(sqrt(100.0f64), 10.0);
470 assert_biteq!(sqrt(4.0f64), 2.0);
471 }
472
473 #[test]
474 fn spec_tests_f64() {
475 spec_test::<f64>();
476 }
477
478 #[test]
479 #[allow(clippy::approx_constant)]
480 fn conformance_tests_f64() {
481 let cases = [
482 (f64::PI, 0x3ffc5bf891b4ef6a_u64),
483 (10000.0, 0x4059000000000000_u64),
484 (f64::from_bits(0x0000000f), 0x1e7efbdeb14f4eda_u64),
485 (f64::INFINITY, f64::INFINITY.to_bits()),
486 ];
487
488 for (input, output) in cases {
489 assert_biteq!(
490 sqrt(input),
491 f64::from_bits(output),
492 "input: {input:?} ({:#018x})",
493 input.to_bits()
494 );
495 }
496 }
497
498 #[test]
499 #[cfg(f128_enabled)]
500 fn sanity_check_f128() {
501 assert_biteq!(sqrt(100.0f128), 10.0);
502 assert_biteq!(sqrt(4.0f128), 2.0);
503 }
504
505 #[test]
506 #[cfg(f128_enabled)]
507 fn spec_tests_f128() {
508 spec_test::<f128>();
509 }
510
511 #[test]
512 #[cfg(f128_enabled)]
513 #[allow(clippy::approx_constant)]
514 fn conformance_tests_f128() {
515 let cases = [
516 (f128::PI, 0x3fffc5bf891b4ef6aa79c3b0520d5db9_u128),
517 // 10_000.0, see `f16` for reasoning.
518 (
519 f128::from_bits(0x400c3880000000000000000000000000),
520 0x40059000000000000000000000000000_u128,
521 ),
522 (f128::from_bits(0x0000000f), 0x1fc9efbdeb14f4ed9b17ae807907e1e9_u128),
523 (f128::INFINITY, f128::INFINITY.to_bits()),
524 ];
525
526 for (input, output) in cases {
527 assert_biteq!(
528 sqrt(input),
529 f128::from_bits(output),
530 "input: {input:?} ({:#018x})",
531 input.to_bits()
532 );
533 }
534 }
535}
536