1 | use crate::Integer; |
2 | use core::mem; |
3 | use num_traits::{checked_pow, PrimInt}; |
4 | |
5 | /// Provides methods to compute an integer's square root, cube root, |
6 | /// and arbitrary `n`th root. |
7 | pub trait Roots: Integer { |
8 | /// Returns the truncated principal `n`th root of an integer |
9 | /// -- `if x >= 0 { ⌊ⁿ√x⌋ } else { ⌈ⁿ√x⌉ }` |
10 | /// |
11 | /// This is solving for `r` in `rⁿ = x`, rounding toward zero. |
12 | /// If `x` is positive, the result will satisfy `rⁿ ≤ x < (r+1)ⁿ`. |
13 | /// If `x` is negative and `n` is odd, then `(r-1)ⁿ < x ≤ rⁿ`. |
14 | /// |
15 | /// # Panics |
16 | /// |
17 | /// Panics if `n` is zero: |
18 | /// |
19 | /// ```should_panic |
20 | /// # use num_integer::Roots; |
21 | /// println!("can't compute ⁰√x : {}" , 123.nth_root(0)); |
22 | /// ``` |
23 | /// |
24 | /// or if `n` is even and `self` is negative: |
25 | /// |
26 | /// ```should_panic |
27 | /// # use num_integer::Roots; |
28 | /// println!("no imaginary numbers... {}" , (-1).nth_root(10)); |
29 | /// ``` |
30 | /// |
31 | /// # Examples |
32 | /// |
33 | /// ``` |
34 | /// use num_integer::Roots; |
35 | /// |
36 | /// let x: i32 = 12345; |
37 | /// assert_eq!(x.nth_root(1), x); |
38 | /// assert_eq!(x.nth_root(2), x.sqrt()); |
39 | /// assert_eq!(x.nth_root(3), x.cbrt()); |
40 | /// assert_eq!(x.nth_root(4), 10); |
41 | /// assert_eq!(x.nth_root(13), 2); |
42 | /// assert_eq!(x.nth_root(14), 1); |
43 | /// assert_eq!(x.nth_root(std::u32::MAX), 1); |
44 | /// |
45 | /// assert_eq!(std::i32::MAX.nth_root(30), 2); |
46 | /// assert_eq!(std::i32::MAX.nth_root(31), 1); |
47 | /// assert_eq!(std::i32::MIN.nth_root(31), -2); |
48 | /// assert_eq!((std::i32::MIN + 1).nth_root(31), -1); |
49 | /// |
50 | /// assert_eq!(std::u32::MAX.nth_root(31), 2); |
51 | /// assert_eq!(std::u32::MAX.nth_root(32), 1); |
52 | /// ``` |
53 | fn nth_root(&self, n: u32) -> Self; |
54 | |
55 | /// Returns the truncated principal square root of an integer -- `⌊√x⌋` |
56 | /// |
57 | /// This is solving for `r` in `r² = x`, rounding toward zero. |
58 | /// The result will satisfy `r² ≤ x < (r+1)²`. |
59 | /// |
60 | /// # Panics |
61 | /// |
62 | /// Panics if `self` is less than zero: |
63 | /// |
64 | /// ```should_panic |
65 | /// # use num_integer::Roots; |
66 | /// println!("no imaginary numbers... {}" , (-1).sqrt()); |
67 | /// ``` |
68 | /// |
69 | /// # Examples |
70 | /// |
71 | /// ``` |
72 | /// use num_integer::Roots; |
73 | /// |
74 | /// let x: i32 = 12345; |
75 | /// assert_eq!((x * x).sqrt(), x); |
76 | /// assert_eq!((x * x + 1).sqrt(), x); |
77 | /// assert_eq!((x * x - 1).sqrt(), x - 1); |
78 | /// ``` |
79 | #[inline ] |
80 | fn sqrt(&self) -> Self { |
81 | self.nth_root(2) |
82 | } |
83 | |
84 | /// Returns the truncated principal cube root of an integer -- |
85 | /// `if x >= 0 { ⌊∛x⌋ } else { ⌈∛x⌉ }` |
86 | /// |
87 | /// This is solving for `r` in `r³ = x`, rounding toward zero. |
88 | /// If `x` is positive, the result will satisfy `r³ ≤ x < (r+1)³`. |
89 | /// If `x` is negative, then `(r-1)³ < x ≤ r³`. |
90 | /// |
91 | /// # Examples |
92 | /// |
93 | /// ``` |
94 | /// use num_integer::Roots; |
95 | /// |
96 | /// let x: i32 = 1234; |
97 | /// assert_eq!((x * x * x).cbrt(), x); |
98 | /// assert_eq!((x * x * x + 1).cbrt(), x); |
99 | /// assert_eq!((x * x * x - 1).cbrt(), x - 1); |
100 | /// |
101 | /// assert_eq!((-(x * x * x)).cbrt(), -x); |
102 | /// assert_eq!((-(x * x * x + 1)).cbrt(), -x); |
103 | /// assert_eq!((-(x * x * x - 1)).cbrt(), -(x - 1)); |
104 | /// ``` |
105 | #[inline ] |
106 | fn cbrt(&self) -> Self { |
107 | self.nth_root(3) |
108 | } |
109 | } |
110 | |
111 | /// Returns the truncated principal square root of an integer -- |
112 | /// see [Roots::sqrt](trait.Roots.html#method.sqrt). |
113 | #[inline ] |
114 | pub fn sqrt<T: Roots>(x: T) -> T { |
115 | x.sqrt() |
116 | } |
117 | |
118 | /// Returns the truncated principal cube root of an integer -- |
119 | /// see [Roots::cbrt](trait.Roots.html#method.cbrt). |
120 | #[inline ] |
121 | pub fn cbrt<T: Roots>(x: T) -> T { |
122 | x.cbrt() |
123 | } |
124 | |
125 | /// Returns the truncated principal `n`th root of an integer -- |
126 | /// see [Roots::nth_root](trait.Roots.html#tymethod.nth_root). |
127 | #[inline ] |
128 | pub fn nth_root<T: Roots>(x: T, n: u32) -> T { |
129 | x.nth_root(n) |
130 | } |
131 | |
132 | macro_rules! signed_roots { |
133 | ($T:ty, $U:ty) => { |
134 | impl Roots for $T { |
135 | #[inline] |
136 | fn nth_root(&self, n: u32) -> Self { |
137 | if *self >= 0 { |
138 | (*self as $U).nth_root(n) as Self |
139 | } else { |
140 | assert!(n.is_odd(), "even roots of a negative are imaginary" ); |
141 | -((self.wrapping_neg() as $U).nth_root(n) as Self) |
142 | } |
143 | } |
144 | |
145 | #[inline] |
146 | fn sqrt(&self) -> Self { |
147 | assert!(*self >= 0, "the square root of a negative is imaginary" ); |
148 | (*self as $U).sqrt() as Self |
149 | } |
150 | |
151 | #[inline] |
152 | fn cbrt(&self) -> Self { |
153 | if *self >= 0 { |
154 | (*self as $U).cbrt() as Self |
155 | } else { |
156 | -((self.wrapping_neg() as $U).cbrt() as Self) |
157 | } |
158 | } |
159 | } |
160 | }; |
161 | } |
162 | |
163 | signed_roots!(i8, u8); |
164 | signed_roots!(i16, u16); |
165 | signed_roots!(i32, u32); |
166 | signed_roots!(i64, u64); |
167 | signed_roots!(i128, u128); |
168 | signed_roots!(isize, usize); |
169 | |
170 | #[inline ] |
171 | fn fixpoint<T, F>(mut x: T, f: F) -> T |
172 | where |
173 | T: Integer + Copy, |
174 | F: Fn(T) -> T, |
175 | { |
176 | let mut xn: T = f(x); |
177 | while x < xn { |
178 | x = xn; |
179 | xn = f(x); |
180 | } |
181 | while x > xn { |
182 | x = xn; |
183 | xn = f(x); |
184 | } |
185 | x |
186 | } |
187 | |
188 | #[inline ] |
189 | fn bits<T>() -> u32 { |
190 | 8 * mem::size_of::<T>() as u32 |
191 | } |
192 | |
193 | #[inline ] |
194 | fn log2<T: PrimInt>(x: T) -> u32 { |
195 | debug_assert!(x > T::zero()); |
196 | bits::<T>() - 1 - x.leading_zeros() |
197 | } |
198 | |
199 | macro_rules! unsigned_roots { |
200 | ($T:ident) => { |
201 | impl Roots for $T { |
202 | #[inline] |
203 | fn nth_root(&self, n: u32) -> Self { |
204 | fn go(a: $T, n: u32) -> $T { |
205 | // Specialize small roots |
206 | match n { |
207 | 0 => panic!("can't find a root of degree 0!" ), |
208 | 1 => return a, |
209 | 2 => return a.sqrt(), |
210 | 3 => return a.cbrt(), |
211 | _ => (), |
212 | } |
213 | |
214 | // The root of values less than 2ⁿ can only be 0 or 1. |
215 | if bits::<$T>() <= n || a < (1 << n) { |
216 | return (a > 0) as $T; |
217 | } |
218 | |
219 | if bits::<$T>() > 64 { |
220 | // 128-bit division is slow, so do a bitwise `nth_root` until it's small enough. |
221 | return if a <= core::u64::MAX as $T { |
222 | (a as u64).nth_root(n) as $T |
223 | } else { |
224 | let lo = (a >> n).nth_root(n) << 1; |
225 | let hi = lo + 1; |
226 | // 128-bit `checked_mul` also involves division, but we can't always |
227 | // compute `hiⁿ` without risking overflow. Try to avoid it though... |
228 | if hi.next_power_of_two().trailing_zeros() * n >= bits::<$T>() { |
229 | match checked_pow(hi, n as usize) { |
230 | Some(x) if x <= a => hi, |
231 | _ => lo, |
232 | } |
233 | } else { |
234 | if hi.pow(n) <= a { |
235 | hi |
236 | } else { |
237 | lo |
238 | } |
239 | } |
240 | }; |
241 | } |
242 | |
243 | #[cfg(feature = "std" )] |
244 | #[inline] |
245 | fn guess(x: $T, n: u32) -> $T { |
246 | // for smaller inputs, `f64` doesn't justify its cost. |
247 | if bits::<$T>() <= 32 || x <= core::u32::MAX as $T { |
248 | 1 << ((log2(x) + n - 1) / n) |
249 | } else { |
250 | ((x as f64).ln() / f64::from(n)).exp() as $T |
251 | } |
252 | } |
253 | |
254 | #[cfg(not(feature = "std" ))] |
255 | #[inline] |
256 | fn guess(x: $T, n: u32) -> $T { |
257 | 1 << ((log2(x) + n - 1) / n) |
258 | } |
259 | |
260 | // https://en.wikipedia.org/wiki/Nth_root_algorithm |
261 | let n1 = n - 1; |
262 | let next = |x: $T| { |
263 | let y = match checked_pow(x, n1 as usize) { |
264 | Some(ax) => a / ax, |
265 | None => 0, |
266 | }; |
267 | (y + x * n1 as $T) / n as $T |
268 | }; |
269 | fixpoint(guess(a, n), next) |
270 | } |
271 | go(*self, n) |
272 | } |
273 | |
274 | #[inline] |
275 | fn sqrt(&self) -> Self { |
276 | fn go(a: $T) -> $T { |
277 | if bits::<$T>() > 64 { |
278 | // 128-bit division is slow, so do a bitwise `sqrt` until it's small enough. |
279 | return if a <= core::u64::MAX as $T { |
280 | (a as u64).sqrt() as $T |
281 | } else { |
282 | let lo = (a >> 2u32).sqrt() << 1; |
283 | let hi = lo + 1; |
284 | if hi * hi <= a { |
285 | hi |
286 | } else { |
287 | lo |
288 | } |
289 | }; |
290 | } |
291 | |
292 | if a < 4 { |
293 | return (a > 0) as $T; |
294 | } |
295 | |
296 | #[cfg(feature = "std" )] |
297 | #[inline] |
298 | fn guess(x: $T) -> $T { |
299 | (x as f64).sqrt() as $T |
300 | } |
301 | |
302 | #[cfg(not(feature = "std" ))] |
303 | #[inline] |
304 | fn guess(x: $T) -> $T { |
305 | 1 << ((log2(x) + 1) / 2) |
306 | } |
307 | |
308 | // https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Babylonian_method |
309 | let next = |x: $T| (a / x + x) >> 1; |
310 | fixpoint(guess(a), next) |
311 | } |
312 | go(*self) |
313 | } |
314 | |
315 | #[inline] |
316 | fn cbrt(&self) -> Self { |
317 | fn go(a: $T) -> $T { |
318 | if bits::<$T>() > 64 { |
319 | // 128-bit division is slow, so do a bitwise `cbrt` until it's small enough. |
320 | return if a <= core::u64::MAX as $T { |
321 | (a as u64).cbrt() as $T |
322 | } else { |
323 | let lo = (a >> 3u32).cbrt() << 1; |
324 | let hi = lo + 1; |
325 | if hi * hi * hi <= a { |
326 | hi |
327 | } else { |
328 | lo |
329 | } |
330 | }; |
331 | } |
332 | |
333 | if bits::<$T>() <= 32 { |
334 | // Implementation based on Hacker's Delight `icbrt2` |
335 | let mut x = a; |
336 | let mut y2 = 0; |
337 | let mut y = 0; |
338 | let smax = bits::<$T>() / 3; |
339 | for s in (0..smax + 1).rev() { |
340 | let s = s * 3; |
341 | y2 *= 4; |
342 | y *= 2; |
343 | let b = 3 * (y2 + y) + 1; |
344 | if x >> s >= b { |
345 | x -= b << s; |
346 | y2 += 2 * y + 1; |
347 | y += 1; |
348 | } |
349 | } |
350 | return y; |
351 | } |
352 | |
353 | if a < 8 { |
354 | return (a > 0) as $T; |
355 | } |
356 | if a <= core::u32::MAX as $T { |
357 | return (a as u32).cbrt() as $T; |
358 | } |
359 | |
360 | #[cfg(feature = "std" )] |
361 | #[inline] |
362 | fn guess(x: $T) -> $T { |
363 | (x as f64).cbrt() as $T |
364 | } |
365 | |
366 | #[cfg(not(feature = "std" ))] |
367 | #[inline] |
368 | fn guess(x: $T) -> $T { |
369 | 1 << ((log2(x) + 2) / 3) |
370 | } |
371 | |
372 | // https://en.wikipedia.org/wiki/Cube_root#Numerical_methods |
373 | let next = |x: $T| (a / (x * x) + x * 2) / 3; |
374 | fixpoint(guess(a), next) |
375 | } |
376 | go(*self) |
377 | } |
378 | } |
379 | }; |
380 | } |
381 | |
382 | unsigned_roots!(u8); |
383 | unsigned_roots!(u16); |
384 | unsigned_roots!(u32); |
385 | unsigned_roots!(u64); |
386 | unsigned_roots!(u128); |
387 | unsigned_roots!(usize); |
388 | |