1 | //! These functions use the [Karatsuba square root algorithm][1] to compute the |
2 | //! [integer square root](https://en.wikipedia.org/wiki/Integer_square_root) |
3 | //! for the primitive integer types. |
4 | //! |
5 | //! The signed integer functions can only handle **nonnegative** inputs, so |
6 | //! that must be checked before calling those. |
7 | //! |
8 | //! [1]: <https://web.archive.org/web/20230511212802/https://inria.hal.science/inria-00072854v1/file/RR-3805.pdf> |
9 | //! "Paul Zimmermann. Karatsuba Square Root. \[Research Report\] RR-3805, |
10 | //! INRIA. 1999, pp.8. (inria-00072854)" |
11 | |
12 | /// This array stores the [integer square roots]( |
13 | /// https://en.wikipedia.org/wiki/Integer_square_root) and remainders of each |
14 | /// [`u8`](prim@u8) value. For example, `U8_ISQRT_WITH_REMAINDER[17]` will be |
15 | /// `(4, 1)` because the integer square root of 17 is 4 and because 17 is 1 |
16 | /// higher than 4 squared. |
17 | const U8_ISQRT_WITH_REMAINDER: [(u8, u8); 256] = { |
18 | let mut result: [(u8, u8); 256] = [(0, 0); 256]; |
19 | |
20 | let mut n: usize = 0; |
21 | let mut isqrt_n: usize = 0; |
22 | while n < result.len() { |
23 | result[n] = (isqrt_n as u8, (n - isqrt_n.pow(exp:2)) as u8); |
24 | |
25 | n += 1; |
26 | if n == (isqrt_n + 1).pow(exp:2) { |
27 | isqrt_n += 1; |
28 | } |
29 | } |
30 | |
31 | result |
32 | }; |
33 | |
34 | /// Returns the [integer square root]( |
35 | /// https://en.wikipedia.org/wiki/Integer_square_root) of any [`u8`](prim@u8) |
36 | /// input. |
37 | #[must_use = "this returns the result of the operation, \ |
38 | without modifying the original" ] |
39 | #[inline ] |
40 | pub(super) const fn u8(n: u8) -> u8 { |
41 | U8_ISQRT_WITH_REMAINDER[n as usize].0 |
42 | } |
43 | |
44 | /// Generates an `i*` function that returns the [integer square root]( |
45 | /// https://en.wikipedia.org/wiki/Integer_square_root) of any **nonnegative** |
46 | /// input of a specific signed integer type. |
47 | macro_rules! signed_fn { |
48 | ($SignedT:ident, $UnsignedT:ident) => { |
49 | /// Returns the [integer square root]( |
50 | /// https://en.wikipedia.org/wiki/Integer_square_root) of any |
51 | /// **nonnegative** |
52 | #[doc = concat!("[`" , stringify!($SignedT), "`](prim@" , stringify!($SignedT), ")" )] |
53 | /// input. |
54 | /// |
55 | /// # Safety |
56 | /// |
57 | /// This results in undefined behavior when the input is negative. |
58 | #[must_use = "this returns the result of the operation, \ |
59 | without modifying the original" ] |
60 | #[inline] |
61 | pub(super) const unsafe fn $SignedT(n: $SignedT) -> $SignedT { |
62 | debug_assert!(n >= 0, "Negative input inside `isqrt`." ); |
63 | $UnsignedT(n as $UnsignedT) as $SignedT |
64 | } |
65 | }; |
66 | } |
67 | |
68 | signed_fn!(i8, u8); |
69 | signed_fn!(i16, u16); |
70 | signed_fn!(i32, u32); |
71 | signed_fn!(i64, u64); |
72 | signed_fn!(i128, u128); |
73 | |
74 | /// Generates a `u*` function that returns the [integer square root]( |
75 | /// https://en.wikipedia.org/wiki/Integer_square_root) of any input of |
76 | /// a specific unsigned integer type. |
77 | macro_rules! unsigned_fn { |
78 | ($UnsignedT:ident, $HalfBitsT:ident, $stages:ident) => { |
79 | /// Returns the [integer square root]( |
80 | /// https://en.wikipedia.org/wiki/Integer_square_root) of any |
81 | #[doc = concat!("[`" , stringify!($UnsignedT), "`](prim@" , stringify!($UnsignedT), ")" )] |
82 | /// input. |
83 | #[must_use = "this returns the result of the operation, \ |
84 | without modifying the original" ] |
85 | #[inline] |
86 | pub(super) const fn $UnsignedT(mut n: $UnsignedT) -> $UnsignedT { |
87 | if n <= <$HalfBitsT>::MAX as $UnsignedT { |
88 | $HalfBitsT(n as $HalfBitsT) as $UnsignedT |
89 | } else { |
90 | // The normalization shift satisfies the Karatsuba square root |
91 | // algorithm precondition "a₃ ≥ b/4" where a₃ is the most |
92 | // significant quarter of `n`'s bits and b is the number of |
93 | // values that can be represented by that quarter of the bits. |
94 | // |
95 | // b/4 would then be all 0s except the second most significant |
96 | // bit (010...0) in binary. Since a₃ must be at least b/4, a₃'s |
97 | // most significant bit or its neighbor must be a 1. Since a₃'s |
98 | // most significant bits are `n`'s most significant bits, the |
99 | // same applies to `n`. |
100 | // |
101 | // The reason to shift by an even number of bits is because an |
102 | // even number of bits produces the square root shifted to the |
103 | // left by half of the normalization shift: |
104 | // |
105 | // sqrt(n << (2 * p)) |
106 | // sqrt(2.pow(2 * p) * n) |
107 | // sqrt(2.pow(2 * p)) * sqrt(n) |
108 | // 2.pow(p) * sqrt(n) |
109 | // sqrt(n) << p |
110 | // |
111 | // Shifting by an odd number of bits leaves an ugly sqrt(2) |
112 | // multiplied in: |
113 | // |
114 | // sqrt(n << (2 * p + 1)) |
115 | // sqrt(2.pow(2 * p + 1) * n) |
116 | // sqrt(2 * 2.pow(2 * p) * n) |
117 | // sqrt(2) * sqrt(2.pow(2 * p)) * sqrt(n) |
118 | // sqrt(2) * 2.pow(p) * sqrt(n) |
119 | // sqrt(2) * (sqrt(n) << p) |
120 | const EVEN_MAKING_BITMASK: u32 = !1; |
121 | let normalization_shift = n.leading_zeros() & EVEN_MAKING_BITMASK; |
122 | n <<= normalization_shift; |
123 | |
124 | let s = $stages(n); |
125 | |
126 | let denormalization_shift = normalization_shift >> 1; |
127 | s >> denormalization_shift |
128 | } |
129 | } |
130 | }; |
131 | } |
132 | |
133 | /// Generates the first stage of the computation after normalization. |
134 | /// |
135 | /// # Safety |
136 | /// |
137 | /// `$n` must be nonzero. |
138 | macro_rules! first_stage { |
139 | ($original_bits:literal, $n:ident) => {{ |
140 | debug_assert!($n != 0, "`$n` is zero in `first_stage!`." ); |
141 | |
142 | const N_SHIFT: u32 = $original_bits - 8; |
143 | let n = $n >> N_SHIFT; |
144 | |
145 | let (s, r) = U8_ISQRT_WITH_REMAINDER[n as usize]; |
146 | |
147 | // Inform the optimizer that `s` is nonzero. This will allow it to |
148 | // avoid generating code to handle division-by-zero panics in the next |
149 | // stage. |
150 | // |
151 | // SAFETY: If the original `$n` is zero, the top of the `unsigned_fn` |
152 | // macro recurses instead of continuing to this point, so the original |
153 | // `$n` wasn't a 0 if we've reached here. |
154 | // |
155 | // Then the `unsigned_fn` macro normalizes `$n` so that at least one of |
156 | // its two most-significant bits is a 1. |
157 | // |
158 | // Then this stage puts the eight most-significant bits of `$n` into |
159 | // `n`. This means that `n` here has at least one 1 bit in its two |
160 | // most-significant bits, making `n` nonzero. |
161 | // |
162 | // `U8_ISQRT_WITH_REMAINDER[n as usize]` will give a nonzero `s` when |
163 | // given a nonzero `n`. |
164 | unsafe { crate::hint::assert_unchecked(s != 0) }; |
165 | (s, r) |
166 | }}; |
167 | } |
168 | |
169 | /// Generates a middle stage of the computation. |
170 | /// |
171 | /// # Safety |
172 | /// |
173 | /// `$s` must be nonzero. |
174 | macro_rules! middle_stage { |
175 | ($original_bits:literal, $ty:ty, $n:ident, $s:ident, $r:ident) => {{ |
176 | debug_assert!($s != 0, "`$s` is zero in `middle_stage!`." ); |
177 | |
178 | const N_SHIFT: u32 = $original_bits - <$ty>::BITS; |
179 | let n = ($n >> N_SHIFT) as $ty; |
180 | |
181 | const HALF_BITS: u32 = <$ty>::BITS >> 1; |
182 | const QUARTER_BITS: u32 = <$ty>::BITS >> 2; |
183 | const LOWER_HALF_1_BITS: $ty = (1 << HALF_BITS) - 1; |
184 | const LOWEST_QUARTER_1_BITS: $ty = (1 << QUARTER_BITS) - 1; |
185 | |
186 | let lo = n & LOWER_HALF_1_BITS; |
187 | let numerator = (($r as $ty) << QUARTER_BITS) | (lo >> QUARTER_BITS); |
188 | let denominator = ($s as $ty) << 1; |
189 | let q = numerator / denominator; |
190 | let u = numerator % denominator; |
191 | |
192 | let mut s = ($s << QUARTER_BITS) as $ty + q; |
193 | let (mut r, overflow) = |
194 | ((u << QUARTER_BITS) | (lo & LOWEST_QUARTER_1_BITS)).overflowing_sub(q * q); |
195 | if overflow { |
196 | r = r.wrapping_add(2 * s - 1); |
197 | s -= 1; |
198 | } |
199 | |
200 | // Inform the optimizer that `s` is nonzero. This will allow it to |
201 | // avoid generating code to handle division-by-zero panics in the next |
202 | // stage. |
203 | // |
204 | // SAFETY: If the original `$n` is zero, the top of the `unsigned_fn` |
205 | // macro recurses instead of continuing to this point, so the original |
206 | // `$n` wasn't a 0 if we've reached here. |
207 | // |
208 | // Then the `unsigned_fn` macro normalizes `$n` so that at least one of |
209 | // its two most-significant bits is a 1. |
210 | // |
211 | // Then these stages take as many of the most-significant bits of `$n` |
212 | // as will fit in this stage's type. For example, the stage that |
213 | // handles `u32` deals with the 32 most-significant bits of `$n`. This |
214 | // means that each stage has at least one 1 bit in `n`'s two |
215 | // most-significant bits, making `n` nonzero. |
216 | // |
217 | // Then this stage will produce the correct integer square root for |
218 | // that `n` value. Since `n` is nonzero, `s` will also be nonzero. |
219 | unsafe { crate::hint::assert_unchecked(s != 0) }; |
220 | (s, r) |
221 | }}; |
222 | } |
223 | |
224 | /// Generates the last stage of the computation before denormalization. |
225 | /// |
226 | /// # Safety |
227 | /// |
228 | /// `$s` must be nonzero. |
229 | macro_rules! last_stage { |
230 | ($ty:ty, $n:ident, $s:ident, $r:ident) => {{ |
231 | debug_assert!($s != 0, "`$s` is zero in `last_stage!`." ); |
232 | |
233 | const HALF_BITS: u32 = <$ty>::BITS >> 1; |
234 | const QUARTER_BITS: u32 = <$ty>::BITS >> 2; |
235 | const LOWER_HALF_1_BITS: $ty = (1 << HALF_BITS) - 1; |
236 | |
237 | let lo = $n & LOWER_HALF_1_BITS; |
238 | let numerator = (($r as $ty) << QUARTER_BITS) | (lo >> QUARTER_BITS); |
239 | let denominator = ($s as $ty) << 1; |
240 | |
241 | let q = numerator / denominator; |
242 | let mut s = ($s << QUARTER_BITS) as $ty + q; |
243 | let (s_squared, overflow) = s.overflowing_mul(s); |
244 | if overflow || s_squared > $n { |
245 | s -= 1; |
246 | } |
247 | s |
248 | }}; |
249 | } |
250 | |
251 | /// Takes the normalized [`u16`](prim@u16) input and gets its normalized |
252 | /// [integer square root](https://en.wikipedia.org/wiki/Integer_square_root). |
253 | /// |
254 | /// # Safety |
255 | /// |
256 | /// `n` must be nonzero. |
257 | #[inline ] |
258 | const fn u16_stages(n: u16) -> u16 { |
259 | let (s: u8, r: u8) = first_stage!(16, n); |
260 | last_stage!(u16, n, s, r) |
261 | } |
262 | |
263 | /// Takes the normalized [`u32`](prim@u32) input and gets its normalized |
264 | /// [integer square root](https://en.wikipedia.org/wiki/Integer_square_root). |
265 | /// |
266 | /// # Safety |
267 | /// |
268 | /// `n` must be nonzero. |
269 | #[inline ] |
270 | const fn u32_stages(n: u32) -> u32 { |
271 | let (s: u8, r: u8) = first_stage!(32, n); |
272 | let (s: u16, r: u16) = middle_stage!(32, u16, n, s, r); |
273 | last_stage!(u32, n, s, r) |
274 | } |
275 | |
276 | /// Takes the normalized [`u64`](prim@u64) input and gets its normalized |
277 | /// [integer square root](https://en.wikipedia.org/wiki/Integer_square_root). |
278 | /// |
279 | /// # Safety |
280 | /// |
281 | /// `n` must be nonzero. |
282 | #[inline ] |
283 | const fn u64_stages(n: u64) -> u64 { |
284 | let (s: u8, r: u8) = first_stage!(64, n); |
285 | let (s: u16, r: u16) = middle_stage!(64, u16, n, s, r); |
286 | let (s: u32, r: u32) = middle_stage!(64, u32, n, s, r); |
287 | last_stage!(u64, n, s, r) |
288 | } |
289 | |
290 | /// Takes the normalized [`u128`](prim@u128) input and gets its normalized |
291 | /// [integer square root](https://en.wikipedia.org/wiki/Integer_square_root). |
292 | /// |
293 | /// # Safety |
294 | /// |
295 | /// `n` must be nonzero. |
296 | #[inline ] |
297 | const fn u128_stages(n: u128) -> u128 { |
298 | let (s: u8, r: u8) = first_stage!(128, n); |
299 | let (s: u16, r: u16) = middle_stage!(128, u16, n, s, r); |
300 | let (s: u32, r: u32) = middle_stage!(128, u32, n, s, r); |
301 | let (s: u64, r: u64) = middle_stage!(128, u64, n, s, r); |
302 | last_stage!(u128, n, s, r) |
303 | } |
304 | |
305 | unsigned_fn!(u16, u8, u16_stages); |
306 | unsigned_fn!(u32, u16, u32_stages); |
307 | unsigned_fn!(u64, u32, u64_stages); |
308 | unsigned_fn!(u128, u64, u128_stages); |
309 | |
310 | /// Instantiate this panic logic once, rather than for all the isqrt methods |
311 | /// on every single primitive type. |
312 | #[cold ] |
313 | #[track_caller ] |
314 | pub(super) const fn panic_for_negative_argument() -> ! { |
315 | panic!("argument of integer square root cannot be negative" ) |
316 | } |
317 | |