1 | // Copyright 2018 Developers of the Rand project. |
2 | // |
3 | // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or |
4 | // https://www.apache.org/licenses/LICENSE-2.0> or the MIT license |
5 | // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your |
6 | // option. This file may not be copied, modified, or distributed |
7 | // except according to those terms. |
8 | |
9 | //! Sequence-related functionality |
10 | //! |
11 | //! This module provides: |
12 | //! |
13 | //! * [`SliceRandom`] slice sampling and mutation |
14 | //! * [`IteratorRandom`] iterator sampling |
15 | //! * [`index::sample`] low-level API to choose multiple indices from |
16 | //! `0..length` |
17 | //! |
18 | //! Also see: |
19 | //! |
20 | //! * [`crate::distributions::WeightedIndex`] distribution which provides |
21 | //! weighted index sampling. |
22 | //! |
23 | //! In order to make results reproducible across 32-64 bit architectures, all |
24 | //! `usize` indices are sampled as a `u32` where possible (also providing a |
25 | //! small performance boost in some cases). |
26 | |
27 | |
28 | #[cfg (feature = "alloc" )] |
29 | #[cfg_attr (doc_cfg, doc(cfg(feature = "alloc" )))] |
30 | pub mod index; |
31 | |
32 | #[cfg (feature = "alloc" )] use core::ops::Index; |
33 | |
34 | #[cfg (feature = "alloc" )] use alloc::vec::Vec; |
35 | |
36 | #[cfg (feature = "alloc" )] |
37 | use crate::distributions::uniform::{SampleBorrow, SampleUniform}; |
38 | #[cfg (feature = "alloc" )] use crate::distributions::WeightedError; |
39 | use crate::Rng; |
40 | |
41 | /// Extension trait on slices, providing random mutation and sampling methods. |
42 | /// |
43 | /// This trait is implemented on all `[T]` slice types, providing several |
44 | /// methods for choosing and shuffling elements. You must `use` this trait: |
45 | /// |
46 | /// ``` |
47 | /// use rand::seq::SliceRandom; |
48 | /// |
49 | /// let mut rng = rand::thread_rng(); |
50 | /// let mut bytes = "Hello, random!" .to_string().into_bytes(); |
51 | /// bytes.shuffle(&mut rng); |
52 | /// let str = String::from_utf8(bytes).unwrap(); |
53 | /// println!("{}" , str); |
54 | /// ``` |
55 | /// Example output (non-deterministic): |
56 | /// ```none |
57 | /// l,nmroHado !le |
58 | /// ``` |
59 | pub trait SliceRandom { |
60 | /// The element type. |
61 | type Item; |
62 | |
63 | /// Returns a reference to one random element of the slice, or `None` if the |
64 | /// slice is empty. |
65 | /// |
66 | /// For slices, complexity is `O(1)`. |
67 | /// |
68 | /// # Example |
69 | /// |
70 | /// ``` |
71 | /// use rand::thread_rng; |
72 | /// use rand::seq::SliceRandom; |
73 | /// |
74 | /// let choices = [1, 2, 4, 8, 16, 32]; |
75 | /// let mut rng = thread_rng(); |
76 | /// println!("{:?}" , choices.choose(&mut rng)); |
77 | /// assert_eq!(choices[..0].choose(&mut rng), None); |
78 | /// ``` |
79 | fn choose<R>(&self, rng: &mut R) -> Option<&Self::Item> |
80 | where R: Rng + ?Sized; |
81 | |
82 | /// Returns a mutable reference to one random element of the slice, or |
83 | /// `None` if the slice is empty. |
84 | /// |
85 | /// For slices, complexity is `O(1)`. |
86 | fn choose_mut<R>(&mut self, rng: &mut R) -> Option<&mut Self::Item> |
87 | where R: Rng + ?Sized; |
88 | |
89 | /// Chooses `amount` elements from the slice at random, without repetition, |
90 | /// and in random order. The returned iterator is appropriate both for |
91 | /// collection into a `Vec` and filling an existing buffer (see example). |
92 | /// |
93 | /// In case this API is not sufficiently flexible, use [`index::sample`]. |
94 | /// |
95 | /// For slices, complexity is the same as [`index::sample`]. |
96 | /// |
97 | /// # Example |
98 | /// ``` |
99 | /// use rand::seq::SliceRandom; |
100 | /// |
101 | /// let mut rng = &mut rand::thread_rng(); |
102 | /// let sample = "Hello, audience!" .as_bytes(); |
103 | /// |
104 | /// // collect the results into a vector: |
105 | /// let v: Vec<u8> = sample.choose_multiple(&mut rng, 3).cloned().collect(); |
106 | /// |
107 | /// // store in a buffer: |
108 | /// let mut buf = [0u8; 5]; |
109 | /// for (b, slot) in sample.choose_multiple(&mut rng, buf.len()).zip(buf.iter_mut()) { |
110 | /// *slot = *b; |
111 | /// } |
112 | /// ``` |
113 | #[cfg (feature = "alloc" )] |
114 | #[cfg_attr (doc_cfg, doc(cfg(feature = "alloc" )))] |
115 | fn choose_multiple<R>(&self, rng: &mut R, amount: usize) -> SliceChooseIter<Self, Self::Item> |
116 | where R: Rng + ?Sized; |
117 | |
118 | /// Similar to [`choose`], but where the likelihood of each outcome may be |
119 | /// specified. |
120 | /// |
121 | /// The specified function `weight` maps each item `x` to a relative |
122 | /// likelihood `weight(x)`. The probability of each item being selected is |
123 | /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`. |
124 | /// |
125 | /// For slices of length `n`, complexity is `O(n)`. |
126 | /// See also [`choose_weighted_mut`], [`distributions::weighted`]. |
127 | /// |
128 | /// # Example |
129 | /// |
130 | /// ``` |
131 | /// use rand::prelude::*; |
132 | /// |
133 | /// let choices = [('a' , 2), ('b' , 1), ('c' , 1)]; |
134 | /// let mut rng = thread_rng(); |
135 | /// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' |
136 | /// println!("{:?}" , choices.choose_weighted(&mut rng, |item| item.1).unwrap().0); |
137 | /// ``` |
138 | /// [`choose`]: SliceRandom::choose |
139 | /// [`choose_weighted_mut`]: SliceRandom::choose_weighted_mut |
140 | /// [`distributions::weighted`]: crate::distributions::weighted |
141 | #[cfg (feature = "alloc" )] |
142 | #[cfg_attr (doc_cfg, doc(cfg(feature = "alloc" )))] |
143 | fn choose_weighted<R, F, B, X>( |
144 | &self, rng: &mut R, weight: F, |
145 | ) -> Result<&Self::Item, WeightedError> |
146 | where |
147 | R: Rng + ?Sized, |
148 | F: Fn(&Self::Item) -> B, |
149 | B: SampleBorrow<X>, |
150 | X: SampleUniform |
151 | + for<'a> ::core::ops::AddAssign<&'a X> |
152 | + ::core::cmp::PartialOrd<X> |
153 | + Clone |
154 | + Default; |
155 | |
156 | /// Similar to [`choose_mut`], but where the likelihood of each outcome may |
157 | /// be specified. |
158 | /// |
159 | /// The specified function `weight` maps each item `x` to a relative |
160 | /// likelihood `weight(x)`. The probability of each item being selected is |
161 | /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`. |
162 | /// |
163 | /// For slices of length `n`, complexity is `O(n)`. |
164 | /// See also [`choose_weighted`], [`distributions::weighted`]. |
165 | /// |
166 | /// [`choose_mut`]: SliceRandom::choose_mut |
167 | /// [`choose_weighted`]: SliceRandom::choose_weighted |
168 | /// [`distributions::weighted`]: crate::distributions::weighted |
169 | #[cfg (feature = "alloc" )] |
170 | #[cfg_attr (doc_cfg, doc(cfg(feature = "alloc" )))] |
171 | fn choose_weighted_mut<R, F, B, X>( |
172 | &mut self, rng: &mut R, weight: F, |
173 | ) -> Result<&mut Self::Item, WeightedError> |
174 | where |
175 | R: Rng + ?Sized, |
176 | F: Fn(&Self::Item) -> B, |
177 | B: SampleBorrow<X>, |
178 | X: SampleUniform |
179 | + for<'a> ::core::ops::AddAssign<&'a X> |
180 | + ::core::cmp::PartialOrd<X> |
181 | + Clone |
182 | + Default; |
183 | |
184 | /// Similar to [`choose_multiple`], but where the likelihood of each element's |
185 | /// inclusion in the output may be specified. The elements are returned in an |
186 | /// arbitrary, unspecified order. |
187 | /// |
188 | /// The specified function `weight` maps each item `x` to a relative |
189 | /// likelihood `weight(x)`. The probability of each item being selected is |
190 | /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`. |
191 | /// |
192 | /// If all of the weights are equal, even if they are all zero, each element has |
193 | /// an equal likelihood of being selected. |
194 | /// |
195 | /// The complexity of this method depends on the feature `partition_at_index`. |
196 | /// If the feature is enabled, then for slices of length `n`, the complexity |
197 | /// is `O(n)` space and `O(n)` time. Otherwise, the complexity is `O(n)` space and |
198 | /// `O(n * log amount)` time. |
199 | /// |
200 | /// # Example |
201 | /// |
202 | /// ``` |
203 | /// use rand::prelude::*; |
204 | /// |
205 | /// let choices = [('a' , 2), ('b' , 1), ('c' , 1)]; |
206 | /// let mut rng = thread_rng(); |
207 | /// // First Draw * Second Draw = total odds |
208 | /// // ----------------------- |
209 | /// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'b']` in some order. |
210 | /// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'c']` in some order. |
211 | /// // (25% * 33%) + (25% * 33%) = 16.6% chance that the output is `['b', 'c']` in some order. |
212 | /// println!("{:?}" , choices.choose_multiple_weighted(&mut rng, 2, |item| item.1).unwrap().collect::<Vec<_>>()); |
213 | /// ``` |
214 | /// [`choose_multiple`]: SliceRandom::choose_multiple |
215 | // |
216 | // Note: this is feature-gated on std due to usage of f64::powf. |
217 | // If necessary, we may use alloc+libm as an alternative (see PR #1089). |
218 | #[cfg (feature = "std" )] |
219 | #[cfg_attr (doc_cfg, doc(cfg(feature = "std" )))] |
220 | fn choose_multiple_weighted<R, F, X>( |
221 | &self, rng: &mut R, amount: usize, weight: F, |
222 | ) -> Result<SliceChooseIter<Self, Self::Item>, WeightedError> |
223 | where |
224 | R: Rng + ?Sized, |
225 | F: Fn(&Self::Item) -> X, |
226 | X: Into<f64>; |
227 | |
228 | /// Shuffle a mutable slice in place. |
229 | /// |
230 | /// For slices of length `n`, complexity is `O(n)`. |
231 | /// |
232 | /// # Example |
233 | /// |
234 | /// ``` |
235 | /// use rand::seq::SliceRandom; |
236 | /// use rand::thread_rng; |
237 | /// |
238 | /// let mut rng = thread_rng(); |
239 | /// let mut y = [1, 2, 3, 4, 5]; |
240 | /// println!("Unshuffled: {:?}" , y); |
241 | /// y.shuffle(&mut rng); |
242 | /// println!("Shuffled: {:?}" , y); |
243 | /// ``` |
244 | fn shuffle<R>(&mut self, rng: &mut R) |
245 | where R: Rng + ?Sized; |
246 | |
247 | /// Shuffle a slice in place, but exit early. |
248 | /// |
249 | /// Returns two mutable slices from the source slice. The first contains |
250 | /// `amount` elements randomly permuted. The second has the remaining |
251 | /// elements that are not fully shuffled. |
252 | /// |
253 | /// This is an efficient method to select `amount` elements at random from |
254 | /// the slice, provided the slice may be mutated. |
255 | /// |
256 | /// If you only need to choose elements randomly and `amount > self.len()/2` |
257 | /// then you may improve performance by taking |
258 | /// `amount = values.len() - amount` and using only the second slice. |
259 | /// |
260 | /// If `amount` is greater than the number of elements in the slice, this |
261 | /// will perform a full shuffle. |
262 | /// |
263 | /// For slices, complexity is `O(m)` where `m = amount`. |
264 | fn partial_shuffle<R>( |
265 | &mut self, rng: &mut R, amount: usize, |
266 | ) -> (&mut [Self::Item], &mut [Self::Item]) |
267 | where R: Rng + ?Sized; |
268 | } |
269 | |
270 | /// Extension trait on iterators, providing random sampling methods. |
271 | /// |
272 | /// This trait is implemented on all iterators `I` where `I: Iterator + Sized` |
273 | /// and provides methods for |
274 | /// choosing one or more elements. You must `use` this trait: |
275 | /// |
276 | /// ``` |
277 | /// use rand::seq::IteratorRandom; |
278 | /// |
279 | /// let mut rng = rand::thread_rng(); |
280 | /// |
281 | /// let faces = "😀😎😐😕😠😢" ; |
282 | /// println!("I am {}!" , faces.chars().choose(&mut rng).unwrap()); |
283 | /// ``` |
284 | /// Example output (non-deterministic): |
285 | /// ```none |
286 | /// I am 😀! |
287 | /// ``` |
288 | pub trait IteratorRandom: Iterator + Sized { |
289 | /// Choose one element at random from the iterator. |
290 | /// |
291 | /// Returns `None` if and only if the iterator is empty. |
292 | /// |
293 | /// This method uses [`Iterator::size_hint`] for optimisation. With an |
294 | /// accurate hint and where [`Iterator::nth`] is a constant-time operation |
295 | /// this method can offer `O(1)` performance. Where no size hint is |
296 | /// available, complexity is `O(n)` where `n` is the iterator length. |
297 | /// Partial hints (where `lower > 0`) also improve performance. |
298 | /// |
299 | /// Note that the output values and the number of RNG samples used |
300 | /// depends on size hints. In particular, `Iterator` combinators that don't |
301 | /// change the values yielded but change the size hints may result in |
302 | /// `choose` returning different elements. If you want consistent results |
303 | /// and RNG usage consider using [`IteratorRandom::choose_stable`]. |
304 | fn choose<R>(mut self, rng: &mut R) -> Option<Self::Item> |
305 | where R: Rng + ?Sized { |
306 | let (mut lower, mut upper) = self.size_hint(); |
307 | let mut consumed = 0; |
308 | let mut result = None; |
309 | |
310 | // Handling for this condition outside the loop allows the optimizer to eliminate the loop |
311 | // when the Iterator is an ExactSizeIterator. This has a large performance impact on e.g. |
312 | // seq_iter_choose_from_1000. |
313 | if upper == Some(lower) { |
314 | return if lower == 0 { |
315 | None |
316 | } else { |
317 | self.nth(gen_index(rng, lower)) |
318 | }; |
319 | } |
320 | |
321 | // Continue until the iterator is exhausted |
322 | loop { |
323 | if lower > 1 { |
324 | let ix = gen_index(rng, lower + consumed); |
325 | let skip = if ix < lower { |
326 | result = self.nth(ix); |
327 | lower - (ix + 1) |
328 | } else { |
329 | lower |
330 | }; |
331 | if upper == Some(lower) { |
332 | return result; |
333 | } |
334 | consumed += lower; |
335 | if skip > 0 { |
336 | self.nth(skip - 1); |
337 | } |
338 | } else { |
339 | let elem = self.next(); |
340 | if elem.is_none() { |
341 | return result; |
342 | } |
343 | consumed += 1; |
344 | if gen_index(rng, consumed) == 0 { |
345 | result = elem; |
346 | } |
347 | } |
348 | |
349 | let hint = self.size_hint(); |
350 | lower = hint.0; |
351 | upper = hint.1; |
352 | } |
353 | } |
354 | |
355 | /// Choose one element at random from the iterator. |
356 | /// |
357 | /// Returns `None` if and only if the iterator is empty. |
358 | /// |
359 | /// This method is very similar to [`choose`] except that the result |
360 | /// only depends on the length of the iterator and the values produced by |
361 | /// `rng`. Notably for any iterator of a given length this will make the |
362 | /// same requests to `rng` and if the same sequence of values are produced |
363 | /// the same index will be selected from `self`. This may be useful if you |
364 | /// need consistent results no matter what type of iterator you are working |
365 | /// with. If you do not need this stability prefer [`choose`]. |
366 | /// |
367 | /// Note that this method still uses [`Iterator::size_hint`] to skip |
368 | /// constructing elements where possible, however the selection and `rng` |
369 | /// calls are the same in the face of this optimization. If you want to |
370 | /// force every element to be created regardless call `.inspect(|e| ())`. |
371 | /// |
372 | /// [`choose`]: IteratorRandom::choose |
373 | fn choose_stable<R>(mut self, rng: &mut R) -> Option<Self::Item> |
374 | where R: Rng + ?Sized { |
375 | let mut consumed = 0; |
376 | let mut result = None; |
377 | |
378 | loop { |
379 | // Currently the only way to skip elements is `nth()`. So we need to |
380 | // store what index to access next here. |
381 | // This should be replaced by `advance_by()` once it is stable: |
382 | // https://github.com/rust-lang/rust/issues/77404 |
383 | let mut next = 0; |
384 | |
385 | let (lower, _) = self.size_hint(); |
386 | if lower >= 2 { |
387 | let highest_selected = (0..lower) |
388 | .filter(|ix| gen_index(rng, consumed+ix+1) == 0) |
389 | .last(); |
390 | |
391 | consumed += lower; |
392 | next = lower; |
393 | |
394 | if let Some(ix) = highest_selected { |
395 | result = self.nth(ix); |
396 | next -= ix + 1; |
397 | debug_assert!(result.is_some(), "iterator shorter than size_hint().0" ); |
398 | } |
399 | } |
400 | |
401 | let elem = self.nth(next); |
402 | if elem.is_none() { |
403 | return result |
404 | } |
405 | |
406 | if gen_index(rng, consumed+1) == 0 { |
407 | result = elem; |
408 | } |
409 | consumed += 1; |
410 | } |
411 | } |
412 | |
413 | /// Collects values at random from the iterator into a supplied buffer |
414 | /// until that buffer is filled. |
415 | /// |
416 | /// Although the elements are selected randomly, the order of elements in |
417 | /// the buffer is neither stable nor fully random. If random ordering is |
418 | /// desired, shuffle the result. |
419 | /// |
420 | /// Returns the number of elements added to the buffer. This equals the length |
421 | /// of the buffer unless the iterator contains insufficient elements, in which |
422 | /// case this equals the number of elements available. |
423 | /// |
424 | /// Complexity is `O(n)` where `n` is the length of the iterator. |
425 | /// For slices, prefer [`SliceRandom::choose_multiple`]. |
426 | fn choose_multiple_fill<R>(mut self, rng: &mut R, buf: &mut [Self::Item]) -> usize |
427 | where R: Rng + ?Sized { |
428 | let amount = buf.len(); |
429 | let mut len = 0; |
430 | while len < amount { |
431 | if let Some(elem) = self.next() { |
432 | buf[len] = elem; |
433 | len += 1; |
434 | } else { |
435 | // Iterator exhausted; stop early |
436 | return len; |
437 | } |
438 | } |
439 | |
440 | // Continue, since the iterator was not exhausted |
441 | for (i, elem) in self.enumerate() { |
442 | let k = gen_index(rng, i + 1 + amount); |
443 | if let Some(slot) = buf.get_mut(k) { |
444 | *slot = elem; |
445 | } |
446 | } |
447 | len |
448 | } |
449 | |
450 | /// Collects `amount` values at random from the iterator into a vector. |
451 | /// |
452 | /// This is equivalent to `choose_multiple_fill` except for the result type. |
453 | /// |
454 | /// Although the elements are selected randomly, the order of elements in |
455 | /// the buffer is neither stable nor fully random. If random ordering is |
456 | /// desired, shuffle the result. |
457 | /// |
458 | /// The length of the returned vector equals `amount` unless the iterator |
459 | /// contains insufficient elements, in which case it equals the number of |
460 | /// elements available. |
461 | /// |
462 | /// Complexity is `O(n)` where `n` is the length of the iterator. |
463 | /// For slices, prefer [`SliceRandom::choose_multiple`]. |
464 | #[cfg (feature = "alloc" )] |
465 | #[cfg_attr (doc_cfg, doc(cfg(feature = "alloc" )))] |
466 | fn choose_multiple<R>(mut self, rng: &mut R, amount: usize) -> Vec<Self::Item> |
467 | where R: Rng + ?Sized { |
468 | let mut reservoir = Vec::with_capacity(amount); |
469 | reservoir.extend(self.by_ref().take(amount)); |
470 | |
471 | // Continue unless the iterator was exhausted |
472 | // |
473 | // note: this prevents iterators that "restart" from causing problems. |
474 | // If the iterator stops once, then so do we. |
475 | if reservoir.len() == amount { |
476 | for (i, elem) in self.enumerate() { |
477 | let k = gen_index(rng, i + 1 + amount); |
478 | if let Some(slot) = reservoir.get_mut(k) { |
479 | *slot = elem; |
480 | } |
481 | } |
482 | } else { |
483 | // Don't hang onto extra memory. There is a corner case where |
484 | // `amount` was much less than `self.len()`. |
485 | reservoir.shrink_to_fit(); |
486 | } |
487 | reservoir |
488 | } |
489 | } |
490 | |
491 | |
492 | impl<T> SliceRandom for [T] { |
493 | type Item = T; |
494 | |
495 | fn choose<R>(&self, rng: &mut R) -> Option<&Self::Item> |
496 | where R: Rng + ?Sized { |
497 | if self.is_empty() { |
498 | None |
499 | } else { |
500 | Some(&self[gen_index(rng, self.len())]) |
501 | } |
502 | } |
503 | |
504 | fn choose_mut<R>(&mut self, rng: &mut R) -> Option<&mut Self::Item> |
505 | where R: Rng + ?Sized { |
506 | if self.is_empty() { |
507 | None |
508 | } else { |
509 | let len = self.len(); |
510 | Some(&mut self[gen_index(rng, len)]) |
511 | } |
512 | } |
513 | |
514 | #[cfg (feature = "alloc" )] |
515 | fn choose_multiple<R>(&self, rng: &mut R, amount: usize) -> SliceChooseIter<Self, Self::Item> |
516 | where R: Rng + ?Sized { |
517 | let amount = ::core::cmp::min(amount, self.len()); |
518 | SliceChooseIter { |
519 | slice: self, |
520 | _phantom: Default::default(), |
521 | indices: index::sample(rng, self.len(), amount).into_iter(), |
522 | } |
523 | } |
524 | |
525 | #[cfg (feature = "alloc" )] |
526 | fn choose_weighted<R, F, B, X>( |
527 | &self, rng: &mut R, weight: F, |
528 | ) -> Result<&Self::Item, WeightedError> |
529 | where |
530 | R: Rng + ?Sized, |
531 | F: Fn(&Self::Item) -> B, |
532 | B: SampleBorrow<X>, |
533 | X: SampleUniform |
534 | + for<'a> ::core::ops::AddAssign<&'a X> |
535 | + ::core::cmp::PartialOrd<X> |
536 | + Clone |
537 | + Default, |
538 | { |
539 | use crate::distributions::{Distribution, WeightedIndex}; |
540 | let distr = WeightedIndex::new(self.iter().map(weight))?; |
541 | Ok(&self[distr.sample(rng)]) |
542 | } |
543 | |
544 | #[cfg (feature = "alloc" )] |
545 | fn choose_weighted_mut<R, F, B, X>( |
546 | &mut self, rng: &mut R, weight: F, |
547 | ) -> Result<&mut Self::Item, WeightedError> |
548 | where |
549 | R: Rng + ?Sized, |
550 | F: Fn(&Self::Item) -> B, |
551 | B: SampleBorrow<X>, |
552 | X: SampleUniform |
553 | + for<'a> ::core::ops::AddAssign<&'a X> |
554 | + ::core::cmp::PartialOrd<X> |
555 | + Clone |
556 | + Default, |
557 | { |
558 | use crate::distributions::{Distribution, WeightedIndex}; |
559 | let distr = WeightedIndex::new(self.iter().map(weight))?; |
560 | Ok(&mut self[distr.sample(rng)]) |
561 | } |
562 | |
563 | #[cfg (feature = "std" )] |
564 | fn choose_multiple_weighted<R, F, X>( |
565 | &self, rng: &mut R, amount: usize, weight: F, |
566 | ) -> Result<SliceChooseIter<Self, Self::Item>, WeightedError> |
567 | where |
568 | R: Rng + ?Sized, |
569 | F: Fn(&Self::Item) -> X, |
570 | X: Into<f64>, |
571 | { |
572 | let amount = ::core::cmp::min(amount, self.len()); |
573 | Ok(SliceChooseIter { |
574 | slice: self, |
575 | _phantom: Default::default(), |
576 | indices: index::sample_weighted( |
577 | rng, |
578 | self.len(), |
579 | |idx| weight(&self[idx]).into(), |
580 | amount, |
581 | )? |
582 | .into_iter(), |
583 | }) |
584 | } |
585 | |
586 | fn shuffle<R>(&mut self, rng: &mut R) |
587 | where R: Rng + ?Sized { |
588 | for i in (1..self.len()).rev() { |
589 | // invariant: elements with index > i have been locked in place. |
590 | self.swap(i, gen_index(rng, i + 1)); |
591 | } |
592 | } |
593 | |
594 | fn partial_shuffle<R>( |
595 | &mut self, rng: &mut R, amount: usize, |
596 | ) -> (&mut [Self::Item], &mut [Self::Item]) |
597 | where R: Rng + ?Sized { |
598 | // This applies Durstenfeld's algorithm for the |
599 | // [Fisher–Yates shuffle](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm) |
600 | // for an unbiased permutation, but exits early after choosing `amount` |
601 | // elements. |
602 | |
603 | let len = self.len(); |
604 | let end = if amount >= len { 0 } else { len - amount }; |
605 | |
606 | for i in (end..len).rev() { |
607 | // invariant: elements with index > i have been locked in place. |
608 | self.swap(i, gen_index(rng, i + 1)); |
609 | } |
610 | let r = self.split_at_mut(end); |
611 | (r.1, r.0) |
612 | } |
613 | } |
614 | |
615 | impl<I> IteratorRandom for I where I: Iterator + Sized {} |
616 | |
617 | |
618 | /// An iterator over multiple slice elements. |
619 | /// |
620 | /// This struct is created by |
621 | /// [`SliceRandom::choose_multiple`](trait.SliceRandom.html#tymethod.choose_multiple). |
622 | #[cfg (feature = "alloc" )] |
623 | #[cfg_attr (doc_cfg, doc(cfg(feature = "alloc" )))] |
624 | #[derive (Debug)] |
625 | pub struct SliceChooseIter<'a, S: ?Sized + 'a, T: 'a> { |
626 | slice: &'a S, |
627 | _phantom: ::core::marker::PhantomData<T>, |
628 | indices: index::IndexVecIntoIter, |
629 | } |
630 | |
631 | #[cfg (feature = "alloc" )] |
632 | impl<'a, S: Index<usize, Output = T> + ?Sized + 'a, T: 'a> Iterator for SliceChooseIter<'a, S, T> { |
633 | type Item = &'a T; |
634 | |
635 | fn next(&mut self) -> Option<Self::Item> { |
636 | // TODO: investigate using SliceIndex::get_unchecked when stable |
637 | self.indices.next().map(|i: usize| &self.slice[i as usize]) |
638 | } |
639 | |
640 | fn size_hint(&self) -> (usize, Option<usize>) { |
641 | (self.indices.len(), Some(self.indices.len())) |
642 | } |
643 | } |
644 | |
645 | #[cfg (feature = "alloc" )] |
646 | impl<'a, S: Index<usize, Output = T> + ?Sized + 'a, T: 'a> ExactSizeIterator |
647 | for SliceChooseIter<'a, S, T> |
648 | { |
649 | fn len(&self) -> usize { |
650 | self.indices.len() |
651 | } |
652 | } |
653 | |
654 | |
655 | // Sample a number uniformly between 0 and `ubound`. Uses 32-bit sampling where |
656 | // possible, primarily in order to produce the same output on 32-bit and 64-bit |
657 | // platforms. |
658 | #[inline ] |
659 | fn gen_index<R: Rng + ?Sized>(rng: &mut R, ubound: usize) -> usize { |
660 | if ubound <= (core::u32::MAX as usize) { |
661 | rng.gen_range(0..ubound as u32) as usize |
662 | } else { |
663 | rng.gen_range(0..ubound) |
664 | } |
665 | } |
666 | |
667 | |
668 | #[cfg (test)] |
669 | mod test { |
670 | use super::*; |
671 | #[cfg (feature = "alloc" )] use crate::Rng; |
672 | #[cfg (all(feature = "alloc" , not(feature = "std" )))] use alloc::vec::Vec; |
673 | |
674 | #[test ] |
675 | fn test_slice_choose() { |
676 | let mut r = crate::test::rng(107); |
677 | let chars = [ |
678 | 'a' , 'b' , 'c' , 'd' , 'e' , 'f' , 'g' , 'h' , 'i' , 'j' , 'k' , 'l' , 'm' , 'n' , |
679 | ]; |
680 | let mut chosen = [0i32; 14]; |
681 | // The below all use a binomial distribution with n=1000, p=1/14. |
682 | // binocdf(40, 1000, 1/14) ~= 2e-5; 1-binocdf(106, ..) ~= 2e-5 |
683 | for _ in 0..1000 { |
684 | let picked = *chars.choose(&mut r).unwrap(); |
685 | chosen[(picked as usize) - ('a' as usize)] += 1; |
686 | } |
687 | for count in chosen.iter() { |
688 | assert!(40 < *count && *count < 106); |
689 | } |
690 | |
691 | chosen.iter_mut().for_each(|x| *x = 0); |
692 | for _ in 0..1000 { |
693 | *chosen.choose_mut(&mut r).unwrap() += 1; |
694 | } |
695 | for count in chosen.iter() { |
696 | assert!(40 < *count && *count < 106); |
697 | } |
698 | |
699 | let mut v: [isize; 0] = []; |
700 | assert_eq!(v.choose(&mut r), None); |
701 | assert_eq!(v.choose_mut(&mut r), None); |
702 | } |
703 | |
704 | #[test ] |
705 | fn value_stability_slice() { |
706 | let mut r = crate::test::rng(413); |
707 | let chars = [ |
708 | 'a' , 'b' , 'c' , 'd' , 'e' , 'f' , 'g' , 'h' , 'i' , 'j' , 'k' , 'l' , 'm' , 'n' , |
709 | ]; |
710 | let mut nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; |
711 | |
712 | assert_eq!(chars.choose(&mut r), Some(&'l' )); |
713 | assert_eq!(nums.choose_mut(&mut r), Some(&mut 10)); |
714 | |
715 | #[cfg (feature = "alloc" )] |
716 | assert_eq!( |
717 | &chars |
718 | .choose_multiple(&mut r, 8) |
719 | .cloned() |
720 | .collect::<Vec<char>>(), |
721 | &['d' , 'm' , 'b' , 'n' , 'c' , 'k' , 'h' , 'e' ] |
722 | ); |
723 | |
724 | #[cfg (feature = "alloc" )] |
725 | assert_eq!(chars.choose_weighted(&mut r, |_| 1), Ok(&'f' )); |
726 | #[cfg (feature = "alloc" )] |
727 | assert_eq!(nums.choose_weighted_mut(&mut r, |_| 1), Ok(&mut 5)); |
728 | |
729 | let mut r = crate::test::rng(414); |
730 | nums.shuffle(&mut r); |
731 | assert_eq!(nums, [9, 5, 3, 10, 7, 12, 8, 11, 6, 4, 0, 2, 1]); |
732 | nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; |
733 | let res = nums.partial_shuffle(&mut r, 6); |
734 | assert_eq!(res.0, &mut [7, 4, 8, 6, 9, 3]); |
735 | assert_eq!(res.1, &mut [0, 1, 2, 12, 11, 5, 10]); |
736 | } |
737 | |
738 | #[derive (Clone)] |
739 | struct UnhintedIterator<I: Iterator + Clone> { |
740 | iter: I, |
741 | } |
742 | impl<I: Iterator + Clone> Iterator for UnhintedIterator<I> { |
743 | type Item = I::Item; |
744 | |
745 | fn next(&mut self) -> Option<Self::Item> { |
746 | self.iter.next() |
747 | } |
748 | } |
749 | |
750 | #[derive (Clone)] |
751 | struct ChunkHintedIterator<I: ExactSizeIterator + Iterator + Clone> { |
752 | iter: I, |
753 | chunk_remaining: usize, |
754 | chunk_size: usize, |
755 | hint_total_size: bool, |
756 | } |
757 | impl<I: ExactSizeIterator + Iterator + Clone> Iterator for ChunkHintedIterator<I> { |
758 | type Item = I::Item; |
759 | |
760 | fn next(&mut self) -> Option<Self::Item> { |
761 | if self.chunk_remaining == 0 { |
762 | self.chunk_remaining = ::core::cmp::min(self.chunk_size, self.iter.len()); |
763 | } |
764 | self.chunk_remaining = self.chunk_remaining.saturating_sub(1); |
765 | |
766 | self.iter.next() |
767 | } |
768 | |
769 | fn size_hint(&self) -> (usize, Option<usize>) { |
770 | ( |
771 | self.chunk_remaining, |
772 | if self.hint_total_size { |
773 | Some(self.iter.len()) |
774 | } else { |
775 | None |
776 | }, |
777 | ) |
778 | } |
779 | } |
780 | |
781 | #[derive (Clone)] |
782 | struct WindowHintedIterator<I: ExactSizeIterator + Iterator + Clone> { |
783 | iter: I, |
784 | window_size: usize, |
785 | hint_total_size: bool, |
786 | } |
787 | impl<I: ExactSizeIterator + Iterator + Clone> Iterator for WindowHintedIterator<I> { |
788 | type Item = I::Item; |
789 | |
790 | fn next(&mut self) -> Option<Self::Item> { |
791 | self.iter.next() |
792 | } |
793 | |
794 | fn size_hint(&self) -> (usize, Option<usize>) { |
795 | ( |
796 | ::core::cmp::min(self.iter.len(), self.window_size), |
797 | if self.hint_total_size { |
798 | Some(self.iter.len()) |
799 | } else { |
800 | None |
801 | }, |
802 | ) |
803 | } |
804 | } |
805 | |
806 | #[test ] |
807 | #[cfg_attr (miri, ignore)] // Miri is too slow |
808 | fn test_iterator_choose() { |
809 | let r = &mut crate::test::rng(109); |
810 | fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) { |
811 | let mut chosen = [0i32; 9]; |
812 | for _ in 0..1000 { |
813 | let picked = iter.clone().choose(r).unwrap(); |
814 | chosen[picked] += 1; |
815 | } |
816 | for count in chosen.iter() { |
817 | // Samples should follow Binomial(1000, 1/9) |
818 | // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x |
819 | // Note: have seen 153, which is unlikely but not impossible. |
820 | assert!( |
821 | 72 < *count && *count < 154, |
822 | "count not close to 1000/9: {}" , |
823 | count |
824 | ); |
825 | } |
826 | } |
827 | |
828 | test_iter(r, 0..9); |
829 | test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()); |
830 | #[cfg (feature = "alloc" )] |
831 | test_iter(r, (0..9).collect::<Vec<_>>().into_iter()); |
832 | test_iter(r, UnhintedIterator { iter: 0..9 }); |
833 | test_iter(r, ChunkHintedIterator { |
834 | iter: 0..9, |
835 | chunk_size: 4, |
836 | chunk_remaining: 4, |
837 | hint_total_size: false, |
838 | }); |
839 | test_iter(r, ChunkHintedIterator { |
840 | iter: 0..9, |
841 | chunk_size: 4, |
842 | chunk_remaining: 4, |
843 | hint_total_size: true, |
844 | }); |
845 | test_iter(r, WindowHintedIterator { |
846 | iter: 0..9, |
847 | window_size: 2, |
848 | hint_total_size: false, |
849 | }); |
850 | test_iter(r, WindowHintedIterator { |
851 | iter: 0..9, |
852 | window_size: 2, |
853 | hint_total_size: true, |
854 | }); |
855 | |
856 | assert_eq!((0..0).choose(r), None); |
857 | assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None); |
858 | } |
859 | |
860 | #[test ] |
861 | #[cfg_attr (miri, ignore)] // Miri is too slow |
862 | fn test_iterator_choose_stable() { |
863 | let r = &mut crate::test::rng(109); |
864 | fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) { |
865 | let mut chosen = [0i32; 9]; |
866 | for _ in 0..1000 { |
867 | let picked = iter.clone().choose_stable(r).unwrap(); |
868 | chosen[picked] += 1; |
869 | } |
870 | for count in chosen.iter() { |
871 | // Samples should follow Binomial(1000, 1/9) |
872 | // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x |
873 | // Note: have seen 153, which is unlikely but not impossible. |
874 | assert!( |
875 | 72 < *count && *count < 154, |
876 | "count not close to 1000/9: {}" , |
877 | count |
878 | ); |
879 | } |
880 | } |
881 | |
882 | test_iter(r, 0..9); |
883 | test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()); |
884 | #[cfg (feature = "alloc" )] |
885 | test_iter(r, (0..9).collect::<Vec<_>>().into_iter()); |
886 | test_iter(r, UnhintedIterator { iter: 0..9 }); |
887 | test_iter(r, ChunkHintedIterator { |
888 | iter: 0..9, |
889 | chunk_size: 4, |
890 | chunk_remaining: 4, |
891 | hint_total_size: false, |
892 | }); |
893 | test_iter(r, ChunkHintedIterator { |
894 | iter: 0..9, |
895 | chunk_size: 4, |
896 | chunk_remaining: 4, |
897 | hint_total_size: true, |
898 | }); |
899 | test_iter(r, WindowHintedIterator { |
900 | iter: 0..9, |
901 | window_size: 2, |
902 | hint_total_size: false, |
903 | }); |
904 | test_iter(r, WindowHintedIterator { |
905 | iter: 0..9, |
906 | window_size: 2, |
907 | hint_total_size: true, |
908 | }); |
909 | |
910 | assert_eq!((0..0).choose(r), None); |
911 | assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None); |
912 | } |
913 | |
914 | #[test ] |
915 | #[cfg_attr (miri, ignore)] // Miri is too slow |
916 | fn test_iterator_choose_stable_stability() { |
917 | fn test_iter(iter: impl Iterator<Item = usize> + Clone) -> [i32; 9] { |
918 | let r = &mut crate::test::rng(109); |
919 | let mut chosen = [0i32; 9]; |
920 | for _ in 0..1000 { |
921 | let picked = iter.clone().choose_stable(r).unwrap(); |
922 | chosen[picked] += 1; |
923 | } |
924 | chosen |
925 | } |
926 | |
927 | let reference = test_iter(0..9); |
928 | assert_eq!(test_iter([0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()), reference); |
929 | |
930 | #[cfg (feature = "alloc" )] |
931 | assert_eq!(test_iter((0..9).collect::<Vec<_>>().into_iter()), reference); |
932 | assert_eq!(test_iter(UnhintedIterator { iter: 0..9 }), reference); |
933 | assert_eq!(test_iter(ChunkHintedIterator { |
934 | iter: 0..9, |
935 | chunk_size: 4, |
936 | chunk_remaining: 4, |
937 | hint_total_size: false, |
938 | }), reference); |
939 | assert_eq!(test_iter(ChunkHintedIterator { |
940 | iter: 0..9, |
941 | chunk_size: 4, |
942 | chunk_remaining: 4, |
943 | hint_total_size: true, |
944 | }), reference); |
945 | assert_eq!(test_iter(WindowHintedIterator { |
946 | iter: 0..9, |
947 | window_size: 2, |
948 | hint_total_size: false, |
949 | }), reference); |
950 | assert_eq!(test_iter(WindowHintedIterator { |
951 | iter: 0..9, |
952 | window_size: 2, |
953 | hint_total_size: true, |
954 | }), reference); |
955 | } |
956 | |
957 | #[test ] |
958 | #[cfg_attr (miri, ignore)] // Miri is too slow |
959 | fn test_shuffle() { |
960 | let mut r = crate::test::rng(108); |
961 | let empty: &mut [isize] = &mut []; |
962 | empty.shuffle(&mut r); |
963 | let mut one = [1]; |
964 | one.shuffle(&mut r); |
965 | let b: &[_] = &[1]; |
966 | assert_eq!(one, b); |
967 | |
968 | let mut two = [1, 2]; |
969 | two.shuffle(&mut r); |
970 | assert!(two == [1, 2] || two == [2, 1]); |
971 | |
972 | fn move_last(slice: &mut [usize], pos: usize) { |
973 | // use slice[pos..].rotate_left(1); once we can use that |
974 | let last_val = slice[pos]; |
975 | for i in pos..slice.len() - 1 { |
976 | slice[i] = slice[i + 1]; |
977 | } |
978 | *slice.last_mut().unwrap() = last_val; |
979 | } |
980 | let mut counts = [0i32; 24]; |
981 | for _ in 0..10000 { |
982 | let mut arr: [usize; 4] = [0, 1, 2, 3]; |
983 | arr.shuffle(&mut r); |
984 | let mut permutation = 0usize; |
985 | let mut pos_value = counts.len(); |
986 | for i in 0..4 { |
987 | pos_value /= 4 - i; |
988 | let pos = arr.iter().position(|&x| x == i).unwrap(); |
989 | assert!(pos < (4 - i)); |
990 | permutation += pos * pos_value; |
991 | move_last(&mut arr, pos); |
992 | assert_eq!(arr[3], i); |
993 | } |
994 | for (i, &a) in arr.iter().enumerate() { |
995 | assert_eq!(a, i); |
996 | } |
997 | counts[permutation] += 1; |
998 | } |
999 | for count in counts.iter() { |
1000 | // Binomial(10000, 1/24) with average 416.667 |
1001 | // Octave: binocdf(n, 10000, 1/24) |
1002 | // 99.9% chance samples lie within this range: |
1003 | assert!(352 <= *count && *count <= 483, "count: {}" , count); |
1004 | } |
1005 | } |
1006 | |
1007 | #[test ] |
1008 | fn test_partial_shuffle() { |
1009 | let mut r = crate::test::rng(118); |
1010 | |
1011 | let mut empty: [u32; 0] = []; |
1012 | let res = empty.partial_shuffle(&mut r, 10); |
1013 | assert_eq!((res.0.len(), res.1.len()), (0, 0)); |
1014 | |
1015 | let mut v = [1, 2, 3, 4, 5]; |
1016 | let res = v.partial_shuffle(&mut r, 2); |
1017 | assert_eq!((res.0.len(), res.1.len()), (2, 3)); |
1018 | assert!(res.0[0] != res.0[1]); |
1019 | // First elements are only modified if selected, so at least one isn't modified: |
1020 | assert!(res.1[0] == 1 || res.1[1] == 2 || res.1[2] == 3); |
1021 | } |
1022 | |
1023 | #[test ] |
1024 | #[cfg (feature = "alloc" )] |
1025 | fn test_sample_iter() { |
1026 | let min_val = 1; |
1027 | let max_val = 100; |
1028 | |
1029 | let mut r = crate::test::rng(401); |
1030 | let vals = (min_val..max_val).collect::<Vec<i32>>(); |
1031 | let small_sample = vals.iter().choose_multiple(&mut r, 5); |
1032 | let large_sample = vals.iter().choose_multiple(&mut r, vals.len() + 5); |
1033 | |
1034 | assert_eq!(small_sample.len(), 5); |
1035 | assert_eq!(large_sample.len(), vals.len()); |
1036 | // no randomization happens when amount >= len |
1037 | assert_eq!(large_sample, vals.iter().collect::<Vec<_>>()); |
1038 | |
1039 | assert!(small_sample |
1040 | .iter() |
1041 | .all(|e| { **e >= min_val && **e <= max_val })); |
1042 | } |
1043 | |
1044 | #[test ] |
1045 | #[cfg (feature = "alloc" )] |
1046 | #[cfg_attr (miri, ignore)] // Miri is too slow |
1047 | fn test_weighted() { |
1048 | let mut r = crate::test::rng(406); |
1049 | const N_REPS: u32 = 3000; |
1050 | let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; |
1051 | let total_weight = weights.iter().sum::<u32>() as f32; |
1052 | |
1053 | let verify = |result: [i32; 14]| { |
1054 | for (i, count) in result.iter().enumerate() { |
1055 | let exp = (weights[i] * N_REPS) as f32 / total_weight; |
1056 | let mut err = (*count as f32 - exp).abs(); |
1057 | if err != 0.0 { |
1058 | err /= exp; |
1059 | } |
1060 | assert!(err <= 0.25); |
1061 | } |
1062 | }; |
1063 | |
1064 | // choose_weighted |
1065 | fn get_weight<T>(item: &(u32, T)) -> u32 { |
1066 | item.0 |
1067 | } |
1068 | let mut chosen = [0i32; 14]; |
1069 | let mut items = [(0u32, 0usize); 14]; // (weight, index) |
1070 | for (i, item) in items.iter_mut().enumerate() { |
1071 | *item = (weights[i], i); |
1072 | } |
1073 | for _ in 0..N_REPS { |
1074 | let item = items.choose_weighted(&mut r, get_weight).unwrap(); |
1075 | chosen[item.1] += 1; |
1076 | } |
1077 | verify(chosen); |
1078 | |
1079 | // choose_weighted_mut |
1080 | let mut items = [(0u32, 0i32); 14]; // (weight, count) |
1081 | for (i, item) in items.iter_mut().enumerate() { |
1082 | *item = (weights[i], 0); |
1083 | } |
1084 | for _ in 0..N_REPS { |
1085 | items.choose_weighted_mut(&mut r, get_weight).unwrap().1 += 1; |
1086 | } |
1087 | for (ch, item) in chosen.iter_mut().zip(items.iter()) { |
1088 | *ch = item.1; |
1089 | } |
1090 | verify(chosen); |
1091 | |
1092 | // Check error cases |
1093 | let empty_slice = &mut [10][0..0]; |
1094 | assert_eq!( |
1095 | empty_slice.choose_weighted(&mut r, |_| 1), |
1096 | Err(WeightedError::NoItem) |
1097 | ); |
1098 | assert_eq!( |
1099 | empty_slice.choose_weighted_mut(&mut r, |_| 1), |
1100 | Err(WeightedError::NoItem) |
1101 | ); |
1102 | assert_eq!( |
1103 | ['x' ].choose_weighted_mut(&mut r, |_| 0), |
1104 | Err(WeightedError::AllWeightsZero) |
1105 | ); |
1106 | assert_eq!( |
1107 | [0, -1].choose_weighted_mut(&mut r, |x| *x), |
1108 | Err(WeightedError::InvalidWeight) |
1109 | ); |
1110 | assert_eq!( |
1111 | [-1, 0].choose_weighted_mut(&mut r, |x| *x), |
1112 | Err(WeightedError::InvalidWeight) |
1113 | ); |
1114 | } |
1115 | |
1116 | #[test ] |
1117 | fn value_stability_choose() { |
1118 | fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> { |
1119 | let mut rng = crate::test::rng(411); |
1120 | iter.choose(&mut rng) |
1121 | } |
1122 | |
1123 | assert_eq!(choose([].iter().cloned()), None); |
1124 | assert_eq!(choose(0..100), Some(33)); |
1125 | assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(40)); |
1126 | assert_eq!( |
1127 | choose(ChunkHintedIterator { |
1128 | iter: 0..100, |
1129 | chunk_size: 32, |
1130 | chunk_remaining: 32, |
1131 | hint_total_size: false, |
1132 | }), |
1133 | Some(39) |
1134 | ); |
1135 | assert_eq!( |
1136 | choose(ChunkHintedIterator { |
1137 | iter: 0..100, |
1138 | chunk_size: 32, |
1139 | chunk_remaining: 32, |
1140 | hint_total_size: true, |
1141 | }), |
1142 | Some(39) |
1143 | ); |
1144 | assert_eq!( |
1145 | choose(WindowHintedIterator { |
1146 | iter: 0..100, |
1147 | window_size: 32, |
1148 | hint_total_size: false, |
1149 | }), |
1150 | Some(90) |
1151 | ); |
1152 | assert_eq!( |
1153 | choose(WindowHintedIterator { |
1154 | iter: 0..100, |
1155 | window_size: 32, |
1156 | hint_total_size: true, |
1157 | }), |
1158 | Some(90) |
1159 | ); |
1160 | } |
1161 | |
1162 | #[test ] |
1163 | fn value_stability_choose_stable() { |
1164 | fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> { |
1165 | let mut rng = crate::test::rng(411); |
1166 | iter.choose_stable(&mut rng) |
1167 | } |
1168 | |
1169 | assert_eq!(choose([].iter().cloned()), None); |
1170 | assert_eq!(choose(0..100), Some(40)); |
1171 | assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(40)); |
1172 | assert_eq!( |
1173 | choose(ChunkHintedIterator { |
1174 | iter: 0..100, |
1175 | chunk_size: 32, |
1176 | chunk_remaining: 32, |
1177 | hint_total_size: false, |
1178 | }), |
1179 | Some(40) |
1180 | ); |
1181 | assert_eq!( |
1182 | choose(ChunkHintedIterator { |
1183 | iter: 0..100, |
1184 | chunk_size: 32, |
1185 | chunk_remaining: 32, |
1186 | hint_total_size: true, |
1187 | }), |
1188 | Some(40) |
1189 | ); |
1190 | assert_eq!( |
1191 | choose(WindowHintedIterator { |
1192 | iter: 0..100, |
1193 | window_size: 32, |
1194 | hint_total_size: false, |
1195 | }), |
1196 | Some(40) |
1197 | ); |
1198 | assert_eq!( |
1199 | choose(WindowHintedIterator { |
1200 | iter: 0..100, |
1201 | window_size: 32, |
1202 | hint_total_size: true, |
1203 | }), |
1204 | Some(40) |
1205 | ); |
1206 | } |
1207 | |
1208 | #[test ] |
1209 | fn value_stability_choose_multiple() { |
1210 | fn do_test<I: Iterator<Item = u32>>(iter: I, v: &[u32]) { |
1211 | let mut rng = crate::test::rng(412); |
1212 | let mut buf = [0u32; 8]; |
1213 | assert_eq!(iter.choose_multiple_fill(&mut rng, &mut buf), v.len()); |
1214 | assert_eq!(&buf[0..v.len()], v); |
1215 | } |
1216 | |
1217 | do_test(0..4, &[0, 1, 2, 3]); |
1218 | do_test(0..8, &[0, 1, 2, 3, 4, 5, 6, 7]); |
1219 | do_test(0..100, &[58, 78, 80, 92, 43, 8, 96, 7]); |
1220 | |
1221 | #[cfg (feature = "alloc" )] |
1222 | { |
1223 | fn do_test<I: Iterator<Item = u32>>(iter: I, v: &[u32]) { |
1224 | let mut rng = crate::test::rng(412); |
1225 | assert_eq!(iter.choose_multiple(&mut rng, v.len()), v); |
1226 | } |
1227 | |
1228 | do_test(0..4, &[0, 1, 2, 3]); |
1229 | do_test(0..8, &[0, 1, 2, 3, 4, 5, 6, 7]); |
1230 | do_test(0..100, &[58, 78, 80, 92, 43, 8, 96, 7]); |
1231 | } |
1232 | } |
1233 | |
1234 | #[test ] |
1235 | #[cfg (feature = "std" )] |
1236 | fn test_multiple_weighted_edge_cases() { |
1237 | use super::*; |
1238 | |
1239 | let mut rng = crate::test::rng(413); |
1240 | |
1241 | // Case 1: One of the weights is 0 |
1242 | let choices = [('a' , 2), ('b' , 1), ('c' , 0)]; |
1243 | for _ in 0..100 { |
1244 | let result = choices |
1245 | .choose_multiple_weighted(&mut rng, 2, |item| item.1) |
1246 | .unwrap() |
1247 | .collect::<Vec<_>>(); |
1248 | |
1249 | assert_eq!(result.len(), 2); |
1250 | assert!(!result.iter().any(|val| val.0 == 'c' )); |
1251 | } |
1252 | |
1253 | // Case 2: All of the weights are 0 |
1254 | let choices = [('a' , 0), ('b' , 0), ('c' , 0)]; |
1255 | |
1256 | assert_eq!(choices |
1257 | .choose_multiple_weighted(&mut rng, 2, |item| item.1) |
1258 | .unwrap().count(), 2); |
1259 | |
1260 | // Case 3: Negative weights |
1261 | let choices = [('a' , -1), ('b' , 1), ('c' , 1)]; |
1262 | assert_eq!( |
1263 | choices |
1264 | .choose_multiple_weighted(&mut rng, 2, |item| item.1) |
1265 | .unwrap_err(), |
1266 | WeightedError::InvalidWeight |
1267 | ); |
1268 | |
1269 | // Case 4: Empty list |
1270 | let choices = []; |
1271 | assert_eq!(choices |
1272 | .choose_multiple_weighted(&mut rng, 0, |_: &()| 0) |
1273 | .unwrap().count(), 0); |
1274 | |
1275 | // Case 5: NaN weights |
1276 | let choices = [('a' , core::f64::NAN), ('b' , 1.0), ('c' , 1.0)]; |
1277 | assert_eq!( |
1278 | choices |
1279 | .choose_multiple_weighted(&mut rng, 2, |item| item.1) |
1280 | .unwrap_err(), |
1281 | WeightedError::InvalidWeight |
1282 | ); |
1283 | |
1284 | // Case 6: +infinity weights |
1285 | let choices = [('a' , core::f64::INFINITY), ('b' , 1.0), ('c' , 1.0)]; |
1286 | for _ in 0..100 { |
1287 | let result = choices |
1288 | .choose_multiple_weighted(&mut rng, 2, |item| item.1) |
1289 | .unwrap() |
1290 | .collect::<Vec<_>>(); |
1291 | assert_eq!(result.len(), 2); |
1292 | assert!(result.iter().any(|val| val.0 == 'a' )); |
1293 | } |
1294 | |
1295 | // Case 7: -infinity weights |
1296 | let choices = [('a' , core::f64::NEG_INFINITY), ('b' , 1.0), ('c' , 1.0)]; |
1297 | assert_eq!( |
1298 | choices |
1299 | .choose_multiple_weighted(&mut rng, 2, |item| item.1) |
1300 | .unwrap_err(), |
1301 | WeightedError::InvalidWeight |
1302 | ); |
1303 | |
1304 | // Case 8: -0 weights |
1305 | let choices = [('a' , -0.0), ('b' , 1.0), ('c' , 1.0)]; |
1306 | assert!(choices |
1307 | .choose_multiple_weighted(&mut rng, 2, |item| item.1) |
1308 | .is_ok()); |
1309 | } |
1310 | |
1311 | #[test ] |
1312 | #[cfg (feature = "std" )] |
1313 | fn test_multiple_weighted_distributions() { |
1314 | use super::*; |
1315 | |
1316 | // The theoretical probabilities of the different outcomes are: |
1317 | // AB: 0.5 * 0.5 = 0.250 |
1318 | // AC: 0.5 * 0.5 = 0.250 |
1319 | // BA: 0.25 * 0.67 = 0.167 |
1320 | // BC: 0.25 * 0.33 = 0.082 |
1321 | // CA: 0.25 * 0.67 = 0.167 |
1322 | // CB: 0.25 * 0.33 = 0.082 |
1323 | let choices = [('a' , 2), ('b' , 1), ('c' , 1)]; |
1324 | let mut rng = crate::test::rng(414); |
1325 | |
1326 | let mut results = [0i32; 3]; |
1327 | let expected_results = [4167, 4167, 1666]; |
1328 | for _ in 0..10000 { |
1329 | let result = choices |
1330 | .choose_multiple_weighted(&mut rng, 2, |item| item.1) |
1331 | .unwrap() |
1332 | .collect::<Vec<_>>(); |
1333 | |
1334 | assert_eq!(result.len(), 2); |
1335 | |
1336 | match (result[0].0, result[1].0) { |
1337 | ('a' , 'b' ) | ('b' , 'a' ) => { |
1338 | results[0] += 1; |
1339 | } |
1340 | ('a' , 'c' ) | ('c' , 'a' ) => { |
1341 | results[1] += 1; |
1342 | } |
1343 | ('b' , 'c' ) | ('c' , 'b' ) => { |
1344 | results[2] += 1; |
1345 | } |
1346 | (_, _) => panic!("unexpected result" ), |
1347 | } |
1348 | } |
1349 | |
1350 | let mut diffs = results |
1351 | .iter() |
1352 | .zip(&expected_results) |
1353 | .map(|(a, b)| (a - b).abs()); |
1354 | assert!(!diffs.any(|deviation| deviation > 100)); |
1355 | } |
1356 | } |
1357 | |