1// Copyright 2018 Developers of the Rand project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! Weighted index sampling
10
11use crate::distributions::uniform::{SampleBorrow, SampleUniform, UniformSampler};
12use crate::distributions::Distribution;
13use crate::Rng;
14use core::cmp::PartialOrd;
15use core::fmt;
16
17// Note that this whole module is only imported if feature="alloc" is enabled.
18use alloc::vec::Vec;
19
20#[cfg(feature = "serde1")]
21use serde::{Serialize, Deserialize};
22
23/// A distribution using weighted sampling of discrete items
24///
25/// Sampling a `WeightedIndex` distribution returns the index of a randomly
26/// selected element from the iterator used when the `WeightedIndex` was
27/// created. The chance of a given element being picked is proportional to the
28/// value of the element. The weights can use any type `X` for which an
29/// implementation of [`Uniform<X>`] exists.
30///
31/// # Performance
32///
33/// Time complexity of sampling from `WeightedIndex` is `O(log N)` where
34/// `N` is the number of weights. As an alternative,
35/// [`rand_distr::weighted_alias`](https://docs.rs/rand_distr/*/rand_distr/weighted_alias/index.html)
36/// supports `O(1)` sampling, but with much higher initialisation cost.
37///
38/// A `WeightedIndex<X>` contains a `Vec<X>` and a [`Uniform<X>`] and so its
39/// size is the sum of the size of those objects, possibly plus some alignment.
40///
41/// Creating a `WeightedIndex<X>` will allocate enough space to hold `N - 1`
42/// weights of type `X`, where `N` is the number of weights. However, since
43/// `Vec` doesn't guarantee a particular growth strategy, additional memory
44/// might be allocated but not used. Since the `WeightedIndex` object also
45/// contains, this might cause additional allocations, though for primitive
46/// types, [`Uniform<X>`] doesn't allocate any memory.
47///
48/// Sampling from `WeightedIndex` will result in a single call to
49/// `Uniform<X>::sample` (method of the [`Distribution`] trait), which typically
50/// will request a single value from the underlying [`RngCore`], though the
51/// exact number depends on the implementation of `Uniform<X>::sample`.
52///
53/// # Example
54///
55/// ```
56/// use rand::prelude::*;
57/// use rand::distributions::WeightedIndex;
58///
59/// let choices = ['a', 'b', 'c'];
60/// let weights = [2, 1, 1];
61/// let dist = WeightedIndex::new(&weights).unwrap();
62/// let mut rng = thread_rng();
63/// for _ in 0..100 {
64/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
65/// println!("{}", choices[dist.sample(&mut rng)]);
66/// }
67///
68/// let items = [('a', 0), ('b', 3), ('c', 7)];
69/// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap();
70/// for _ in 0..100 {
71/// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c'
72/// println!("{}", items[dist2.sample(&mut rng)].0);
73/// }
74/// ```
75///
76/// [`Uniform<X>`]: crate::distributions::Uniform
77/// [`RngCore`]: crate::RngCore
78#[derive(Debug, Clone, PartialEq)]
79#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
80#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
81pub struct WeightedIndex<X: SampleUniform + PartialOrd> {
82 cumulative_weights: Vec<X>,
83 total_weight: X,
84 weight_distribution: X::Sampler,
85}
86
87impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
88 /// Creates a new a `WeightedIndex` [`Distribution`] using the values
89 /// in `weights`. The weights can use any type `X` for which an
90 /// implementation of [`Uniform<X>`] exists.
91 ///
92 /// Returns an error if the iterator is empty, if any weight is `< 0`, or
93 /// if its total value is 0.
94 ///
95 /// [`Uniform<X>`]: crate::distributions::uniform::Uniform
96 pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, WeightedError>
97 where
98 I: IntoIterator,
99 I::Item: SampleBorrow<X>,
100 X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default,
101 {
102 let mut iter = weights.into_iter();
103 let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone();
104
105 let zero = <X as Default>::default();
106 if !(total_weight >= zero) {
107 return Err(WeightedError::InvalidWeight);
108 }
109
110 let mut weights = Vec::<X>::with_capacity(iter.size_hint().0);
111 for w in iter {
112 // Note that `!(w >= x)` is not equivalent to `w < x` for partially
113 // ordered types due to NaNs which are equal to nothing.
114 if !(w.borrow() >= &zero) {
115 return Err(WeightedError::InvalidWeight);
116 }
117 weights.push(total_weight.clone());
118 total_weight += w.borrow();
119 }
120
121 if total_weight == zero {
122 return Err(WeightedError::AllWeightsZero);
123 }
124 let distr = X::Sampler::new(zero, total_weight.clone());
125
126 Ok(WeightedIndex {
127 cumulative_weights: weights,
128 total_weight,
129 weight_distribution: distr,
130 })
131 }
132
133 /// Update a subset of weights, without changing the number of weights.
134 ///
135 /// `new_weights` must be sorted by the index.
136 ///
137 /// Using this method instead of `new` might be more efficient if only a small number of
138 /// weights is modified. No allocations are performed, unless the weight type `X` uses
139 /// allocation internally.
140 ///
141 /// In case of error, `self` is not modified.
142 pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightedError>
143 where X: for<'a> ::core::ops::AddAssign<&'a X>
144 + for<'a> ::core::ops::SubAssign<&'a X>
145 + Clone
146 + Default {
147 if new_weights.is_empty() {
148 return Ok(());
149 }
150
151 let zero = <X as Default>::default();
152
153 let mut total_weight = self.total_weight.clone();
154
155 // Check for errors first, so we don't modify `self` in case something
156 // goes wrong.
157 let mut prev_i = None;
158 for &(i, w) in new_weights {
159 if let Some(old_i) = prev_i {
160 if old_i >= i {
161 return Err(WeightedError::InvalidWeight);
162 }
163 }
164 if !(*w >= zero) {
165 return Err(WeightedError::InvalidWeight);
166 }
167 if i > self.cumulative_weights.len() {
168 return Err(WeightedError::TooMany);
169 }
170
171 let mut old_w = if i < self.cumulative_weights.len() {
172 self.cumulative_weights[i].clone()
173 } else {
174 self.total_weight.clone()
175 };
176 if i > 0 {
177 old_w -= &self.cumulative_weights[i - 1];
178 }
179
180 total_weight -= &old_w;
181 total_weight += w;
182 prev_i = Some(i);
183 }
184 if total_weight <= zero {
185 return Err(WeightedError::AllWeightsZero);
186 }
187
188 // Update the weights. Because we checked all the preconditions in the
189 // previous loop, this should never panic.
190 let mut iter = new_weights.iter();
191
192 let mut prev_weight = zero.clone();
193 let mut next_new_weight = iter.next();
194 let &(first_new_index, _) = next_new_weight.unwrap();
195 let mut cumulative_weight = if first_new_index > 0 {
196 self.cumulative_weights[first_new_index - 1].clone()
197 } else {
198 zero.clone()
199 };
200 for i in first_new_index..self.cumulative_weights.len() {
201 match next_new_weight {
202 Some(&(j, w)) if i == j => {
203 cumulative_weight += w;
204 next_new_weight = iter.next();
205 }
206 _ => {
207 let mut tmp = self.cumulative_weights[i].clone();
208 tmp -= &prev_weight; // We know this is positive.
209 cumulative_weight += &tmp;
210 }
211 }
212 prev_weight = cumulative_weight.clone();
213 core::mem::swap(&mut prev_weight, &mut self.cumulative_weights[i]);
214 }
215
216 self.total_weight = total_weight;
217 self.weight_distribution = X::Sampler::new(zero, self.total_weight.clone());
218
219 Ok(())
220 }
221}
222
223impl<X> Distribution<usize> for WeightedIndex<X>
224where X: SampleUniform + PartialOrd
225{
226 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
227 use ::core::cmp::Ordering;
228 let chosen_weight: X = self.weight_distribution.sample(rng);
229 // Find the first item which has a weight *higher* than the chosen weight.
230 self.cumulative_weights
231 .binary_search_by(|w: &X| {
232 if *w <= chosen_weight {
233 Ordering::Less
234 } else {
235 Ordering::Greater
236 }
237 })
238 .unwrap_err()
239 }
240}
241
242#[cfg(test)]
243mod test {
244 use super::*;
245
246 #[cfg(feature = "serde1")]
247 #[test]
248 fn test_weightedindex_serde1() {
249 let weighted_index = WeightedIndex::new(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).unwrap();
250
251 let ser_weighted_index = bincode::serialize(&weighted_index).unwrap();
252 let de_weighted_index: WeightedIndex<i32> =
253 bincode::deserialize(&ser_weighted_index).unwrap();
254
255 assert_eq!(
256 de_weighted_index.cumulative_weights,
257 weighted_index.cumulative_weights
258 );
259 assert_eq!(de_weighted_index.total_weight, weighted_index.total_weight);
260 }
261
262 #[test]
263 fn test_accepting_nan(){
264 assert_eq!(
265 WeightedIndex::new(&[core::f32::NAN, 0.5]).unwrap_err(),
266 WeightedError::InvalidWeight,
267 );
268 assert_eq!(
269 WeightedIndex::new(&[core::f32::NAN]).unwrap_err(),
270 WeightedError::InvalidWeight,
271 );
272 assert_eq!(
273 WeightedIndex::new(&[0.5, core::f32::NAN]).unwrap_err(),
274 WeightedError::InvalidWeight,
275 );
276
277 assert_eq!(
278 WeightedIndex::new(&[0.5, 7.0])
279 .unwrap()
280 .update_weights(&[(0, &core::f32::NAN)])
281 .unwrap_err(),
282 WeightedError::InvalidWeight,
283 )
284 }
285
286
287 #[test]
288 #[cfg_attr(miri, ignore)] // Miri is too slow
289 fn test_weightedindex() {
290 let mut r = crate::test::rng(700);
291 const N_REPS: u32 = 5000;
292 let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
293 let total_weight = weights.iter().sum::<u32>() as f32;
294
295 let verify = |result: [i32; 14]| {
296 for (i, count) in result.iter().enumerate() {
297 let exp = (weights[i] * N_REPS) as f32 / total_weight;
298 let mut err = (*count as f32 - exp).abs();
299 if err != 0.0 {
300 err /= exp;
301 }
302 assert!(err <= 0.25);
303 }
304 };
305
306 // WeightedIndex from vec
307 let mut chosen = [0i32; 14];
308 let distr = WeightedIndex::new(weights.to_vec()).unwrap();
309 for _ in 0..N_REPS {
310 chosen[distr.sample(&mut r)] += 1;
311 }
312 verify(chosen);
313
314 // WeightedIndex from slice
315 chosen = [0i32; 14];
316 let distr = WeightedIndex::new(&weights[..]).unwrap();
317 for _ in 0..N_REPS {
318 chosen[distr.sample(&mut r)] += 1;
319 }
320 verify(chosen);
321
322 // WeightedIndex from iterator
323 chosen = [0i32; 14];
324 let distr = WeightedIndex::new(weights.iter()).unwrap();
325 for _ in 0..N_REPS {
326 chosen[distr.sample(&mut r)] += 1;
327 }
328 verify(chosen);
329
330 for _ in 0..5 {
331 assert_eq!(WeightedIndex::new(&[0, 1]).unwrap().sample(&mut r), 1);
332 assert_eq!(WeightedIndex::new(&[1, 0]).unwrap().sample(&mut r), 0);
333 assert_eq!(
334 WeightedIndex::new(&[0, 0, 0, 0, 10, 0])
335 .unwrap()
336 .sample(&mut r),
337 4
338 );
339 }
340
341 assert_eq!(
342 WeightedIndex::new(&[10][0..0]).unwrap_err(),
343 WeightedError::NoItem
344 );
345 assert_eq!(
346 WeightedIndex::new(&[0]).unwrap_err(),
347 WeightedError::AllWeightsZero
348 );
349 assert_eq!(
350 WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(),
351 WeightedError::InvalidWeight
352 );
353 assert_eq!(
354 WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(),
355 WeightedError::InvalidWeight
356 );
357 assert_eq!(
358 WeightedIndex::new(&[-10]).unwrap_err(),
359 WeightedError::InvalidWeight
360 );
361 }
362
363 #[test]
364 fn test_update_weights() {
365 let data = [
366 (
367 &[10u32, 2, 3, 4][..],
368 &[(1, &100), (2, &4)][..], // positive change
369 &[10, 100, 4, 4][..],
370 ),
371 (
372 &[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..],
373 &[(2, &1), (5, &1), (13, &100)][..], // negative change and last element
374 &[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..],
375 ),
376 ];
377
378 for (weights, update, expected_weights) in data.iter() {
379 let total_weight = weights.iter().sum::<u32>();
380 let mut distr = WeightedIndex::new(weights.to_vec()).unwrap();
381 assert_eq!(distr.total_weight, total_weight);
382
383 distr.update_weights(update).unwrap();
384 let expected_total_weight = expected_weights.iter().sum::<u32>();
385 let expected_distr = WeightedIndex::new(expected_weights.to_vec()).unwrap();
386 assert_eq!(distr.total_weight, expected_total_weight);
387 assert_eq!(distr.total_weight, expected_distr.total_weight);
388 assert_eq!(distr.cumulative_weights, expected_distr.cumulative_weights);
389 }
390 }
391
392 #[test]
393 fn value_stability() {
394 fn test_samples<X: SampleUniform + PartialOrd, I>(
395 weights: I, buf: &mut [usize], expected: &[usize],
396 ) where
397 I: IntoIterator,
398 I::Item: SampleBorrow<X>,
399 X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default,
400 {
401 assert_eq!(buf.len(), expected.len());
402 let distr = WeightedIndex::new(weights).unwrap();
403 let mut rng = crate::test::rng(701);
404 for r in buf.iter_mut() {
405 *r = rng.sample(&distr);
406 }
407 assert_eq!(buf, expected);
408 }
409
410 let mut buf = [0; 10];
411 test_samples(&[1i32, 1, 1, 1, 1, 1, 1, 1, 1], &mut buf, &[
412 0, 6, 2, 6, 3, 4, 7, 8, 2, 5,
413 ]);
414 test_samples(&[0.7f32, 0.1, 0.1, 0.1], &mut buf, &[
415 0, 0, 0, 1, 0, 0, 2, 3, 0, 0,
416 ]);
417 test_samples(&[1.0f64, 0.999, 0.998, 0.997], &mut buf, &[
418 2, 2, 1, 3, 2, 1, 3, 3, 2, 1,
419 ]);
420 }
421
422 #[test]
423 fn weighted_index_distributions_can_be_compared() {
424 assert_eq!(WeightedIndex::new(&[1, 2]), WeightedIndex::new(&[1, 2]));
425 }
426}
427
428/// Error type returned from `WeightedIndex::new`.
429#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
430#[derive(Debug, Clone, Copy, PartialEq, Eq)]
431pub enum WeightedError {
432 /// The provided weight collection contains no items.
433 NoItem,
434
435 /// A weight is either less than zero, greater than the supported maximum,
436 /// NaN, or otherwise invalid.
437 InvalidWeight,
438
439 /// All items in the provided weight collection are zero.
440 AllWeightsZero,
441
442 /// Too many weights are provided (length greater than `u32::MAX`)
443 TooMany,
444}
445
446#[cfg(feature = "std")]
447impl std::error::Error for WeightedError {}
448
449impl fmt::Display for WeightedError {
450 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
451 f.write_str(data:match *self {
452 WeightedError::NoItem => "No weights provided in distribution",
453 WeightedError::InvalidWeight => "A weight is invalid in distribution",
454 WeightedError::AllWeightsZero => "All weights are zero in distribution",
455 WeightedError::TooMany => "Too many weights (hit u32::MAX) in distribution",
456 })
457 }
458}
459