1// Copyright 2018-2024 Developers of the Rand project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! `IteratorRandom`
10
11use super::coin_flipper::CoinFlipper;
12#[allow(unused)]
13use super::IndexedRandom;
14use crate::Rng;
15#[cfg(feature = "alloc")]
16use alloc::vec::Vec;
17
18/// Extension trait on iterators, providing random sampling methods.
19///
20/// This trait is implemented on all iterators `I` where `I: Iterator + Sized`
21/// and provides methods for
22/// choosing one or more elements. You must `use` this trait:
23///
24/// ```
25/// use rand::seq::IteratorRandom;
26///
27/// let faces = "😀😎😐😕😠😢";
28/// println!("I am {}!", faces.chars().choose(&mut rand::rng()).unwrap());
29/// ```
30/// Example output (non-deterministic):
31/// ```none
32/// I am 😀!
33/// ```
34pub trait IteratorRandom: Iterator + Sized {
35 /// Uniformly sample one element
36 ///
37 /// Assuming that the [`Iterator::size_hint`] is correct, this method
38 /// returns one uniformly-sampled random element of the slice, or `None`
39 /// only if the slice is empty. Incorrect bounds on the `size_hint` may
40 /// cause this method to incorrectly return `None` if fewer elements than
41 /// the advertised `lower` bound are present and may prevent sampling of
42 /// elements beyond an advertised `upper` bound (i.e. incorrect `size_hint`
43 /// is memory-safe, but may result in unexpected `None` result and
44 /// non-uniform distribution).
45 ///
46 /// With an accurate [`Iterator::size_hint`] and where [`Iterator::nth`] is
47 /// a constant-time operation, this method can offer `O(1)` performance.
48 /// Where no size hint is
49 /// available, complexity is `O(n)` where `n` is the iterator length.
50 /// Partial hints (where `lower > 0`) also improve performance.
51 ///
52 /// Note further that [`Iterator::size_hint`] may affect the number of RNG
53 /// samples used as well as the result (while remaining uniform sampling).
54 /// Consider instead using [`IteratorRandom::choose_stable`] to avoid
55 /// [`Iterator`] combinators which only change size hints from affecting the
56 /// results.
57 ///
58 /// # Example
59 ///
60 /// ```
61 /// use rand::seq::IteratorRandom;
62 ///
63 /// let words = "Mary had a little lamb".split(' ');
64 /// println!("{}", words.choose(&mut rand::rng()).unwrap());
65 /// ```
66 fn choose<R>(mut self, rng: &mut R) -> Option<Self::Item>
67 where
68 R: Rng + ?Sized,
69 {
70 let (mut lower, mut upper) = self.size_hint();
71 let mut result = None;
72
73 // Handling for this condition outside the loop allows the optimizer to eliminate the loop
74 // when the Iterator is an ExactSizeIterator. This has a large performance impact on e.g.
75 // seq_iter_choose_from_1000.
76 if upper == Some(lower) {
77 return match lower {
78 0 => None,
79 1 => self.next(),
80 _ => self.nth(rng.random_range(..lower)),
81 };
82 }
83
84 let mut coin_flipper = CoinFlipper::new(rng);
85 let mut consumed = 0;
86
87 // Continue until the iterator is exhausted
88 loop {
89 if lower > 1 {
90 let ix = coin_flipper.rng.random_range(..lower + consumed);
91 let skip = if ix < lower {
92 result = self.nth(ix);
93 lower - (ix + 1)
94 } else {
95 lower
96 };
97 if upper == Some(lower) {
98 return result;
99 }
100 consumed += lower;
101 if skip > 0 {
102 self.nth(skip - 1);
103 }
104 } else {
105 let elem = self.next();
106 if elem.is_none() {
107 return result;
108 }
109 consumed += 1;
110 if coin_flipper.random_ratio_one_over(consumed) {
111 result = elem;
112 }
113 }
114
115 let hint = self.size_hint();
116 lower = hint.0;
117 upper = hint.1;
118 }
119 }
120
121 /// Uniformly sample one element (stable)
122 ///
123 /// This method is very similar to [`choose`] except that the result
124 /// only depends on the length of the iterator and the values produced by
125 /// `rng`. Notably for any iterator of a given length this will make the
126 /// same requests to `rng` and if the same sequence of values are produced
127 /// the same index will be selected from `self`. This may be useful if you
128 /// need consistent results no matter what type of iterator you are working
129 /// with. If you do not need this stability prefer [`choose`].
130 ///
131 /// Note that this method still uses [`Iterator::size_hint`] to skip
132 /// constructing elements where possible, however the selection and `rng`
133 /// calls are the same in the face of this optimization. If you want to
134 /// force every element to be created regardless call `.inspect(|e| ())`.
135 ///
136 /// [`choose`]: IteratorRandom::choose
137 fn choose_stable<R>(mut self, rng: &mut R) -> Option<Self::Item>
138 where
139 R: Rng + ?Sized,
140 {
141 let mut consumed = 0;
142 let mut result = None;
143 let mut coin_flipper = CoinFlipper::new(rng);
144
145 loop {
146 // Currently the only way to skip elements is `nth()`. So we need to
147 // store what index to access next here.
148 // This should be replaced by `advance_by()` once it is stable:
149 // https://github.com/rust-lang/rust/issues/77404
150 let mut next = 0;
151
152 let (lower, _) = self.size_hint();
153 if lower >= 2 {
154 let highest_selected = (0..lower)
155 .filter(|ix| coin_flipper.random_ratio_one_over(consumed + ix + 1))
156 .last();
157
158 consumed += lower;
159 next = lower;
160
161 if let Some(ix) = highest_selected {
162 result = self.nth(ix);
163 next -= ix + 1;
164 debug_assert!(result.is_some(), "iterator shorter than size_hint().0");
165 }
166 }
167
168 let elem = self.nth(next);
169 if elem.is_none() {
170 return result;
171 }
172
173 if coin_flipper.random_ratio_one_over(consumed + 1) {
174 result = elem;
175 }
176 consumed += 1;
177 }
178 }
179
180 /// Uniformly sample `amount` distinct elements into a buffer
181 ///
182 /// Collects values at random from the iterator into a supplied buffer
183 /// until that buffer is filled.
184 ///
185 /// Although the elements are selected randomly, the order of elements in
186 /// the buffer is neither stable nor fully random. If random ordering is
187 /// desired, shuffle the result.
188 ///
189 /// Returns the number of elements added to the buffer. This equals the length
190 /// of the buffer unless the iterator contains insufficient elements, in which
191 /// case this equals the number of elements available.
192 ///
193 /// Complexity is `O(n)` where `n` is the length of the iterator.
194 /// For slices, prefer [`IndexedRandom::choose_multiple`].
195 fn choose_multiple_fill<R>(mut self, rng: &mut R, buf: &mut [Self::Item]) -> usize
196 where
197 R: Rng + ?Sized,
198 {
199 let amount = buf.len();
200 let mut len = 0;
201 while len < amount {
202 if let Some(elem) = self.next() {
203 buf[len] = elem;
204 len += 1;
205 } else {
206 // Iterator exhausted; stop early
207 return len;
208 }
209 }
210
211 // Continue, since the iterator was not exhausted
212 for (i, elem) in self.enumerate() {
213 let k = rng.random_range(..i + 1 + amount);
214 if let Some(slot) = buf.get_mut(k) {
215 *slot = elem;
216 }
217 }
218 len
219 }
220
221 /// Uniformly sample `amount` distinct elements into a [`Vec`]
222 ///
223 /// This is equivalent to `choose_multiple_fill` except for the result type.
224 ///
225 /// Although the elements are selected randomly, the order of elements in
226 /// the buffer is neither stable nor fully random. If random ordering is
227 /// desired, shuffle the result.
228 ///
229 /// The length of the returned vector equals `amount` unless the iterator
230 /// contains insufficient elements, in which case it equals the number of
231 /// elements available.
232 ///
233 /// Complexity is `O(n)` where `n` is the length of the iterator.
234 /// For slices, prefer [`IndexedRandom::choose_multiple`].
235 #[cfg(feature = "alloc")]
236 fn choose_multiple<R>(mut self, rng: &mut R, amount: usize) -> Vec<Self::Item>
237 where
238 R: Rng + ?Sized,
239 {
240 let mut reservoir = Vec::with_capacity(amount);
241 reservoir.extend(self.by_ref().take(amount));
242
243 // Continue unless the iterator was exhausted
244 //
245 // note: this prevents iterators that "restart" from causing problems.
246 // If the iterator stops once, then so do we.
247 if reservoir.len() == amount {
248 for (i, elem) in self.enumerate() {
249 let k = rng.random_range(..i + 1 + amount);
250 if let Some(slot) = reservoir.get_mut(k) {
251 *slot = elem;
252 }
253 }
254 } else {
255 // Don't hang onto extra memory. There is a corner case where
256 // `amount` was much less than `self.len()`.
257 reservoir.shrink_to_fit();
258 }
259 reservoir
260 }
261}
262
263impl<I> IteratorRandom for I where I: Iterator + Sized {}
264
265#[cfg(test)]
266mod test {
267 use super::*;
268 #[cfg(all(feature = "alloc", not(feature = "std")))]
269 use alloc::vec::Vec;
270
271 #[derive(Clone)]
272 struct UnhintedIterator<I: Iterator + Clone> {
273 iter: I,
274 }
275 impl<I: Iterator + Clone> Iterator for UnhintedIterator<I> {
276 type Item = I::Item;
277
278 fn next(&mut self) -> Option<Self::Item> {
279 self.iter.next()
280 }
281 }
282
283 #[derive(Clone)]
284 struct ChunkHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
285 iter: I,
286 chunk_remaining: usize,
287 chunk_size: usize,
288 hint_total_size: bool,
289 }
290 impl<I: ExactSizeIterator + Iterator + Clone> Iterator for ChunkHintedIterator<I> {
291 type Item = I::Item;
292
293 fn next(&mut self) -> Option<Self::Item> {
294 if self.chunk_remaining == 0 {
295 self.chunk_remaining = core::cmp::min(self.chunk_size, self.iter.len());
296 }
297 self.chunk_remaining = self.chunk_remaining.saturating_sub(1);
298
299 self.iter.next()
300 }
301
302 fn size_hint(&self) -> (usize, Option<usize>) {
303 (
304 self.chunk_remaining,
305 if self.hint_total_size {
306 Some(self.iter.len())
307 } else {
308 None
309 },
310 )
311 }
312 }
313
314 #[derive(Clone)]
315 struct WindowHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
316 iter: I,
317 window_size: usize,
318 hint_total_size: bool,
319 }
320 impl<I: ExactSizeIterator + Iterator + Clone> Iterator for WindowHintedIterator<I> {
321 type Item = I::Item;
322
323 fn next(&mut self) -> Option<Self::Item> {
324 self.iter.next()
325 }
326
327 fn size_hint(&self) -> (usize, Option<usize>) {
328 (
329 core::cmp::min(self.iter.len(), self.window_size),
330 if self.hint_total_size {
331 Some(self.iter.len())
332 } else {
333 None
334 },
335 )
336 }
337 }
338
339 #[test]
340 #[cfg_attr(miri, ignore)] // Miri is too slow
341 fn test_iterator_choose() {
342 let r = &mut crate::test::rng(109);
343 fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) {
344 let mut chosen = [0i32; 9];
345 for _ in 0..1000 {
346 let picked = iter.clone().choose(r).unwrap();
347 chosen[picked] += 1;
348 }
349 for count in chosen.iter() {
350 // Samples should follow Binomial(1000, 1/9)
351 // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x
352 // Note: have seen 153, which is unlikely but not impossible.
353 assert!(
354 72 < *count && *count < 154,
355 "count not close to 1000/9: {}",
356 count
357 );
358 }
359 }
360
361 test_iter(r, 0..9);
362 test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned());
363 #[cfg(feature = "alloc")]
364 test_iter(r, (0..9).collect::<Vec<_>>().into_iter());
365 test_iter(r, UnhintedIterator { iter: 0..9 });
366 test_iter(
367 r,
368 ChunkHintedIterator {
369 iter: 0..9,
370 chunk_size: 4,
371 chunk_remaining: 4,
372 hint_total_size: false,
373 },
374 );
375 test_iter(
376 r,
377 ChunkHintedIterator {
378 iter: 0..9,
379 chunk_size: 4,
380 chunk_remaining: 4,
381 hint_total_size: true,
382 },
383 );
384 test_iter(
385 r,
386 WindowHintedIterator {
387 iter: 0..9,
388 window_size: 2,
389 hint_total_size: false,
390 },
391 );
392 test_iter(
393 r,
394 WindowHintedIterator {
395 iter: 0..9,
396 window_size: 2,
397 hint_total_size: true,
398 },
399 );
400
401 assert_eq!((0..0).choose(r), None);
402 assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None);
403 }
404
405 #[test]
406 #[cfg_attr(miri, ignore)] // Miri is too slow
407 fn test_iterator_choose_stable() {
408 let r = &mut crate::test::rng(109);
409 fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) {
410 let mut chosen = [0i32; 9];
411 for _ in 0..1000 {
412 let picked = iter.clone().choose_stable(r).unwrap();
413 chosen[picked] += 1;
414 }
415 for count in chosen.iter() {
416 // Samples should follow Binomial(1000, 1/9)
417 // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x
418 // Note: have seen 153, which is unlikely but not impossible.
419 assert!(
420 72 < *count && *count < 154,
421 "count not close to 1000/9: {}",
422 count
423 );
424 }
425 }
426
427 test_iter(r, 0..9);
428 test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned());
429 #[cfg(feature = "alloc")]
430 test_iter(r, (0..9).collect::<Vec<_>>().into_iter());
431 test_iter(r, UnhintedIterator { iter: 0..9 });
432 test_iter(
433 r,
434 ChunkHintedIterator {
435 iter: 0..9,
436 chunk_size: 4,
437 chunk_remaining: 4,
438 hint_total_size: false,
439 },
440 );
441 test_iter(
442 r,
443 ChunkHintedIterator {
444 iter: 0..9,
445 chunk_size: 4,
446 chunk_remaining: 4,
447 hint_total_size: true,
448 },
449 );
450 test_iter(
451 r,
452 WindowHintedIterator {
453 iter: 0..9,
454 window_size: 2,
455 hint_total_size: false,
456 },
457 );
458 test_iter(
459 r,
460 WindowHintedIterator {
461 iter: 0..9,
462 window_size: 2,
463 hint_total_size: true,
464 },
465 );
466
467 assert_eq!((0..0).choose(r), None);
468 assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None);
469 }
470
471 #[test]
472 #[cfg_attr(miri, ignore)] // Miri is too slow
473 fn test_iterator_choose_stable_stability() {
474 fn test_iter(iter: impl Iterator<Item = usize> + Clone) -> [i32; 9] {
475 let r = &mut crate::test::rng(109);
476 let mut chosen = [0i32; 9];
477 for _ in 0..1000 {
478 let picked = iter.clone().choose_stable(r).unwrap();
479 chosen[picked] += 1;
480 }
481 chosen
482 }
483
484 let reference = test_iter(0..9);
485 assert_eq!(
486 test_iter([0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()),
487 reference
488 );
489
490 #[cfg(feature = "alloc")]
491 assert_eq!(test_iter((0..9).collect::<Vec<_>>().into_iter()), reference);
492 assert_eq!(test_iter(UnhintedIterator { iter: 0..9 }), reference);
493 assert_eq!(
494 test_iter(ChunkHintedIterator {
495 iter: 0..9,
496 chunk_size: 4,
497 chunk_remaining: 4,
498 hint_total_size: false,
499 }),
500 reference
501 );
502 assert_eq!(
503 test_iter(ChunkHintedIterator {
504 iter: 0..9,
505 chunk_size: 4,
506 chunk_remaining: 4,
507 hint_total_size: true,
508 }),
509 reference
510 );
511 assert_eq!(
512 test_iter(WindowHintedIterator {
513 iter: 0..9,
514 window_size: 2,
515 hint_total_size: false,
516 }),
517 reference
518 );
519 assert_eq!(
520 test_iter(WindowHintedIterator {
521 iter: 0..9,
522 window_size: 2,
523 hint_total_size: true,
524 }),
525 reference
526 );
527 }
528
529 #[test]
530 #[cfg(feature = "alloc")]
531 fn test_sample_iter() {
532 let min_val = 1;
533 let max_val = 100;
534
535 let mut r = crate::test::rng(401);
536 let vals = (min_val..max_val).collect::<Vec<i32>>();
537 let small_sample = vals.iter().choose_multiple(&mut r, 5);
538 let large_sample = vals.iter().choose_multiple(&mut r, vals.len() + 5);
539
540 assert_eq!(small_sample.len(), 5);
541 assert_eq!(large_sample.len(), vals.len());
542 // no randomization happens when amount >= len
543 assert_eq!(large_sample, vals.iter().collect::<Vec<_>>());
544
545 assert!(small_sample
546 .iter()
547 .all(|e| { **e >= min_val && **e <= max_val }));
548 }
549
550 #[test]
551 fn value_stability_choose() {
552 fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> {
553 let mut rng = crate::test::rng(411);
554 iter.choose(&mut rng)
555 }
556
557 assert_eq!(choose([].iter().cloned()), None);
558 assert_eq!(choose(0..100), Some(33));
559 assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(27));
560 assert_eq!(
561 choose(ChunkHintedIterator {
562 iter: 0..100,
563 chunk_size: 32,
564 chunk_remaining: 32,
565 hint_total_size: false,
566 }),
567 Some(91)
568 );
569 assert_eq!(
570 choose(ChunkHintedIterator {
571 iter: 0..100,
572 chunk_size: 32,
573 chunk_remaining: 32,
574 hint_total_size: true,
575 }),
576 Some(91)
577 );
578 assert_eq!(
579 choose(WindowHintedIterator {
580 iter: 0..100,
581 window_size: 32,
582 hint_total_size: false,
583 }),
584 Some(34)
585 );
586 assert_eq!(
587 choose(WindowHintedIterator {
588 iter: 0..100,
589 window_size: 32,
590 hint_total_size: true,
591 }),
592 Some(34)
593 );
594 }
595
596 #[test]
597 fn value_stability_choose_stable() {
598 fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> {
599 let mut rng = crate::test::rng(411);
600 iter.choose_stable(&mut rng)
601 }
602
603 assert_eq!(choose([].iter().cloned()), None);
604 assert_eq!(choose(0..100), Some(27));
605 assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(27));
606 assert_eq!(
607 choose(ChunkHintedIterator {
608 iter: 0..100,
609 chunk_size: 32,
610 chunk_remaining: 32,
611 hint_total_size: false,
612 }),
613 Some(27)
614 );
615 assert_eq!(
616 choose(ChunkHintedIterator {
617 iter: 0..100,
618 chunk_size: 32,
619 chunk_remaining: 32,
620 hint_total_size: true,
621 }),
622 Some(27)
623 );
624 assert_eq!(
625 choose(WindowHintedIterator {
626 iter: 0..100,
627 window_size: 32,
628 hint_total_size: false,
629 }),
630 Some(27)
631 );
632 assert_eq!(
633 choose(WindowHintedIterator {
634 iter: 0..100,
635 window_size: 32,
636 hint_total_size: true,
637 }),
638 Some(27)
639 );
640 }
641
642 #[test]
643 fn value_stability_choose_multiple() {
644 fn do_test<I: Clone + Iterator<Item = u32>>(iter: I, v: &[u32]) {
645 let mut rng = crate::test::rng(412);
646 let mut buf = [0u32; 8];
647 assert_eq!(
648 iter.clone().choose_multiple_fill(&mut rng, &mut buf),
649 v.len()
650 );
651 assert_eq!(&buf[0..v.len()], v);
652
653 #[cfg(feature = "alloc")]
654 {
655 let mut rng = crate::test::rng(412);
656 assert_eq!(iter.choose_multiple(&mut rng, v.len()), v);
657 }
658 }
659
660 do_test(0..4, &[0, 1, 2, 3]);
661 do_test(0..8, &[0, 1, 2, 3, 4, 5, 6, 7]);
662 do_test(0..100, &[77, 95, 38, 23, 25, 8, 58, 40]);
663 }
664}
665