1// Copyright (C) 2016,2017 Sebastian Dröge <sebastian@centricular.com>
2//
3// Licensed under the MIT license, see the LICENSE file or <http://opensource.org/licenses/MIT>
4
5#![no_std]
6
7//! Provides a trait for numeric types to perform combined multiplication and division with
8//! overflow protection.
9//!
10//! The [`MulDiv`] trait provides functions for performing combined multiplication and division for
11//! numeric types and comes with implementations for all the primitive integer types. Three
12//! variants with different rounding characteristics are provided: [`mul_div_floor()`],
13//! [`mul_div_round()`] and [`mul_div_ceil()`].
14//!
15//! ## Example
16//!
17//! ```rust
18//! extern crate muldiv;
19//! use muldiv::MulDiv;
20//! # fn main() {
21//! // Calculates 127 * 23 / 42 rounded down
22//! let x = 127u8.mul_div_floor(23, 42);
23//! assert_eq!(x, Some(69));
24//! # }
25//! ```
26//! [`MulDiv`]: trait.MulDiv.html
27//! [`mul_div_floor()`]: trait.MulDiv.html#tymethod.mul_div_floor
28//! [`mul_div_round()`]: trait.MulDiv.html#tymethod.mul_div_round
29//! [`mul_div_ceil()`]: trait.MulDiv.html#tymethod.mul_div_ceil
30
31use core::u16;
32use core::u32;
33use core::u64;
34use core::u8;
35
36use core::i16;
37use core::i32;
38use core::i64;
39use core::i8;
40
41/// Trait for calculating `val * num / denom` with different rounding modes and overflow
42/// protection.
43///
44/// Implementations of this trait have to ensure that even if the result of the multiplication does
45/// not fit into the type, as long as it would fit after the division the correct result has to be
46/// returned instead of `None`. `None` only should be returned if the overall result does not fit
47/// into the type.
48///
49/// This specifically means that e.g. the `u64` implementation must, depending on the arguments, be
50/// able to do 128 bit integer multiplication.
51pub trait MulDiv<RHS = Self> {
52 /// Output type for the methods of this trait.
53 type Output;
54
55 /// Calculates `floor(val * num / denom)`, i.e. the largest integer less than or equal to the
56 /// result of the division.
57 ///
58 /// ## Example
59 ///
60 /// ```rust
61 /// extern crate muldiv;
62 /// use muldiv::MulDiv;
63 ///
64 /// # fn main() {
65 /// let x = 3i8.mul_div_floor(4, 2);
66 /// assert_eq!(x, Some(6));
67 ///
68 /// let x = 5i8.mul_div_floor(2, 3);
69 /// assert_eq!(x, Some(3));
70 ///
71 /// let x = (-5i8).mul_div_floor(2, 3);
72 /// assert_eq!(x, Some(-4));
73 ///
74 /// let x = 3i8.mul_div_floor(3, 2);
75 /// assert_eq!(x, Some(4));
76 ///
77 /// let x = (-3i8).mul_div_floor(3, 2);
78 /// assert_eq!(x, Some(-5));
79 ///
80 /// let x = 127i8.mul_div_floor(4, 3);
81 /// assert_eq!(x, None);
82 /// # }
83 /// ```
84 fn mul_div_floor(self, num: RHS, denom: RHS) -> Option<Self::Output>;
85
86 /// Calculates `round(val * num / denom)`, i.e. the closest integer to the result of the
87 /// division. If both surrounding integers are the same distance (`x.5`), the one with the bigger
88 /// absolute value is returned (round away from 0.0).
89 ///
90 /// ## Example
91 ///
92 /// ```rust
93 /// extern crate muldiv;
94 /// use muldiv::MulDiv;
95 ///
96 /// # fn main() {
97 /// let x = 3i8.mul_div_round(4, 2);
98 /// assert_eq!(x, Some(6));
99 ///
100 /// let x = 5i8.mul_div_round(2, 3);
101 /// assert_eq!(x, Some(3));
102 ///
103 /// let x = (-5i8).mul_div_round(2, 3);
104 /// assert_eq!(x, Some(-3));
105 ///
106 /// let x = 3i8.mul_div_round(3, 2);
107 /// assert_eq!(x, Some(5));
108 ///
109 /// let x = (-3i8).mul_div_round(3, 2);
110 /// assert_eq!(x, Some(-5));
111 ///
112 /// let x = 127i8.mul_div_round(4, 3);
113 /// assert_eq!(x, None);
114 /// # }
115 /// ```
116 fn mul_div_round(self, num: RHS, denom: RHS) -> Option<Self::Output>;
117
118 /// Calculates `ceil(val * num / denom)`, i.e. the the smallest integer greater than or equal to
119 /// the result of the division.
120 ///
121 /// ## Example
122 ///
123 /// ```rust
124 /// extern crate muldiv;
125 /// use muldiv::MulDiv;
126 ///
127 /// # fn main() {
128 /// let x = 3i8.mul_div_ceil(4, 2);
129 /// assert_eq!(x, Some(6));
130 ///
131 /// let x = 5i8.mul_div_ceil(2, 3);
132 /// assert_eq!(x, Some(4));
133 ///
134 /// let x = (-5i8).mul_div_ceil(2, 3);
135 /// assert_eq!(x, Some(-3));
136 ///
137 /// let x = 3i8.mul_div_ceil(3, 2);
138 /// assert_eq!(x, Some(5));
139 ///
140 /// let x = (-3i8).mul_div_ceil(3, 2);
141 /// assert_eq!(x, Some(-4));
142 ///
143 /// let x = (127i8).mul_div_ceil(4, 3);
144 /// assert_eq!(x, None);
145 /// # }
146 /// ```
147 fn mul_div_ceil(self, num: RHS, denom: RHS) -> Option<Self::Output>;
148}
149
150macro_rules! mul_div_impl_unsigned {
151 ($t:ident, $u:ident) => {
152 impl MulDiv for $t {
153 type Output = $t;
154
155 fn mul_div_floor(self, num: $t, denom: $t) -> Option<$t> {
156 assert_ne!(denom, 0);
157 let r = ((self as $u) * (num as $u)) / (denom as $u);
158 if r > $t::MAX as $u {
159 None
160 } else {
161 Some(r as $t)
162 }
163 }
164
165 fn mul_div_round(self, num: $t, denom: $t) -> Option<$t> {
166 assert_ne!(denom, 0);
167 let r = ((self as $u) * (num as $u) + ((denom >> 1) as $u)) / (denom as $u);
168 if r > $t::MAX as $u {
169 None
170 } else {
171 Some(r as $t)
172 }
173 }
174
175 fn mul_div_ceil(self, num: $t, denom: $t) -> Option<$t> {
176 assert_ne!(denom, 0);
177 let r = ((self as $u) * (num as $u) + ((denom - 1) as $u)) / (denom as $u);
178 if r > $t::MAX as $u {
179 None
180 } else {
181 Some(r as $t)
182 }
183 }
184 }
185 };
186}
187
188#[cfg(test)]
189macro_rules! mul_div_impl_unsigned_tests {
190 ($t:ident, $u:ident) => {
191 use super::*;
192
193 use quickcheck::{quickcheck, Arbitrary, Gen};
194
195 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
196 struct NonZero($t);
197
198 impl Arbitrary for NonZero {
199 fn arbitrary(g: &mut Gen) -> Self {
200 loop {
201 let v = $t::arbitrary(g);
202 if v != 0 {
203 return NonZero(v);
204 }
205 }
206 }
207 }
208
209 quickcheck! {
210 fn scale_floor(val: $t, num: $t, den: NonZero) -> bool {
211 let res = val.mul_div_floor(num, den.0);
212
213 let expected = ((val as $u) * (num as $u)) / (den.0 as $u);
214
215 if expected > $t::MAX as $u {
216 res.is_none()
217 } else {
218 res == Some(expected as $t)
219 }
220 }
221 }
222
223 quickcheck! {
224 fn scale_round(val: $t, num: $t, den: NonZero) -> bool {
225 let res = val.mul_div_round(num, den.0);
226
227 let mut expected = ((val as $u) * (num as $u)) / (den.0 as $u);
228 let expected_rem = ((val as $u) * (num as $u)) % (den.0 as $u);
229
230 if expected_rem >= ((den.0 as $u) + 1) >> 1 {
231 expected += 1
232 }
233
234 if expected > $t::MAX as $u {
235 res.is_none()
236 } else {
237 res == Some(expected as $t)
238 }
239 }
240 }
241
242 quickcheck! {
243 fn scale_ceil(val: $t, num: $t, den: NonZero) -> bool {
244 let res = val.mul_div_ceil(num, den.0);
245
246 let mut expected = ((val as $u) * (num as $u)) / (den.0 as $u);
247 let expected_rem = ((val as $u) * (num as $u)) % (den.0 as $u);
248
249 if expected_rem != 0 {
250 expected += 1
251 }
252
253 if expected > $t::MAX as $u {
254 res.is_none()
255 } else {
256 res == Some(expected as $t)
257 }
258 }
259 }
260 };
261}
262
263mul_div_impl_unsigned!(u64, u128);
264mul_div_impl_unsigned!(u32, u64);
265mul_div_impl_unsigned!(u16, u32);
266mul_div_impl_unsigned!(u8, u16);
267
268// FIXME: https://github.com/rust-lang/rust/issues/12249
269#[cfg(test)]
270mod muldiv_u64_tests {
271 mul_div_impl_unsigned_tests!(u64, u128);
272}
273
274#[cfg(test)]
275mod muldiv_u32_tests {
276 mul_div_impl_unsigned_tests!(u32, u64);
277}
278
279#[cfg(test)]
280mod muldiv_u16_tests {
281 mul_div_impl_unsigned_tests!(u16, u32);
282}
283
284#[cfg(test)]
285mod muldiv_u8_tests {
286 mul_div_impl_unsigned_tests!(u8, u16);
287}
288
289macro_rules! mul_div_impl_signed {
290 ($t:ident, $u:ident, $v:ident, $b:expr) => {
291 impl MulDiv for $t {
292 type Output = $t;
293
294 fn mul_div_floor(self, num: $t, denom: $t) -> Option<$t> {
295 assert_ne!(denom, 0);
296
297 let sgn = self.signum() * num.signum() * denom.signum();
298
299 let min_val: $u = 1 << ($b - 1);
300 let abs = |x: $t| if x != $t::MIN { x.abs() as $u } else { min_val };
301
302 let self_u = abs(self);
303 let num_u = abs(num);
304 let denom_u = abs(denom);
305
306 if sgn < 0 {
307 self_u.mul_div_ceil(num_u, denom_u)
308 } else {
309 self_u.mul_div_floor(num_u, denom_u)
310 }
311 .and_then(|r| {
312 if r <= $t::MAX as $u {
313 Some(sgn * (r as $t))
314 } else if sgn < 0 && r == min_val {
315 Some($t::MIN)
316 } else {
317 None
318 }
319 })
320 }
321
322 fn mul_div_round(self, num: $t, denom: $t) -> Option<$t> {
323 assert_ne!(denom, 0);
324
325 let sgn = self.signum() * num.signum() * denom.signum();
326
327 let min_val: $u = 1 << ($b - 1);
328 let abs = |x: $t| if x != $t::MIN { x.abs() as $u } else { min_val };
329
330 let self_u = abs(self);
331 let num_u = abs(num);
332 let denom_u = abs(denom);
333
334 if sgn < 0 {
335 let r =
336 ((self_u as $v) * (num_u as $v) + ((denom_u >> 1) as $v)) / (denom_u as $v);
337 if r > $u::MAX as $v {
338 None
339 } else {
340 Some(r as $u)
341 }
342 } else {
343 self_u.mul_div_round(num_u, denom_u)
344 }
345 .and_then(|r| {
346 if r <= $t::MAX as $u {
347 Some(sgn * (r as $t))
348 } else if sgn < 0 && r == min_val {
349 Some($t::MIN)
350 } else {
351 None
352 }
353 })
354 }
355
356 fn mul_div_ceil(self, num: $t, denom: $t) -> Option<$t> {
357 assert_ne!(denom, 0);
358
359 let sgn = self.signum() * num.signum() * denom.signum();
360
361 let min_val: $u = 1 << ($b - 1);
362 let abs = |x: $t| if x != $t::MIN { x.abs() as $u } else { min_val };
363
364 let self_u = abs(self);
365 let num_u = abs(num);
366 let denom_u = abs(denom);
367
368 if sgn < 0 {
369 self_u.mul_div_floor(num_u, denom_u)
370 } else {
371 self_u.mul_div_ceil(num_u, denom_u)
372 }
373 .and_then(|r| {
374 if r <= $t::MAX as $u {
375 Some(sgn * (r as $t))
376 } else if sgn < 0 && r == min_val {
377 Some($t::MIN)
378 } else {
379 None
380 }
381 })
382 }
383 }
384 };
385}
386
387mul_div_impl_signed!(i64, u64, u128, 64);
388mul_div_impl_signed!(i32, u32, u64, 32);
389mul_div_impl_signed!(i16, u16, u32, 16);
390mul_div_impl_signed!(i8, u8, u16, 8);
391
392#[cfg(test)]
393macro_rules! mul_div_impl_signed_tests {
394 ($t:ident, $u:ident) => {
395 use super::*;
396
397 use quickcheck::{quickcheck, Arbitrary, Gen};
398
399 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
400 struct NonZero($t);
401
402 impl Arbitrary for NonZero {
403 fn arbitrary(g: &mut Gen) -> Self {
404 loop {
405 let v = $t::arbitrary(g);
406 if v != 0 {
407 return NonZero(v);
408 }
409 }
410 }
411 }
412
413 quickcheck! {
414 fn scale_floor(val: $t, num: $t, den: NonZero) -> bool {
415 let res = val.mul_div_floor(num, den.0);
416
417 let sgn = val.signum() * num.signum() * den.0.signum();
418 let mut expected = ((val as $u) * (num as $u)) / (den.0 as $u);
419 let expected_rem = ((val as $u) * (num as $u)) % (den.0 as $u);
420
421 if sgn < 0 && expected_rem.abs() != 0 {
422 expected -= 1
423 }
424
425 if expected > $t::MAX as $u || expected < $t::MIN as $u {
426 res.is_none()
427 } else {
428 res == Some(expected as $t)
429 }
430 }
431 }
432
433 quickcheck! {
434 fn scale_round(val: $t, num: $t, den: NonZero) -> bool {
435 let res = val.mul_div_round(num, den.0);
436
437 let sgn = val.signum() * num.signum() * den.0.signum();
438 let mut expected = ((val as $u) * (num as $u)) / (den.0 as $u);
439 let expected_rem = ((val as $u) * (num as $u)) % (den.0 as $u);
440
441 if sgn < 0 && expected_rem.abs() >= ((den.0 as $u).abs() + 1) >> 1 {
442 expected -= 1
443 } else if sgn > 0 && expected_rem.abs() >= ((den.0 as $u).abs() + 1) >> 1 {
444 expected += 1
445 }
446
447 if expected > $t::MAX as $u || expected < $t::MIN as $u {
448 res.is_none()
449 } else {
450 res == Some(expected as $t)
451 }
452 }
453 }
454
455 quickcheck! {
456 fn scale_ceil(val: $t, num: $t, den: NonZero) -> bool {
457 let res = val.mul_div_ceil(num, den.0);
458
459 let sgn = val.signum() * num.signum() * den.0.signum();
460 let mut expected = ((val as $u) * (num as $u)) / (den.0 as $u);
461 let expected_rem = ((val as $u) * (num as $u)) % (den.0 as $u);
462
463 if sgn > 0 && expected_rem.abs() != 0 {
464 expected += 1
465 }
466
467 if expected > $t::MAX as $u || expected < $t::MIN as $u {
468 res.is_none()
469 } else {
470 res == Some(expected as $t)
471 }
472 }
473 }
474 };
475}
476
477// FIXME: https://github.com/rust-lang/rust/issues/12249
478#[cfg(test)]
479mod muldiv_i64_tests {
480 mul_div_impl_signed_tests!(i64, i128);
481}
482
483#[cfg(test)]
484mod muldiv_i32_tests {
485 mul_div_impl_signed_tests!(i32, i64);
486}
487
488#[cfg(test)]
489mod muldiv_i16_tests {
490 mul_div_impl_signed_tests!(i16, i32);
491}
492
493#[cfg(test)]
494mod muldiv_i8_tests {
495 mul_div_impl_signed_tests!(i8, i16);
496}
497