1// Copyright 2018 Developers of the Rand project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use super::{Error, Weight};
10use crate::distr::uniform::{SampleBorrow, SampleUniform, UniformSampler};
11use crate::distr::Distribution;
12use crate::Rng;
13
14// Note that this whole module is only imported if feature="alloc" is enabled.
15use alloc::vec::Vec;
16use core::fmt::{self, Debug};
17
18#[cfg(feature = "serde")]
19use serde::{Deserialize, Serialize};
20
21/// A distribution using weighted sampling of discrete items.
22///
23/// Sampling a `WeightedIndex` distribution returns the index of a randomly
24/// selected element from the iterator used when the `WeightedIndex` was
25/// created. The chance of a given element being picked is proportional to the
26/// weight of the element. The weights can use any type `X` for which an
27/// implementation of [`Uniform<X>`] exists. The implementation guarantees that
28/// elements with zero weight are never picked, even when the weights are
29/// floating point numbers.
30///
31/// # Performance
32///
33/// Time complexity of sampling from `WeightedIndex` is `O(log N)` where
34/// `N` is the number of weights.
35/// See also [`rand_distr::weighted`] for alternative implementations supporting
36/// potentially-faster sampling or a more easily modifiable tree structure.
37///
38/// A `WeightedIndex<X>` contains a `Vec<X>` and a [`Uniform<X>`] and so its
39/// size is the sum of the size of those objects, possibly plus some alignment.
40///
41/// Creating a `WeightedIndex<X>` will allocate enough space to hold `N - 1`
42/// weights of type `X`, where `N` is the number of weights. However, since
43/// `Vec` doesn't guarantee a particular growth strategy, additional memory
44/// might be allocated but not used. Since the `WeightedIndex` object also
45/// contains an instance of `X::Sampler`, this might cause additional allocations,
46/// though for primitive types, [`Uniform<X>`] doesn't allocate any memory.
47///
48/// Sampling from `WeightedIndex` will result in a single call to
49/// `Uniform<X>::sample` (method of the [`Distribution`] trait), which typically
50/// will request a single value from the underlying [`RngCore`], though the
51/// exact number depends on the implementation of `Uniform<X>::sample`.
52///
53/// # Example
54///
55/// ```
56/// use rand::prelude::*;
57/// use rand::distr::weighted::WeightedIndex;
58///
59/// let choices = ['a', 'b', 'c'];
60/// let weights = [2, 1, 1];
61/// let dist = WeightedIndex::new(&weights).unwrap();
62/// let mut rng = rand::rng();
63/// for _ in 0..100 {
64/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
65/// println!("{}", choices[dist.sample(&mut rng)]);
66/// }
67///
68/// let items = [('a', 0.0), ('b', 3.0), ('c', 7.0)];
69/// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap();
70/// for _ in 0..100 {
71/// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c'
72/// println!("{}", items[dist2.sample(&mut rng)].0);
73/// }
74/// ```
75///
76/// [`Uniform<X>`]: crate::distr::Uniform
77/// [`RngCore`]: crate::RngCore
78/// [`rand_distr::weighted`]: https://docs.rs/rand_distr/latest/rand_distr/weighted/index.html
79#[derive(Debug, Clone, PartialEq)]
80#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
81pub struct WeightedIndex<X: SampleUniform + PartialOrd> {
82 cumulative_weights: Vec<X>,
83 total_weight: X,
84 weight_distribution: X::Sampler,
85}
86
87impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
88 /// Creates a new a `WeightedIndex` [`Distribution`] using the values
89 /// in `weights`. The weights can use any type `X` for which an
90 /// implementation of [`Uniform<X>`] exists.
91 ///
92 /// Error cases:
93 /// - [`Error::InvalidInput`] when the iterator `weights` is empty.
94 /// - [`Error::InvalidWeight`] when a weight is not-a-number or negative.
95 /// - [`Error::InsufficientNonZero`] when the sum of all weights is zero.
96 /// - [`Error::Overflow`] when the sum of all weights overflows.
97 ///
98 /// [`Uniform<X>`]: crate::distr::uniform::Uniform
99 pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, Error>
100 where
101 I: IntoIterator,
102 I::Item: SampleBorrow<X>,
103 X: Weight,
104 {
105 let mut iter = weights.into_iter();
106 let mut total_weight: X = iter.next().ok_or(Error::InvalidInput)?.borrow().clone();
107
108 let zero = X::ZERO;
109 if !(total_weight >= zero) {
110 return Err(Error::InvalidWeight);
111 }
112
113 let mut weights = Vec::<X>::with_capacity(iter.size_hint().0);
114 for w in iter {
115 // Note that `!(w >= x)` is not equivalent to `w < x` for partially
116 // ordered types due to NaNs which are equal to nothing.
117 if !(w.borrow() >= &zero) {
118 return Err(Error::InvalidWeight);
119 }
120 weights.push(total_weight.clone());
121
122 if let Err(()) = total_weight.checked_add_assign(w.borrow()) {
123 return Err(Error::Overflow);
124 }
125 }
126
127 if total_weight == zero {
128 return Err(Error::InsufficientNonZero);
129 }
130 let distr = X::Sampler::new(zero, total_weight.clone()).unwrap();
131
132 Ok(WeightedIndex {
133 cumulative_weights: weights,
134 total_weight,
135 weight_distribution: distr,
136 })
137 }
138
139 /// Update a subset of weights, without changing the number of weights.
140 ///
141 /// `new_weights` must be sorted by the index.
142 ///
143 /// Using this method instead of `new` might be more efficient if only a small number of
144 /// weights is modified. No allocations are performed, unless the weight type `X` uses
145 /// allocation internally.
146 ///
147 /// In case of error, `self` is not modified. Error cases:
148 /// - [`Error::InvalidInput`] when `new_weights` are not ordered by
149 /// index or an index is too large.
150 /// - [`Error::InvalidWeight`] when a weight is not-a-number or negative.
151 /// - [`Error::InsufficientNonZero`] when the sum of all weights is zero.
152 /// Note that due to floating-point loss of precision, this case is not
153 /// always correctly detected; usage of a fixed-point weight type may be
154 /// preferred.
155 ///
156 /// Updates take `O(N)` time. If you need to frequently update weights, consider
157 /// [`rand_distr::weighted_tree`](https://docs.rs/rand_distr/*/rand_distr/weighted_tree/index.html)
158 /// as an alternative where an update is `O(log N)`.
159 pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), Error>
160 where
161 X: for<'a> core::ops::AddAssign<&'a X>
162 + for<'a> core::ops::SubAssign<&'a X>
163 + Clone
164 + Default,
165 {
166 if new_weights.is_empty() {
167 return Ok(());
168 }
169
170 let zero = <X as Default>::default();
171
172 let mut total_weight = self.total_weight.clone();
173
174 // Check for errors first, so we don't modify `self` in case something
175 // goes wrong.
176 let mut prev_i = None;
177 for &(i, w) in new_weights {
178 if let Some(old_i) = prev_i {
179 if old_i >= i {
180 return Err(Error::InvalidInput);
181 }
182 }
183 if !(*w >= zero) {
184 return Err(Error::InvalidWeight);
185 }
186 if i > self.cumulative_weights.len() {
187 return Err(Error::InvalidInput);
188 }
189
190 let mut old_w = if i < self.cumulative_weights.len() {
191 self.cumulative_weights[i].clone()
192 } else {
193 self.total_weight.clone()
194 };
195 if i > 0 {
196 old_w -= &self.cumulative_weights[i - 1];
197 }
198
199 total_weight -= &old_w;
200 total_weight += w;
201 prev_i = Some(i);
202 }
203 if total_weight <= zero {
204 return Err(Error::InsufficientNonZero);
205 }
206
207 // Update the weights. Because we checked all the preconditions in the
208 // previous loop, this should never panic.
209 let mut iter = new_weights.iter();
210
211 let mut prev_weight = zero.clone();
212 let mut next_new_weight = iter.next();
213 let &(first_new_index, _) = next_new_weight.unwrap();
214 let mut cumulative_weight = if first_new_index > 0 {
215 self.cumulative_weights[first_new_index - 1].clone()
216 } else {
217 zero.clone()
218 };
219 for i in first_new_index..self.cumulative_weights.len() {
220 match next_new_weight {
221 Some(&(j, w)) if i == j => {
222 cumulative_weight += w;
223 next_new_weight = iter.next();
224 }
225 _ => {
226 let mut tmp = self.cumulative_weights[i].clone();
227 tmp -= &prev_weight; // We know this is positive.
228 cumulative_weight += &tmp;
229 }
230 }
231 prev_weight = cumulative_weight.clone();
232 core::mem::swap(&mut prev_weight, &mut self.cumulative_weights[i]);
233 }
234
235 self.total_weight = total_weight;
236 self.weight_distribution = X::Sampler::new(zero, self.total_weight.clone()).unwrap();
237
238 Ok(())
239 }
240}
241
242/// A lazy-loading iterator over the weights of a `WeightedIndex` distribution.
243/// This is returned by [`WeightedIndex::weights`].
244pub struct WeightedIndexIter<'a, X: SampleUniform + PartialOrd> {
245 weighted_index: &'a WeightedIndex<X>,
246 index: usize,
247}
248
249impl<X> Debug for WeightedIndexIter<'_, X>
250where
251 X: SampleUniform + PartialOrd + Debug,
252 X::Sampler: Debug,
253{
254 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
255 f.debug_struct("WeightedIndexIter")
256 .field("weighted_index", &self.weighted_index)
257 .field("index", &self.index)
258 .finish()
259 }
260}
261
262impl<X> Clone for WeightedIndexIter<'_, X>
263where
264 X: SampleUniform + PartialOrd,
265{
266 fn clone(&self) -> Self {
267 WeightedIndexIter {
268 weighted_index: self.weighted_index,
269 index: self.index,
270 }
271 }
272}
273
274impl<X> Iterator for WeightedIndexIter<'_, X>
275where
276 X: for<'b> core::ops::SubAssign<&'b X> + SampleUniform + PartialOrd + Clone,
277{
278 type Item = X;
279
280 fn next(&mut self) -> Option<Self::Item> {
281 match self.weighted_index.weight(self.index) {
282 None => None,
283 Some(weight) => {
284 self.index += 1;
285 Some(weight)
286 }
287 }
288 }
289}
290
291impl<X: SampleUniform + PartialOrd + Clone> WeightedIndex<X> {
292 /// Returns the weight at the given index, if it exists.
293 ///
294 /// If the index is out of bounds, this will return `None`.
295 ///
296 /// # Example
297 ///
298 /// ```
299 /// use rand::distr::weighted::WeightedIndex;
300 ///
301 /// let weights = [0, 1, 2];
302 /// let dist = WeightedIndex::new(&weights).unwrap();
303 /// assert_eq!(dist.weight(0), Some(0));
304 /// assert_eq!(dist.weight(1), Some(1));
305 /// assert_eq!(dist.weight(2), Some(2));
306 /// assert_eq!(dist.weight(3), None);
307 /// ```
308 pub fn weight(&self, index: usize) -> Option<X>
309 where
310 X: for<'a> core::ops::SubAssign<&'a X>,
311 {
312 use core::cmp::Ordering::*;
313
314 let mut weight = match index.cmp(&self.cumulative_weights.len()) {
315 Less => self.cumulative_weights[index].clone(),
316 Equal => self.total_weight.clone(),
317 Greater => return None,
318 };
319
320 if index > 0 {
321 weight -= &self.cumulative_weights[index - 1];
322 }
323 Some(weight)
324 }
325
326 /// Returns a lazy-loading iterator containing the current weights of this distribution.
327 ///
328 /// If this distribution has not been updated since its creation, this will return the
329 /// same weights as were passed to `new`.
330 ///
331 /// # Example
332 ///
333 /// ```
334 /// use rand::distr::weighted::WeightedIndex;
335 ///
336 /// let weights = [1, 2, 3];
337 /// let mut dist = WeightedIndex::new(&weights).unwrap();
338 /// assert_eq!(dist.weights().collect::<Vec<_>>(), vec![1, 2, 3]);
339 /// dist.update_weights(&[(0, &2)]).unwrap();
340 /// assert_eq!(dist.weights().collect::<Vec<_>>(), vec![2, 2, 3]);
341 /// ```
342 pub fn weights(&self) -> WeightedIndexIter<'_, X>
343 where
344 X: for<'a> core::ops::SubAssign<&'a X>,
345 {
346 WeightedIndexIter {
347 weighted_index: self,
348 index: 0,
349 }
350 }
351
352 /// Returns the sum of all weights in this distribution.
353 pub fn total_weight(&self) -> X {
354 self.total_weight.clone()
355 }
356}
357
358impl<X> Distribution<usize> for WeightedIndex<X>
359where
360 X: SampleUniform + PartialOrd,
361{
362 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
363 let chosen_weight: X = self.weight_distribution.sample(rng);
364 // Find the first item which has a weight *higher* than the chosen weight.
365 self.cumulative_weights
366 .partition_point(|w| w <= &chosen_weight)
367 }
368}
369
370#[cfg(test)]
371mod test {
372 use super::*;
373
374 #[cfg(feature = "serde")]
375 #[test]
376 fn test_weightedindex_serde() {
377 let weighted_index = WeightedIndex::new([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).unwrap();
378
379 let ser_weighted_index = bincode::serialize(&weighted_index).unwrap();
380 let de_weighted_index: WeightedIndex<i32> =
381 bincode::deserialize(&ser_weighted_index).unwrap();
382
383 assert_eq!(
384 de_weighted_index.cumulative_weights,
385 weighted_index.cumulative_weights
386 );
387 assert_eq!(de_weighted_index.total_weight, weighted_index.total_weight);
388 }
389
390 #[test]
391 fn test_accepting_nan() {
392 assert_eq!(
393 WeightedIndex::new([f32::NAN, 0.5]).unwrap_err(),
394 Error::InvalidWeight,
395 );
396 assert_eq!(
397 WeightedIndex::new([f32::NAN]).unwrap_err(),
398 Error::InvalidWeight,
399 );
400 assert_eq!(
401 WeightedIndex::new([0.5, f32::NAN]).unwrap_err(),
402 Error::InvalidWeight,
403 );
404
405 assert_eq!(
406 WeightedIndex::new([0.5, 7.0])
407 .unwrap()
408 .update_weights(&[(0, &f32::NAN)])
409 .unwrap_err(),
410 Error::InvalidWeight,
411 )
412 }
413
414 #[test]
415 #[cfg_attr(miri, ignore)] // Miri is too slow
416 fn test_weightedindex() {
417 let mut r = crate::test::rng(700);
418 const N_REPS: u32 = 5000;
419 let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
420 let total_weight = weights.iter().sum::<u32>() as f32;
421
422 let verify = |result: [i32; 14]| {
423 for (i, count) in result.iter().enumerate() {
424 let exp = (weights[i] * N_REPS) as f32 / total_weight;
425 let mut err = (*count as f32 - exp).abs();
426 if err != 0.0 {
427 err /= exp;
428 }
429 assert!(err <= 0.25);
430 }
431 };
432
433 // WeightedIndex from vec
434 let mut chosen = [0i32; 14];
435 let distr = WeightedIndex::new(weights.to_vec()).unwrap();
436 for _ in 0..N_REPS {
437 chosen[distr.sample(&mut r)] += 1;
438 }
439 verify(chosen);
440
441 // WeightedIndex from slice
442 chosen = [0i32; 14];
443 let distr = WeightedIndex::new(&weights[..]).unwrap();
444 for _ in 0..N_REPS {
445 chosen[distr.sample(&mut r)] += 1;
446 }
447 verify(chosen);
448
449 // WeightedIndex from iterator
450 chosen = [0i32; 14];
451 let distr = WeightedIndex::new(weights.iter()).unwrap();
452 for _ in 0..N_REPS {
453 chosen[distr.sample(&mut r)] += 1;
454 }
455 verify(chosen);
456
457 for _ in 0..5 {
458 assert_eq!(WeightedIndex::new([0, 1]).unwrap().sample(&mut r), 1);
459 assert_eq!(WeightedIndex::new([1, 0]).unwrap().sample(&mut r), 0);
460 assert_eq!(
461 WeightedIndex::new([0, 0, 0, 0, 10, 0])
462 .unwrap()
463 .sample(&mut r),
464 4
465 );
466 }
467
468 assert_eq!(
469 WeightedIndex::new(&[10][0..0]).unwrap_err(),
470 Error::InvalidInput
471 );
472 assert_eq!(
473 WeightedIndex::new([0]).unwrap_err(),
474 Error::InsufficientNonZero
475 );
476 assert_eq!(
477 WeightedIndex::new([10, 20, -1, 30]).unwrap_err(),
478 Error::InvalidWeight
479 );
480 assert_eq!(
481 WeightedIndex::new([-10, 20, 1, 30]).unwrap_err(),
482 Error::InvalidWeight
483 );
484 assert_eq!(WeightedIndex::new([-10]).unwrap_err(), Error::InvalidWeight);
485 }
486
487 #[test]
488 fn test_update_weights() {
489 let data = [
490 (
491 &[10u32, 2, 3, 4][..],
492 &[(1, &100), (2, &4)][..], // positive change
493 &[10, 100, 4, 4][..],
494 ),
495 (
496 &[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..],
497 &[(2, &1), (5, &1), (13, &100)][..], // negative change and last element
498 &[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..],
499 ),
500 ];
501
502 for (weights, update, expected_weights) in data.iter() {
503 let total_weight = weights.iter().sum::<u32>();
504 let mut distr = WeightedIndex::new(weights.to_vec()).unwrap();
505 assert_eq!(distr.total_weight, total_weight);
506
507 distr.update_weights(update).unwrap();
508 let expected_total_weight = expected_weights.iter().sum::<u32>();
509 let expected_distr = WeightedIndex::new(expected_weights.to_vec()).unwrap();
510 assert_eq!(distr.total_weight, expected_total_weight);
511 assert_eq!(distr.total_weight, expected_distr.total_weight);
512 assert_eq!(distr.cumulative_weights, expected_distr.cumulative_weights);
513 }
514 }
515
516 #[test]
517 fn test_update_weights_errors() {
518 let data = [
519 (
520 &[1i32, 0, 0][..],
521 &[(0, &0)][..],
522 Error::InsufficientNonZero,
523 ),
524 (
525 &[10, 10, 10, 10][..],
526 &[(1, &-11)][..],
527 Error::InvalidWeight, // A weight is negative
528 ),
529 (
530 &[1, 2, 3, 4, 5][..],
531 &[(1, &5), (0, &5)][..], // Wrong order
532 Error::InvalidInput,
533 ),
534 (
535 &[1][..],
536 &[(1, &1)][..], // Index too large
537 Error::InvalidInput,
538 ),
539 ];
540
541 for (weights, update, err) in data.iter() {
542 let total_weight = weights.iter().sum::<i32>();
543 let mut distr = WeightedIndex::new(weights.to_vec()).unwrap();
544 assert_eq!(distr.total_weight, total_weight);
545 match distr.update_weights(update) {
546 Ok(_) => panic!("Expected update_weights to fail, but it succeeded"),
547 Err(e) => assert_eq!(e, *err),
548 }
549 }
550 }
551
552 #[test]
553 fn test_weight_at() {
554 let data = [
555 &[1][..],
556 &[10, 2, 3, 4][..],
557 &[1, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..],
558 &[u32::MAX][..],
559 ];
560
561 for weights in data.iter() {
562 let distr = WeightedIndex::new(weights.to_vec()).unwrap();
563 for (i, weight) in weights.iter().enumerate() {
564 assert_eq!(distr.weight(i), Some(*weight));
565 }
566 assert_eq!(distr.weight(weights.len()), None);
567 }
568 }
569
570 #[test]
571 fn test_weights() {
572 let data = [
573 &[1][..],
574 &[10, 2, 3, 4][..],
575 &[1, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..],
576 &[u32::MAX][..],
577 ];
578
579 for weights in data.iter() {
580 let distr = WeightedIndex::new(weights.to_vec()).unwrap();
581 assert_eq!(distr.weights().collect::<Vec<_>>(), weights.to_vec());
582 }
583 }
584
585 #[test]
586 fn value_stability() {
587 fn test_samples<X: Weight + SampleUniform + PartialOrd, I>(
588 weights: I,
589 buf: &mut [usize],
590 expected: &[usize],
591 ) where
592 I: IntoIterator,
593 I::Item: SampleBorrow<X>,
594 {
595 assert_eq!(buf.len(), expected.len());
596 let distr = WeightedIndex::new(weights).unwrap();
597 let mut rng = crate::test::rng(701);
598 for r in buf.iter_mut() {
599 *r = rng.sample(&distr);
600 }
601 assert_eq!(buf, expected);
602 }
603
604 let mut buf = [0; 10];
605 test_samples(
606 [1i32, 1, 1, 1, 1, 1, 1, 1, 1],
607 &mut buf,
608 &[0, 6, 2, 6, 3, 4, 7, 8, 2, 5],
609 );
610 test_samples(
611 [0.7f32, 0.1, 0.1, 0.1],
612 &mut buf,
613 &[0, 0, 0, 1, 0, 0, 2, 3, 0, 0],
614 );
615 test_samples(
616 [1.0f64, 0.999, 0.998, 0.997],
617 &mut buf,
618 &[2, 2, 1, 3, 2, 1, 3, 3, 2, 1],
619 );
620 }
621
622 #[test]
623 fn weighted_index_distributions_can_be_compared() {
624 assert_eq!(WeightedIndex::new([1, 2]), WeightedIndex::new([1, 2]));
625 }
626
627 #[test]
628 fn overflow() {
629 assert_eq!(WeightedIndex::new([2, usize::MAX]), Err(Error::Overflow));
630 }
631}
632