| 1 | /* SPDX-License-Identifier: MIT OR Apache-2.0 */ |
| 2 | use crate::support::{CastInto, DInt, HInt, Int, MinInt, u256}; |
| 3 | |
| 4 | /// Trait for unsigned division of a double-wide integer |
| 5 | /// when the quotient doesn't overflow. |
| 6 | /// |
| 7 | /// This is the inverse of widening multiplication: |
| 8 | /// - for any `x` and nonzero `y`: `x.widen_mul(y).checked_narrowing_div_rem(y) == Some((x, 0))`, |
| 9 | /// - and for any `r in 0..y`: `x.carrying_mul(y, r).checked_narrowing_div_rem(y) == Some((x, r))`, |
| 10 | #[allow (dead_code)] |
| 11 | pub trait NarrowingDiv: DInt + MinInt<Unsigned = Self> { |
| 12 | /// Computes `(self / n, self % n))` |
| 13 | /// |
| 14 | /// # Safety |
| 15 | /// The caller must ensure that `self.hi() < n`, or equivalently, |
| 16 | /// that the quotient does not overflow. |
| 17 | unsafe fn unchecked_narrowing_div_rem(self, n: Self::H) -> (Self::H, Self::H); |
| 18 | |
| 19 | /// Returns `Some((self / n, self % n))` when `self.hi() < n`. |
| 20 | fn checked_narrowing_div_rem(self, n: Self::H) -> Option<(Self::H, Self::H)> { |
| 21 | if self.hi() < n { |
| 22 | Some(unsafe { self.unchecked_narrowing_div_rem(n) }) |
| 23 | } else { |
| 24 | None |
| 25 | } |
| 26 | } |
| 27 | } |
| 28 | |
| 29 | // For primitive types we can just use the standard |
| 30 | // division operators in the double-wide type. |
| 31 | macro_rules! impl_narrowing_div_primitive { |
| 32 | ($D:ident) => { |
| 33 | impl NarrowingDiv for $D { |
| 34 | unsafe fn unchecked_narrowing_div_rem(self, n: Self::H) -> (Self::H, Self::H) { |
| 35 | if self.hi() >= n { |
| 36 | unsafe { core::hint::unreachable_unchecked() } |
| 37 | } |
| 38 | ((self / n.widen()).cast(), (self % n.widen()).cast()) |
| 39 | } |
| 40 | } |
| 41 | }; |
| 42 | } |
| 43 | |
| 44 | // Extend division from `u2N / uN` to `u4N / u2N` |
| 45 | // This is not the most efficient algorithm, but it is |
| 46 | // relatively simple. |
| 47 | macro_rules! impl_narrowing_div_recurse { |
| 48 | ($D:ident) => { |
| 49 | impl NarrowingDiv for $D { |
| 50 | unsafe fn unchecked_narrowing_div_rem(self, n: Self::H) -> (Self::H, Self::H) { |
| 51 | if self.hi() >= n { |
| 52 | unsafe { core::hint::unreachable_unchecked() } |
| 53 | } |
| 54 | |
| 55 | // Normalize the divisor by shifting the most significant one |
| 56 | // to the leading position. `n != 0` is implied by `self.hi() < n` |
| 57 | let lz = n.leading_zeros(); |
| 58 | let a = self << lz; |
| 59 | let b = n << lz; |
| 60 | |
| 61 | let ah = a.hi(); |
| 62 | let (a0, a1) = a.lo().lo_hi(); |
| 63 | // SAFETY: For both calls, `b.leading_zeros() == 0` by the above shift. |
| 64 | // SAFETY: `ah < b` follows from `self.hi() < n` |
| 65 | let (q1, r) = unsafe { div_three_digits_by_two(a1, ah, b) }; |
| 66 | // SAFETY: `r < b` is given as the postcondition of the previous call |
| 67 | let (q0, r) = unsafe { div_three_digits_by_two(a0, r, b) }; |
| 68 | |
| 69 | // Undo the earlier normalization for the remainder |
| 70 | (Self::H::from_lo_hi(q0, q1), r >> lz) |
| 71 | } |
| 72 | } |
| 73 | }; |
| 74 | } |
| 75 | |
| 76 | impl_narrowing_div_primitive!(u16); |
| 77 | impl_narrowing_div_primitive!(u32); |
| 78 | impl_narrowing_div_primitive!(u64); |
| 79 | impl_narrowing_div_primitive!(u128); |
| 80 | impl_narrowing_div_recurse!(u256); |
| 81 | |
| 82 | /// Implement `u3N / u2N`-division on top of `u2N / uN`-division. |
| 83 | /// |
| 84 | /// Returns the quotient and remainder of `(a * R + a0) / n`, |
| 85 | /// where `R = (1 << U::BITS)` is the digit size. |
| 86 | /// |
| 87 | /// # Safety |
| 88 | /// Requires that `n.leading_zeros() == 0` and `a < n`. |
| 89 | unsafe fn div_three_digits_by_two<U>(a0: U, a: U::D, n: U::D) -> (U, U::D) |
| 90 | where |
| 91 | U: HInt, |
| 92 | U::D: Int + NarrowingDiv, |
| 93 | { |
| 94 | if n.leading_zeros() > 0 || a >= n { |
| 95 | unsafe { core::hint::unreachable_unchecked() } |
| 96 | } |
| 97 | |
| 98 | // n = n1R + n0 |
| 99 | let (n0, n1) = n.lo_hi(); |
| 100 | // a = a2R + a1 |
| 101 | let (a1, a2) = a.lo_hi(); |
| 102 | |
| 103 | let mut q; |
| 104 | let mut r; |
| 105 | let mut wrap; |
| 106 | // `a < n` is guaranteed by the caller, but `a2 == n1 && a1 < n0` is possible |
| 107 | if let Some((q0, r1)) = a.checked_narrowing_div_rem(n1) { |
| 108 | q = q0; |
| 109 | // a = qn1 + r1, where 0 <= r1 < n1 |
| 110 | |
| 111 | // Include the remainder with the low bits: |
| 112 | // r = a0 + r1R |
| 113 | r = U::D::from_lo_hi(a0, r1); |
| 114 | |
| 115 | // Subtract the contribution of the divisor low bits with the estimated quotient |
| 116 | let d = q.widen_mul(n0); |
| 117 | (r, wrap) = r.overflowing_sub(d); |
| 118 | |
| 119 | // Since `q` is the quotient of dividing with a slightly smaller divisor, |
| 120 | // it may be an overapproximation, but is never too small, and similarly, |
| 121 | // `r` is now either the correct remainder ... |
| 122 | if !wrap { |
| 123 | return (q, r); |
| 124 | } |
| 125 | // ... or the remainder went "negative" (by as much as `d = qn0 < RR`) |
| 126 | // and we have to adjust. |
| 127 | q -= U::ONE; |
| 128 | } else { |
| 129 | debug_assert!(a2 == n1 && a1 < n0); |
| 130 | // Otherwise, `a2 == n1`, and the estimated quotient would be |
| 131 | // `R + (a1 % n1)`, but the correct quotient can't overflow. |
| 132 | // We'll start from `q = R = (1 << U::BITS)`, |
| 133 | // so `r = aR + a0 - qn = (a - n)R + a0` |
| 134 | r = U::D::from_lo_hi(a0, a1.wrapping_sub(n0)); |
| 135 | // Since `a < n`, the first decrement is always needed: |
| 136 | q = U::MAX; /* R - 1 */ |
| 137 | } |
| 138 | |
| 139 | (r, wrap) = r.overflowing_add(n); |
| 140 | if wrap { |
| 141 | return (q, r); |
| 142 | } |
| 143 | |
| 144 | // If the remainder still didn't wrap, we need another step. |
| 145 | q -= U::ONE; |
| 146 | (r, wrap) = r.overflowing_add(n); |
| 147 | // Since `n >= RR/2`, at least one of the two `r += n` must have wrapped. |
| 148 | debug_assert!(wrap, "estimated quotient should be off by at most two" ); |
| 149 | (q, r) |
| 150 | } |
| 151 | |
| 152 | #[cfg (test)] |
| 153 | mod test { |
| 154 | use super::{HInt, NarrowingDiv}; |
| 155 | |
| 156 | #[test ] |
| 157 | fn inverse_mul() { |
| 158 | for x in 0..=u8::MAX { |
| 159 | for y in 1..=u8::MAX { |
| 160 | let xy = x.widen_mul(y); |
| 161 | assert_eq!(xy.checked_narrowing_div_rem(y), Some((x, 0))); |
| 162 | assert_eq!( |
| 163 | (xy + (y - 1) as u16).checked_narrowing_div_rem(y), |
| 164 | Some((x, y - 1)) |
| 165 | ); |
| 166 | if y > 1 { |
| 167 | assert_eq!((xy + 1).checked_narrowing_div_rem(y), Some((x, 1))); |
| 168 | assert_eq!( |
| 169 | (xy + (y - 2) as u16).checked_narrowing_div_rem(y), |
| 170 | Some((x, y - 2)) |
| 171 | ); |
| 172 | } |
| 173 | } |
| 174 | } |
| 175 | } |
| 176 | } |
| 177 | |