| 1 | //!
|
| 2 | //! This module contains the single trait [`IntegerSquareRoot`] and implements it for primitive
|
| 3 | //! integer types.
|
| 4 | //!
|
| 5 | //! # Example
|
| 6 | //!
|
| 7 | //! ```
|
| 8 | //! extern crate integer_sqrt;
|
| 9 | //! // `use` trait to get functionality
|
| 10 | //! use integer_sqrt::IntegerSquareRoot;
|
| 11 | //!
|
| 12 | //! # fn main() {
|
| 13 | //! assert_eq!(4u8.integer_sqrt(), 2);
|
| 14 | //! # }
|
| 15 | //! ```
|
| 16 | //!
|
| 17 | //! [`IntegerSquareRoot`]: ./trait.IntegerSquareRoot.html
|
| 18 | #![no_std ]
|
| 19 |
|
| 20 | /// A trait implementing integer square root.
|
| 21 | pub trait IntegerSquareRoot {
|
| 22 | /// Find the integer square root.
|
| 23 | ///
|
| 24 | /// See [Integer_square_root on wikipedia][wiki_article] for more information (and also the
|
| 25 | /// source of this algorithm)
|
| 26 | ///
|
| 27 | /// # Panics
|
| 28 | ///
|
| 29 | /// For negative numbers (`i` family) this function will panic on negative input
|
| 30 | ///
|
| 31 | /// [wiki_article]: https://en.wikipedia.org/wiki/Integer_square_root
|
| 32 | fn integer_sqrt(&self) -> Self
|
| 33 | where
|
| 34 | Self: Sized,
|
| 35 | {
|
| 36 | self.integer_sqrt_checked()
|
| 37 | .expect("cannot calculate square root of negative number" )
|
| 38 | }
|
| 39 |
|
| 40 | /// Find the integer square root, returning `None` if the number is negative (this can never
|
| 41 | /// happen for unsigned types).
|
| 42 | fn integer_sqrt_checked(&self) -> Option<Self>
|
| 43 | where
|
| 44 | Self: Sized;
|
| 45 | }
|
| 46 |
|
| 47 | impl<T: num_traits::PrimInt> IntegerSquareRoot for T {
|
| 48 | fn integer_sqrt_checked(&self) -> Option<Self> {
|
| 49 | use core::cmp::Ordering;
|
| 50 | match self.cmp(&T::zero()) {
|
| 51 | // Hopefully this will be stripped for unsigned numbers (impossible condition)
|
| 52 | Ordering::Less => return None,
|
| 53 | Ordering::Equal => return Some(T::zero()),
|
| 54 | _ => {}
|
| 55 | }
|
| 56 |
|
| 57 | // Compute bit, the largest power of 4 <= n
|
| 58 | let max_shift: u32 = T::zero().leading_zeros() - 1;
|
| 59 | let shift: u32 = (max_shift - self.leading_zeros()) & !1;
|
| 60 | let mut bit = T::one().unsigned_shl(shift);
|
| 61 |
|
| 62 | // Algorithm based on the implementation in:
|
| 63 | // https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Binary_numeral_system_(base_2)
|
| 64 | // Note that result/bit are logically unsigned (even if T is signed).
|
| 65 | let mut n = *self;
|
| 66 | let mut result = T::zero();
|
| 67 | while bit != T::zero() {
|
| 68 | if n >= (result + bit) {
|
| 69 | n = n - (result + bit);
|
| 70 | result = result.unsigned_shr(1) + bit;
|
| 71 | } else {
|
| 72 | result = result.unsigned_shr(1);
|
| 73 | }
|
| 74 | bit = bit.unsigned_shr(2);
|
| 75 | }
|
| 76 | Some(result)
|
| 77 | }
|
| 78 | }
|
| 79 |
|
| 80 | #[cfg (test)]
|
| 81 | mod tests {
|
| 82 | use super::IntegerSquareRoot;
|
| 83 | use core::{i8, u16, u64, u8};
|
| 84 |
|
| 85 | macro_rules! gen_tests {
|
| 86 | ($($type:ty => $fn_name:ident),*) => {
|
| 87 | $(
|
| 88 | #[test]
|
| 89 | fn $fn_name() {
|
| 90 | let newton_raphson = |val, square| 0.5 * (val + (square / val as $type) as f64);
|
| 91 | let max_sqrt = {
|
| 92 | let square = <$type>::max_value();
|
| 93 | let mut value = (square as f64).sqrt();
|
| 94 | for _ in 0..2 {
|
| 95 | value = newton_raphson(value, square);
|
| 96 | }
|
| 97 | let mut value = value as $type;
|
| 98 | // make sure we are below the max value (this is how integer square
|
| 99 | // root works)
|
| 100 | if value.checked_mul(value).is_none() {
|
| 101 | value -= 1;
|
| 102 | }
|
| 103 | value
|
| 104 | };
|
| 105 | let tests: [($type, $type); 9] = [
|
| 106 | (0, 0),
|
| 107 | (1, 1),
|
| 108 | (2, 1),
|
| 109 | (3, 1),
|
| 110 | (4, 2),
|
| 111 | (81, 9),
|
| 112 | (80, 8),
|
| 113 | (<$type>::max_value(), max_sqrt),
|
| 114 | (<$type>::max_value() - 1, max_sqrt),
|
| 115 | ];
|
| 116 | for &(in_, out) in tests.iter() {
|
| 117 | assert_eq!(in_.integer_sqrt(), out, "in {}" , in_);
|
| 118 | }
|
| 119 | }
|
| 120 | )*
|
| 121 | };
|
| 122 | }
|
| 123 |
|
| 124 | gen_tests! {
|
| 125 | i8 => i8_test,
|
| 126 | u8 => u8_test,
|
| 127 | i16 => i16_test,
|
| 128 | u16 => u16_test,
|
| 129 | i32 => i32_test,
|
| 130 | u32 => u32_test,
|
| 131 | i64 => i64_test,
|
| 132 | u64 => u64_test,
|
| 133 | u128 => u128_test,
|
| 134 | isize => isize_test,
|
| 135 | usize => usize_test
|
| 136 | }
|
| 137 |
|
| 138 | #[test ]
|
| 139 | fn i128_test() {
|
| 140 | let tests: [(i128, i128); 8] = [
|
| 141 | (0, 0),
|
| 142 | (1, 1),
|
| 143 | (2, 1),
|
| 144 | (3, 1),
|
| 145 | (4, 2),
|
| 146 | (81, 9),
|
| 147 | (80, 8),
|
| 148 | (i128::max_value(), 13_043_817_825_332_782_212),
|
| 149 | ];
|
| 150 | for &(in_, out) in tests.iter() {
|
| 151 | assert_eq!(in_.integer_sqrt(), out, "in {}" , in_);
|
| 152 | }
|
| 153 | }
|
| 154 | }
|
| 155 | |