1 | // Copyright 2018-2020 Developers of the Rand project. |
2 | // Copyright 2017 The Rust Project Developers. |
3 | // |
4 | // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or |
5 | // https://www.apache.org/licenses/LICENSE-2.0> or the MIT license |
6 | // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your |
7 | // option. This file may not be copied, modified, or distributed |
8 | // except according to those terms. |
9 | |
10 | //! `UniformInt` implementation |
11 | |
12 | use super::{Error, SampleBorrow, SampleUniform, UniformSampler}; |
13 | use crate::distr::utils::WideningMultiply; |
14 | #[cfg (feature = "simd_support" )] |
15 | use crate::distr::{Distribution, StandardUniform}; |
16 | use crate::Rng; |
17 | |
18 | #[cfg (feature = "simd_support" )] |
19 | use core::simd::prelude::*; |
20 | #[cfg (feature = "simd_support" )] |
21 | use core::simd::{LaneCount, SupportedLaneCount}; |
22 | |
23 | #[cfg (feature = "serde" )] |
24 | use serde::{Deserialize, Serialize}; |
25 | |
26 | /// The back-end implementing [`UniformSampler`] for integer types. |
27 | /// |
28 | /// Unless you are implementing [`UniformSampler`] for your own type, this type |
29 | /// should not be used directly, use [`Uniform`] instead. |
30 | /// |
31 | /// # Implementation notes |
32 | /// |
33 | /// For simplicity, we use the same generic struct `UniformInt<X>` for all |
34 | /// integer types `X`. This gives us only one field type, `X`; to store unsigned |
35 | /// values of this size, we take use the fact that these conversions are no-ops. |
36 | /// |
37 | /// For a closed range, the number of possible numbers we should generate is |
38 | /// `range = (high - low + 1)`. To avoid bias, we must ensure that the size of |
39 | /// our sample space, `zone`, is a multiple of `range`; other values must be |
40 | /// rejected (by replacing with a new random sample). |
41 | /// |
42 | /// As a special case, we use `range = 0` to represent the full range of the |
43 | /// result type (i.e. for `new_inclusive($ty::MIN, $ty::MAX)`). |
44 | /// |
45 | /// The optimum `zone` is the largest product of `range` which fits in our |
46 | /// (unsigned) target type. We calculate this by calculating how many numbers we |
47 | /// must reject: `reject = (MAX + 1) % range = (MAX - range + 1) % range`. Any (large) |
48 | /// product of `range` will suffice, thus in `sample_single` we multiply by a |
49 | /// power of 2 via bit-shifting (faster but may cause more rejections). |
50 | /// |
51 | /// The smallest integer PRNGs generate is `u32`. For 8- and 16-bit outputs we |
52 | /// use `u32` for our `zone` and samples (because it's not slower and because |
53 | /// it reduces the chance of having to reject a sample). In this case we cannot |
54 | /// store `zone` in the target type since it is too large, however we know |
55 | /// `ints_to_reject < range <= $uty::MAX`. |
56 | /// |
57 | /// An alternative to using a modulus is widening multiply: After a widening |
58 | /// multiply by `range`, the result is in the high word. Then comparing the low |
59 | /// word against `zone` makes sure our distribution is uniform. |
60 | /// |
61 | /// # Bias |
62 | /// |
63 | /// Unless the `unbiased` feature flag is used, outputs may have a small bias. |
64 | /// In the worst case, bias affects 1 in `2^n` samples where n is |
65 | /// 56 (`i8` and `u8`), 48 (`i16` and `u16`), 96 (`i32` and `u32`), 64 (`i64` |
66 | /// and `u64`), 128 (`i128` and `u128`). |
67 | /// |
68 | /// [`Uniform`]: super::Uniform |
69 | #[derive(Clone, Copy, Debug, PartialEq, Eq)] |
70 | #[cfg_attr (feature = "serde" , derive(Serialize, Deserialize))] |
71 | pub struct UniformInt<X> { |
72 | pub(super) low: X, |
73 | pub(super) range: X, |
74 | thresh: X, // effectively 2.pow(max(64, uty_bits)) % range |
75 | } |
76 | |
77 | macro_rules! uniform_int_impl { |
78 | ($ty:ty, $uty:ty, $sample_ty:ident) => { |
79 | impl SampleUniform for $ty { |
80 | type Sampler = UniformInt<$ty>; |
81 | } |
82 | |
83 | impl UniformSampler for UniformInt<$ty> { |
84 | // We play free and fast with unsigned vs signed here |
85 | // (when $ty is signed), but that's fine, since the |
86 | // contract of this macro is for $ty and $uty to be |
87 | // "bit-equal", so casting between them is a no-op. |
88 | |
89 | type X = $ty; |
90 | |
91 | #[inline] // if the range is constant, this helps LLVM to do the |
92 | // calculations at compile-time. |
93 | fn new<B1, B2>(low_b: B1, high_b: B2) -> Result<Self, Error> |
94 | where |
95 | B1: SampleBorrow<Self::X> + Sized, |
96 | B2: SampleBorrow<Self::X> + Sized, |
97 | { |
98 | let low = *low_b.borrow(); |
99 | let high = *high_b.borrow(); |
100 | if !(low < high) { |
101 | return Err(Error::EmptyRange); |
102 | } |
103 | UniformSampler::new_inclusive(low, high - 1) |
104 | } |
105 | |
106 | #[inline] // if the range is constant, this helps LLVM to do the |
107 | // calculations at compile-time. |
108 | fn new_inclusive<B1, B2>(low_b: B1, high_b: B2) -> Result<Self, Error> |
109 | where |
110 | B1: SampleBorrow<Self::X> + Sized, |
111 | B2: SampleBorrow<Self::X> + Sized, |
112 | { |
113 | let low = *low_b.borrow(); |
114 | let high = *high_b.borrow(); |
115 | if !(low <= high) { |
116 | return Err(Error::EmptyRange); |
117 | } |
118 | |
119 | let range = high.wrapping_sub(low).wrapping_add(1) as $uty; |
120 | let thresh = if range > 0 { |
121 | let range = $sample_ty::from(range); |
122 | (range.wrapping_neg() % range) |
123 | } else { |
124 | 0 |
125 | }; |
126 | |
127 | Ok(UniformInt { |
128 | low, |
129 | range: range as $ty, // type: $uty |
130 | thresh: thresh as $uty as $ty, // type: $sample_ty |
131 | }) |
132 | } |
133 | |
134 | /// Sample from distribution, Lemire's method, unbiased |
135 | #[inline] |
136 | fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X { |
137 | let range = self.range as $uty as $sample_ty; |
138 | if range == 0 { |
139 | return rng.random(); |
140 | } |
141 | |
142 | let thresh = self.thresh as $uty as $sample_ty; |
143 | let hi = loop { |
144 | let (hi, lo) = rng.random::<$sample_ty>().wmul(range); |
145 | if lo >= thresh { |
146 | break hi; |
147 | } |
148 | }; |
149 | self.low.wrapping_add(hi as $ty) |
150 | } |
151 | |
152 | #[inline] |
153 | fn sample_single<R: Rng + ?Sized, B1, B2>( |
154 | low_b: B1, |
155 | high_b: B2, |
156 | rng: &mut R, |
157 | ) -> Result<Self::X, Error> |
158 | where |
159 | B1: SampleBorrow<Self::X> + Sized, |
160 | B2: SampleBorrow<Self::X> + Sized, |
161 | { |
162 | let low = *low_b.borrow(); |
163 | let high = *high_b.borrow(); |
164 | if !(low < high) { |
165 | return Err(Error::EmptyRange); |
166 | } |
167 | Self::sample_single_inclusive(low, high - 1, rng) |
168 | } |
169 | |
170 | /// Sample single value, Canon's method, biased |
171 | /// |
172 | /// In the worst case, bias affects 1 in `2^n` samples where n is |
173 | /// 56 (`i8`), 48 (`i16`), 96 (`i32`), 64 (`i64`), 128 (`i128`). |
174 | #[cfg(not(feature = "unbiased" ))] |
175 | #[inline] |
176 | fn sample_single_inclusive<R: Rng + ?Sized, B1, B2>( |
177 | low_b: B1, |
178 | high_b: B2, |
179 | rng: &mut R, |
180 | ) -> Result<Self::X, Error> |
181 | where |
182 | B1: SampleBorrow<Self::X> + Sized, |
183 | B2: SampleBorrow<Self::X> + Sized, |
184 | { |
185 | let low = *low_b.borrow(); |
186 | let high = *high_b.borrow(); |
187 | if !(low <= high) { |
188 | return Err(Error::EmptyRange); |
189 | } |
190 | let range = high.wrapping_sub(low).wrapping_add(1) as $uty as $sample_ty; |
191 | if range == 0 { |
192 | // Range is MAX+1 (unrepresentable), so we need a special case |
193 | return Ok(rng.random()); |
194 | } |
195 | |
196 | // generate a sample using a sensible integer type |
197 | let (mut result, lo_order) = rng.random::<$sample_ty>().wmul(range); |
198 | |
199 | // if the sample is biased... |
200 | if lo_order > range.wrapping_neg() { |
201 | // ...generate a new sample to reduce bias... |
202 | let (new_hi_order, _) = (rng.random::<$sample_ty>()).wmul(range as $sample_ty); |
203 | // ... incrementing result on overflow |
204 | let is_overflow = lo_order.checked_add(new_hi_order as $sample_ty).is_none(); |
205 | result += is_overflow as $sample_ty; |
206 | } |
207 | |
208 | Ok(low.wrapping_add(result as $ty)) |
209 | } |
210 | |
211 | /// Sample single value, Canon's method, unbiased |
212 | #[cfg(feature = "unbiased" )] |
213 | #[inline] |
214 | fn sample_single_inclusive<R: Rng + ?Sized, B1, B2>( |
215 | low_b: B1, |
216 | high_b: B2, |
217 | rng: &mut R, |
218 | ) -> Result<Self::X, Error> |
219 | where |
220 | B1: SampleBorrow<$ty> + Sized, |
221 | B2: SampleBorrow<$ty> + Sized, |
222 | { |
223 | let low = *low_b.borrow(); |
224 | let high = *high_b.borrow(); |
225 | if !(low <= high) { |
226 | return Err(Error::EmptyRange); |
227 | } |
228 | let range = high.wrapping_sub(low).wrapping_add(1) as $uty as $sample_ty; |
229 | if range == 0 { |
230 | // Range is MAX+1 (unrepresentable), so we need a special case |
231 | return Ok(rng.random()); |
232 | } |
233 | |
234 | let (mut result, mut lo) = rng.random::<$sample_ty>().wmul(range); |
235 | |
236 | // In contrast to the biased sampler, we use a loop: |
237 | while lo > range.wrapping_neg() { |
238 | let (new_hi, new_lo) = (rng.random::<$sample_ty>()).wmul(range); |
239 | match lo.checked_add(new_hi) { |
240 | Some(x) if x < $sample_ty::MAX => { |
241 | // Anything less than MAX: last term is 0 |
242 | break; |
243 | } |
244 | None => { |
245 | // Overflow: last term is 1 |
246 | result += 1; |
247 | break; |
248 | } |
249 | _ => { |
250 | // Unlikely case: must check next sample |
251 | lo = new_lo; |
252 | continue; |
253 | } |
254 | } |
255 | } |
256 | |
257 | Ok(low.wrapping_add(result as $ty)) |
258 | } |
259 | } |
260 | }; |
261 | } |
262 | |
263 | uniform_int_impl! { i8, u8, u32 } |
264 | uniform_int_impl! { i16, u16, u32 } |
265 | uniform_int_impl! { i32, u32, u32 } |
266 | uniform_int_impl! { i64, u64, u64 } |
267 | uniform_int_impl! { i128, u128, u128 } |
268 | uniform_int_impl! { u8, u8, u32 } |
269 | uniform_int_impl! { u16, u16, u32 } |
270 | uniform_int_impl! { u32, u32, u32 } |
271 | uniform_int_impl! { u64, u64, u64 } |
272 | uniform_int_impl! { u128, u128, u128 } |
273 | |
274 | #[cfg (feature = "simd_support" )] |
275 | macro_rules! uniform_simd_int_impl { |
276 | ($ty:ident, $unsigned:ident) => { |
277 | // The "pick the largest zone that can fit in an `u32`" optimization |
278 | // is less useful here. Multiple lanes complicate things, we don't |
279 | // know the PRNG's minimal output size, and casting to a larger vector |
280 | // is generally a bad idea for SIMD performance. The user can still |
281 | // implement it manually. |
282 | |
283 | #[cfg(feature = "simd_support" )] |
284 | impl<const LANES: usize> SampleUniform for Simd<$ty, LANES> |
285 | where |
286 | LaneCount<LANES>: SupportedLaneCount, |
287 | Simd<$unsigned, LANES>: |
288 | WideningMultiply<Output = (Simd<$unsigned, LANES>, Simd<$unsigned, LANES>)>, |
289 | StandardUniform: Distribution<Simd<$unsigned, LANES>>, |
290 | { |
291 | type Sampler = UniformInt<Simd<$ty, LANES>>; |
292 | } |
293 | |
294 | #[cfg(feature = "simd_support" )] |
295 | impl<const LANES: usize> UniformSampler for UniformInt<Simd<$ty, LANES>> |
296 | where |
297 | LaneCount<LANES>: SupportedLaneCount, |
298 | Simd<$unsigned, LANES>: |
299 | WideningMultiply<Output = (Simd<$unsigned, LANES>, Simd<$unsigned, LANES>)>, |
300 | StandardUniform: Distribution<Simd<$unsigned, LANES>>, |
301 | { |
302 | type X = Simd<$ty, LANES>; |
303 | |
304 | #[inline] // if the range is constant, this helps LLVM to do the |
305 | // calculations at compile-time. |
306 | fn new<B1, B2>(low_b: B1, high_b: B2) -> Result<Self, Error> |
307 | where B1: SampleBorrow<Self::X> + Sized, |
308 | B2: SampleBorrow<Self::X> + Sized |
309 | { |
310 | let low = *low_b.borrow(); |
311 | let high = *high_b.borrow(); |
312 | if !(low.simd_lt(high).all()) { |
313 | return Err(Error::EmptyRange); |
314 | } |
315 | UniformSampler::new_inclusive(low, high - Simd::splat(1)) |
316 | } |
317 | |
318 | #[inline] // if the range is constant, this helps LLVM to do the |
319 | // calculations at compile-time. |
320 | fn new_inclusive<B1, B2>(low_b: B1, high_b: B2) -> Result<Self, Error> |
321 | where B1: SampleBorrow<Self::X> + Sized, |
322 | B2: SampleBorrow<Self::X> + Sized |
323 | { |
324 | let low = *low_b.borrow(); |
325 | let high = *high_b.borrow(); |
326 | if !(low.simd_le(high).all()) { |
327 | return Err(Error::EmptyRange); |
328 | } |
329 | |
330 | // NOTE: all `Simd` operations are inherently wrapping, |
331 | // see https://doc.rust-lang.org/std/simd/struct.Simd.html |
332 | let range: Simd<$unsigned, LANES> = ((high - low) + Simd::splat(1)).cast(); |
333 | |
334 | // We must avoid divide-by-zero by using 0 % 1 == 0. |
335 | let not_full_range = range.simd_gt(Simd::splat(0)); |
336 | let modulo = not_full_range.select(range, Simd::splat(1)); |
337 | let ints_to_reject = range.wrapping_neg() % modulo; |
338 | |
339 | Ok(UniformInt { |
340 | low, |
341 | // These are really $unsigned values, but store as $ty: |
342 | range: range.cast(), |
343 | thresh: ints_to_reject.cast(), |
344 | }) |
345 | } |
346 | |
347 | fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X { |
348 | let range: Simd<$unsigned, LANES> = self.range.cast(); |
349 | let thresh: Simd<$unsigned, LANES> = self.thresh.cast(); |
350 | |
351 | // This might seem very slow, generating a whole new |
352 | // SIMD vector for every sample rejection. For most uses |
353 | // though, the chance of rejection is small and provides good |
354 | // general performance. With multiple lanes, that chance is |
355 | // multiplied. To mitigate this, we replace only the lanes of |
356 | // the vector which fail, iteratively reducing the chance of |
357 | // rejection. The replacement method does however add a little |
358 | // overhead. Benchmarking or calculating probabilities might |
359 | // reveal contexts where this replacement method is slower. |
360 | let mut v: Simd<$unsigned, LANES> = rng.random(); |
361 | loop { |
362 | let (hi, lo) = v.wmul(range); |
363 | let mask = lo.simd_ge(thresh); |
364 | if mask.all() { |
365 | let hi: Simd<$ty, LANES> = hi.cast(); |
366 | // wrapping addition |
367 | let result = self.low + hi; |
368 | // `select` here compiles to a blend operation |
369 | // When `range.eq(0).none()` the compare and blend |
370 | // operations are avoided. |
371 | let v: Simd<$ty, LANES> = v.cast(); |
372 | return range.simd_gt(Simd::splat(0)).select(result, v); |
373 | } |
374 | // Replace only the failing lanes |
375 | v = mask.select(v, rng.random()); |
376 | } |
377 | } |
378 | } |
379 | }; |
380 | |
381 | // bulk implementation |
382 | ($(($unsigned:ident, $signed:ident)),+) => { |
383 | $( |
384 | uniform_simd_int_impl!($unsigned, $unsigned); |
385 | uniform_simd_int_impl!($signed, $unsigned); |
386 | )+ |
387 | }; |
388 | } |
389 | |
390 | #[cfg (feature = "simd_support" )] |
391 | uniform_simd_int_impl! { (u8, i8), (u16, i16), (u32, i32), (u64, i64) } |
392 | |
393 | /// The back-end implementing [`UniformSampler`] for `usize`. |
394 | /// |
395 | /// # Implementation notes |
396 | /// |
397 | /// Sampling a `usize` value is usually used in relation to the length of an |
398 | /// array or other memory structure, thus it is reasonable to assume that the |
399 | /// vast majority of use-cases will have a maximum size under [`u32::MAX`]. |
400 | /// In part to optimise for this use-case, but mostly to ensure that results |
401 | /// are portable across 32-bit and 64-bit architectures (as far as is possible), |
402 | /// this implementation will use 32-bit sampling when possible. |
403 | #[cfg (any(target_pointer_width = "32" , target_pointer_width = "64" ))] |
404 | #[derive(Clone, Copy, Debug, PartialEq, Eq)] |
405 | pub struct UniformUsize { |
406 | low: usize, |
407 | range: usize, |
408 | thresh: usize, |
409 | #[cfg (target_pointer_width = "64" )] |
410 | mode64: bool, |
411 | } |
412 | |
413 | #[cfg (any(target_pointer_width = "32" , target_pointer_width = "64" ))] |
414 | impl SampleUniform for usize { |
415 | type Sampler = UniformUsize; |
416 | } |
417 | |
418 | #[cfg (any(target_pointer_width = "32" , target_pointer_width = "64" ))] |
419 | impl UniformSampler for UniformUsize { |
420 | type X = usize; |
421 | |
422 | #[inline ] // if the range is constant, this helps LLVM to do the |
423 | // calculations at compile-time. |
424 | fn new<B1, B2>(low_b: B1, high_b: B2) -> Result<Self, Error> |
425 | where |
426 | B1: SampleBorrow<Self::X> + Sized, |
427 | B2: SampleBorrow<Self::X> + Sized, |
428 | { |
429 | let low = *low_b.borrow(); |
430 | let high = *high_b.borrow(); |
431 | if !(low < high) { |
432 | return Err(Error::EmptyRange); |
433 | } |
434 | |
435 | UniformSampler::new_inclusive(low, high - 1) |
436 | } |
437 | |
438 | #[inline ] // if the range is constant, this helps LLVM to do the |
439 | // calculations at compile-time. |
440 | fn new_inclusive<B1, B2>(low_b: B1, high_b: B2) -> Result<Self, Error> |
441 | where |
442 | B1: SampleBorrow<Self::X> + Sized, |
443 | B2: SampleBorrow<Self::X> + Sized, |
444 | { |
445 | let low = *low_b.borrow(); |
446 | let high = *high_b.borrow(); |
447 | if !(low <= high) { |
448 | return Err(Error::EmptyRange); |
449 | } |
450 | |
451 | #[cfg (target_pointer_width = "64" )] |
452 | let mode64 = high > (u32::MAX as usize); |
453 | #[cfg (target_pointer_width = "32" )] |
454 | let mode64 = false; |
455 | |
456 | let (range, thresh); |
457 | if cfg!(target_pointer_width = "64" ) && !mode64 { |
458 | let range32 = (high as u32).wrapping_sub(low as u32).wrapping_add(1); |
459 | range = range32 as usize; |
460 | thresh = if range32 > 0 { |
461 | (range32.wrapping_neg() % range32) as usize |
462 | } else { |
463 | 0 |
464 | }; |
465 | } else { |
466 | range = high.wrapping_sub(low).wrapping_add(1); |
467 | thresh = if range > 0 { |
468 | range.wrapping_neg() % range |
469 | } else { |
470 | 0 |
471 | }; |
472 | } |
473 | |
474 | Ok(UniformUsize { |
475 | low, |
476 | range, |
477 | thresh, |
478 | #[cfg (target_pointer_width = "64" )] |
479 | mode64, |
480 | }) |
481 | } |
482 | |
483 | #[inline ] |
484 | fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize { |
485 | #[cfg (target_pointer_width = "32" )] |
486 | let mode32 = true; |
487 | #[cfg (target_pointer_width = "64" )] |
488 | let mode32 = !self.mode64; |
489 | |
490 | if mode32 { |
491 | let range = self.range as u32; |
492 | if range == 0 { |
493 | return rng.random::<u32>() as usize; |
494 | } |
495 | |
496 | let thresh = self.thresh as u32; |
497 | let hi = loop { |
498 | let (hi, lo) = rng.random::<u32>().wmul(range); |
499 | if lo >= thresh { |
500 | break hi; |
501 | } |
502 | }; |
503 | self.low.wrapping_add(hi as usize) |
504 | } else { |
505 | let range = self.range as u64; |
506 | if range == 0 { |
507 | return rng.random::<u64>() as usize; |
508 | } |
509 | |
510 | let thresh = self.thresh as u64; |
511 | let hi = loop { |
512 | let (hi, lo) = rng.random::<u64>().wmul(range); |
513 | if lo >= thresh { |
514 | break hi; |
515 | } |
516 | }; |
517 | self.low.wrapping_add(hi as usize) |
518 | } |
519 | } |
520 | |
521 | #[inline ] |
522 | fn sample_single<R: Rng + ?Sized, B1, B2>( |
523 | low_b: B1, |
524 | high_b: B2, |
525 | rng: &mut R, |
526 | ) -> Result<Self::X, Error> |
527 | where |
528 | B1: SampleBorrow<Self::X> + Sized, |
529 | B2: SampleBorrow<Self::X> + Sized, |
530 | { |
531 | let low = *low_b.borrow(); |
532 | let high = *high_b.borrow(); |
533 | if !(low < high) { |
534 | return Err(Error::EmptyRange); |
535 | } |
536 | |
537 | if cfg!(target_pointer_width = "64" ) && high > (u32::MAX as usize) { |
538 | return UniformInt::<u64>::sample_single(low as u64, high as u64, rng) |
539 | .map(|x| x as usize); |
540 | } |
541 | |
542 | UniformInt::<u32>::sample_single(low as u32, high as u32, rng).map(|x| x as usize) |
543 | } |
544 | |
545 | #[inline ] |
546 | fn sample_single_inclusive<R: Rng + ?Sized, B1, B2>( |
547 | low_b: B1, |
548 | high_b: B2, |
549 | rng: &mut R, |
550 | ) -> Result<Self::X, Error> |
551 | where |
552 | B1: SampleBorrow<Self::X> + Sized, |
553 | B2: SampleBorrow<Self::X> + Sized, |
554 | { |
555 | let low = *low_b.borrow(); |
556 | let high = *high_b.borrow(); |
557 | if !(low <= high) { |
558 | return Err(Error::EmptyRange); |
559 | } |
560 | |
561 | if cfg!(target_pointer_width = "64" ) && high > (u32::MAX as usize) { |
562 | return UniformInt::<u64>::sample_single_inclusive(low as u64, high as u64, rng) |
563 | .map(|x| x as usize); |
564 | } |
565 | |
566 | UniformInt::<u32>::sample_single_inclusive(low as u32, high as u32, rng).map(|x| x as usize) |
567 | } |
568 | } |
569 | |
570 | #[cfg (test)] |
571 | mod tests { |
572 | use super::*; |
573 | use crate::distr::{Distribution, Uniform}; |
574 | use core::fmt::Debug; |
575 | use core::ops::Add; |
576 | |
577 | #[test] |
578 | fn test_uniform_bad_limits_equal_int() { |
579 | assert_eq!(Uniform::new(10, 10), Err(Error::EmptyRange)); |
580 | } |
581 | |
582 | #[test] |
583 | fn test_uniform_good_limits_equal_int() { |
584 | let mut rng = crate::test::rng(804); |
585 | let dist = Uniform::new_inclusive(10, 10).unwrap(); |
586 | for _ in 0..20 { |
587 | assert_eq!(rng.sample(dist), 10); |
588 | } |
589 | } |
590 | |
591 | #[test] |
592 | fn test_uniform_bad_limits_flipped_int() { |
593 | assert_eq!(Uniform::new(10, 5), Err(Error::EmptyRange)); |
594 | } |
595 | |
596 | #[test] |
597 | #[cfg_attr (miri, ignore)] // Miri is too slow |
598 | fn test_integers() { |
599 | let mut rng = crate::test::rng(251); |
600 | macro_rules! t { |
601 | ($ty:ident, $v:expr, $le:expr, $lt:expr) => {{ |
602 | for &(low, high) in $v.iter() { |
603 | let my_uniform = Uniform::new(low, high).unwrap(); |
604 | for _ in 0..1000 { |
605 | let v: $ty = rng.sample(my_uniform); |
606 | assert!($le(low, v) && $lt(v, high)); |
607 | } |
608 | |
609 | let my_uniform = Uniform::new_inclusive(low, high).unwrap(); |
610 | for _ in 0..1000 { |
611 | let v: $ty = rng.sample(my_uniform); |
612 | assert!($le(low, v) && $le(v, high)); |
613 | } |
614 | |
615 | let my_uniform = Uniform::new(&low, high).unwrap(); |
616 | for _ in 0..1000 { |
617 | let v: $ty = rng.sample(my_uniform); |
618 | assert!($le(low, v) && $lt(v, high)); |
619 | } |
620 | |
621 | let my_uniform = Uniform::new_inclusive(&low, &high).unwrap(); |
622 | for _ in 0..1000 { |
623 | let v: $ty = rng.sample(my_uniform); |
624 | assert!($le(low, v) && $le(v, high)); |
625 | } |
626 | |
627 | for _ in 0..1000 { |
628 | let v = <$ty as SampleUniform>::Sampler::sample_single(low, high, &mut rng).unwrap(); |
629 | assert!($le(low, v) && $lt(v, high)); |
630 | } |
631 | |
632 | for _ in 0..1000 { |
633 | let v = <$ty as SampleUniform>::Sampler::sample_single_inclusive(low, high, &mut rng).unwrap(); |
634 | assert!($le(low, v) && $le(v, high)); |
635 | } |
636 | } |
637 | }}; |
638 | |
639 | // scalar bulk |
640 | ($($ty:ident),*) => {{ |
641 | $(t!( |
642 | $ty, |
643 | [(0, 10), (10, 127), ($ty::MIN, $ty::MAX)], |
644 | |x, y| x <= y, |
645 | |x, y| x < y |
646 | );)* |
647 | }}; |
648 | |
649 | // simd bulk |
650 | ($($ty:ident),* => $scalar:ident) => {{ |
651 | $(t!( |
652 | $ty, |
653 | [ |
654 | ($ty::splat(0), $ty::splat(10)), |
655 | ($ty::splat(10), $ty::splat(127)), |
656 | ($ty::splat($scalar::MIN), $ty::splat($scalar::MAX)), |
657 | ], |
658 | |x: $ty, y| x.simd_le(y).all(), |
659 | |x: $ty, y| x.simd_lt(y).all() |
660 | );)* |
661 | }}; |
662 | } |
663 | t!(i8, i16, i32, i64, i128, u8, u16, u32, u64, usize, u128); |
664 | |
665 | #[cfg (feature = "simd_support" )] |
666 | { |
667 | t!(u8x4, u8x8, u8x16, u8x32, u8x64 => u8); |
668 | t!(i8x4, i8x8, i8x16, i8x32, i8x64 => i8); |
669 | t!(u16x2, u16x4, u16x8, u16x16, u16x32 => u16); |
670 | t!(i16x2, i16x4, i16x8, i16x16, i16x32 => i16); |
671 | t!(u32x2, u32x4, u32x8, u32x16 => u32); |
672 | t!(i32x2, i32x4, i32x8, i32x16 => i32); |
673 | t!(u64x2, u64x4, u64x8 => u64); |
674 | t!(i64x2, i64x4, i64x8 => i64); |
675 | } |
676 | } |
677 | |
678 | #[test] |
679 | fn test_uniform_from_std_range() { |
680 | let r = Uniform::try_from(2u32..7).unwrap(); |
681 | assert_eq!(r.0.low, 2); |
682 | assert_eq!(r.0.range, 5); |
683 | } |
684 | |
685 | #[test] |
686 | fn test_uniform_from_std_range_bad_limits() { |
687 | #![allow (clippy::reversed_empty_ranges)] |
688 | assert!(Uniform::try_from(100..10).is_err()); |
689 | assert!(Uniform::try_from(100..100).is_err()); |
690 | } |
691 | |
692 | #[test] |
693 | fn test_uniform_from_std_range_inclusive() { |
694 | let r = Uniform::try_from(2u32..=6).unwrap(); |
695 | assert_eq!(r.0.low, 2); |
696 | assert_eq!(r.0.range, 5); |
697 | } |
698 | |
699 | #[test] |
700 | fn test_uniform_from_std_range_inclusive_bad_limits() { |
701 | #![allow (clippy::reversed_empty_ranges)] |
702 | assert!(Uniform::try_from(100..=10).is_err()); |
703 | assert!(Uniform::try_from(100..=99).is_err()); |
704 | } |
705 | |
706 | #[test] |
707 | fn value_stability() { |
708 | fn test_samples<T: SampleUniform + Copy + Debug + PartialEq + Add<T>>( |
709 | lb: T, |
710 | ub: T, |
711 | ub_excl: T, |
712 | expected: &[T], |
713 | ) where |
714 | Uniform<T>: Distribution<T>, |
715 | { |
716 | let mut rng = crate::test::rng(897); |
717 | let mut buf = [lb; 6]; |
718 | |
719 | for x in &mut buf[0..3] { |
720 | *x = T::Sampler::sample_single_inclusive(lb, ub, &mut rng).unwrap(); |
721 | } |
722 | |
723 | let distr = Uniform::new_inclusive(lb, ub).unwrap(); |
724 | for x in &mut buf[3..6] { |
725 | *x = rng.sample(&distr); |
726 | } |
727 | assert_eq!(&buf, expected); |
728 | |
729 | let mut rng = crate::test::rng(897); |
730 | |
731 | for x in &mut buf[0..3] { |
732 | *x = T::Sampler::sample_single(lb, ub_excl, &mut rng).unwrap(); |
733 | } |
734 | |
735 | let distr = Uniform::new(lb, ub_excl).unwrap(); |
736 | for x in &mut buf[3..6] { |
737 | *x = rng.sample(&distr); |
738 | } |
739 | assert_eq!(&buf, expected); |
740 | } |
741 | |
742 | test_samples(-105i8, 111, 112, &[-99, -48, 107, 72, -19, 56]); |
743 | test_samples(2i16, 1352, 1353, &[43, 361, 1325, 1109, 539, 1005]); |
744 | test_samples( |
745 | -313853i32, |
746 | 13513, |
747 | 13514, |
748 | &[-303803, -226673, 6912, -45605, -183505, -70668], |
749 | ); |
750 | test_samples( |
751 | 131521i64, |
752 | 6542165, |
753 | 6542166, |
754 | &[1838724, 5384489, 4893692, 3712948, 3951509, 4094926], |
755 | ); |
756 | test_samples( |
757 | -0x8000_0000_0000_0000_0000_0000_0000_0000i128, |
758 | -1, |
759 | 0, |
760 | &[ |
761 | -30725222750250982319765550926688025855, |
762 | -75088619368053423329503924805178012357, |
763 | -64950748766625548510467638647674468829, |
764 | -41794017901603587121582892414659436495, |
765 | -63623852319608406524605295913876414006, |
766 | -17404679390297612013597359206379189023, |
767 | ], |
768 | ); |
769 | test_samples(11u8, 218, 219, &[17, 66, 214, 181, 93, 165]); |
770 | test_samples(11u16, 218, 219, &[17, 66, 214, 181, 93, 165]); |
771 | test_samples(11u32, 218, 219, &[17, 66, 214, 181, 93, 165]); |
772 | test_samples(11u64, 218, 219, &[66, 181, 165, 127, 134, 139]); |
773 | test_samples(11u128, 218, 219, &[181, 127, 139, 167, 141, 197]); |
774 | test_samples(11usize, 218, 219, &[17, 66, 214, 181, 93, 165]); |
775 | |
776 | #[cfg (feature = "simd_support" )] |
777 | { |
778 | let lb = Simd::from([11u8, 0, 128, 127]); |
779 | let ub = Simd::from([218, 254, 254, 254]); |
780 | let ub_excl = ub + Simd::splat(1); |
781 | test_samples( |
782 | lb, |
783 | ub, |
784 | ub_excl, |
785 | &[ |
786 | Simd::from([13, 5, 237, 130]), |
787 | Simd::from([126, 186, 149, 161]), |
788 | Simd::from([103, 86, 234, 252]), |
789 | Simd::from([35, 18, 225, 231]), |
790 | Simd::from([106, 153, 246, 177]), |
791 | Simd::from([195, 168, 149, 222]), |
792 | ], |
793 | ); |
794 | } |
795 | } |
796 | } |
797 | |