1/* SPDX-License-Identifier: MIT OR Apache-2.0 */
2use crate::support::{CastInto, DInt, HInt, Int, MinInt, u256};
3
4/// Trait for unsigned division of a double-wide integer
5/// when the quotient doesn't overflow.
6///
7/// This is the inverse of widening multiplication:
8/// - for any `x` and nonzero `y`: `x.widen_mul(y).checked_narrowing_div_rem(y) == Some((x, 0))`,
9/// - and for any `r in 0..y`: `x.carrying_mul(y, r).checked_narrowing_div_rem(y) == Some((x, r))`,
10#[allow(dead_code)]
11pub trait NarrowingDiv: DInt + MinInt<Unsigned = Self> {
12 /// Computes `(self / n, self % n))`
13 ///
14 /// # Safety
15 /// The caller must ensure that `self.hi() < n`, or equivalently,
16 /// that the quotient does not overflow.
17 unsafe fn unchecked_narrowing_div_rem(self, n: Self::H) -> (Self::H, Self::H);
18
19 /// Returns `Some((self / n, self % n))` when `self.hi() < n`.
20 fn checked_narrowing_div_rem(self, n: Self::H) -> Option<(Self::H, Self::H)> {
21 if self.hi() < n {
22 Some(unsafe { self.unchecked_narrowing_div_rem(n) })
23 } else {
24 None
25 }
26 }
27}
28
29// For primitive types we can just use the standard
30// division operators in the double-wide type.
31macro_rules! impl_narrowing_div_primitive {
32 ($D:ident) => {
33 impl NarrowingDiv for $D {
34 unsafe fn unchecked_narrowing_div_rem(self, n: Self::H) -> (Self::H, Self::H) {
35 if self.hi() >= n {
36 unsafe { core::hint::unreachable_unchecked() }
37 }
38 ((self / n.widen()).cast(), (self % n.widen()).cast())
39 }
40 }
41 };
42}
43
44// Extend division from `u2N / uN` to `u4N / u2N`
45// This is not the most efficient algorithm, but it is
46// relatively simple.
47macro_rules! impl_narrowing_div_recurse {
48 ($D:ident) => {
49 impl NarrowingDiv for $D {
50 unsafe fn unchecked_narrowing_div_rem(self, n: Self::H) -> (Self::H, Self::H) {
51 if self.hi() >= n {
52 unsafe { core::hint::unreachable_unchecked() }
53 }
54
55 // Normalize the divisor by shifting the most significant one
56 // to the leading position. `n != 0` is implied by `self.hi() < n`
57 let lz = n.leading_zeros();
58 let a = self << lz;
59 let b = n << lz;
60
61 let ah = a.hi();
62 let (a0, a1) = a.lo().lo_hi();
63 // SAFETY: For both calls, `b.leading_zeros() == 0` by the above shift.
64 // SAFETY: `ah < b` follows from `self.hi() < n`
65 let (q1, r) = unsafe { div_three_digits_by_two(a1, ah, b) };
66 // SAFETY: `r < b` is given as the postcondition of the previous call
67 let (q0, r) = unsafe { div_three_digits_by_two(a0, r, b) };
68
69 // Undo the earlier normalization for the remainder
70 (Self::H::from_lo_hi(q0, q1), r >> lz)
71 }
72 }
73 };
74}
75
76impl_narrowing_div_primitive!(u16);
77impl_narrowing_div_primitive!(u32);
78impl_narrowing_div_primitive!(u64);
79impl_narrowing_div_primitive!(u128);
80impl_narrowing_div_recurse!(u256);
81
82/// Implement `u3N / u2N`-division on top of `u2N / uN`-division.
83///
84/// Returns the quotient and remainder of `(a * R + a0) / n`,
85/// where `R = (1 << U::BITS)` is the digit size.
86///
87/// # Safety
88/// Requires that `n.leading_zeros() == 0` and `a < n`.
89unsafe fn div_three_digits_by_two<U>(a0: U, a: U::D, n: U::D) -> (U, U::D)
90where
91 U: HInt,
92 U::D: Int + NarrowingDiv,
93{
94 if n.leading_zeros() > 0 || a >= n {
95 unsafe { core::hint::unreachable_unchecked() }
96 }
97
98 // n = n1R + n0
99 let (n0, n1) = n.lo_hi();
100 // a = a2R + a1
101 let (a1, a2) = a.lo_hi();
102
103 let mut q;
104 let mut r;
105 let mut wrap;
106 // `a < n` is guaranteed by the caller, but `a2 == n1 && a1 < n0` is possible
107 if let Some((q0, r1)) = a.checked_narrowing_div_rem(n1) {
108 q = q0;
109 // a = qn1 + r1, where 0 <= r1 < n1
110
111 // Include the remainder with the low bits:
112 // r = a0 + r1R
113 r = U::D::from_lo_hi(a0, r1);
114
115 // Subtract the contribution of the divisor low bits with the estimated quotient
116 let d = q.widen_mul(n0);
117 (r, wrap) = r.overflowing_sub(d);
118
119 // Since `q` is the quotient of dividing with a slightly smaller divisor,
120 // it may be an overapproximation, but is never too small, and similarly,
121 // `r` is now either the correct remainder ...
122 if !wrap {
123 return (q, r);
124 }
125 // ... or the remainder went "negative" (by as much as `d = qn0 < RR`)
126 // and we have to adjust.
127 q -= U::ONE;
128 } else {
129 debug_assert!(a2 == n1 && a1 < n0);
130 // Otherwise, `a2 == n1`, and the estimated quotient would be
131 // `R + (a1 % n1)`, but the correct quotient can't overflow.
132 // We'll start from `q = R = (1 << U::BITS)`,
133 // so `r = aR + a0 - qn = (a - n)R + a0`
134 r = U::D::from_lo_hi(a0, a1.wrapping_sub(n0));
135 // Since `a < n`, the first decrement is always needed:
136 q = U::MAX; /* R - 1 */
137 }
138
139 (r, wrap) = r.overflowing_add(n);
140 if wrap {
141 return (q, r);
142 }
143
144 // If the remainder still didn't wrap, we need another step.
145 q -= U::ONE;
146 (r, wrap) = r.overflowing_add(n);
147 // Since `n >= RR/2`, at least one of the two `r += n` must have wrapped.
148 debug_assert!(wrap, "estimated quotient should be off by at most two");
149 (q, r)
150}
151
152#[cfg(test)]
153mod test {
154 use super::{HInt, NarrowingDiv};
155
156 #[test]
157 fn inverse_mul() {
158 for x in 0..=u8::MAX {
159 for y in 1..=u8::MAX {
160 let xy = x.widen_mul(y);
161 assert_eq!(xy.checked_narrowing_div_rem(y), Some((x, 0)));
162 assert_eq!(
163 (xy + (y - 1) as u16).checked_narrowing_div_rem(y),
164 Some((x, y - 1))
165 );
166 if y > 1 {
167 assert_eq!((xy + 1).checked_narrowing_div_rem(y), Some((x, 1)));
168 assert_eq!(
169 (xy + (y - 2) as u16).checked_narrowing_div_rem(y),
170 Some((x, y - 2))
171 );
172 }
173 }
174 }
175 }
176}
177