1 | // Copyright 2018-2023 Developers of the Rand project. |
2 | // |
3 | // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or |
4 | // https://www.apache.org/licenses/LICENSE-2.0> or the MIT license |
5 | // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your |
6 | // option. This file may not be copied, modified, or distributed |
7 | // except according to those terms. |
8 | |
9 | use crate::RngCore; |
10 | |
11 | pub(crate) struct CoinFlipper<R: RngCore> { |
12 | pub rng: R, |
13 | chunk: u32, // TODO(opt): this should depend on RNG word size |
14 | chunk_remaining: u32, |
15 | } |
16 | |
17 | impl<R: RngCore> CoinFlipper<R> { |
18 | pub fn new(rng: R) -> Self { |
19 | Self { |
20 | rng, |
21 | chunk: 0, |
22 | chunk_remaining: 0, |
23 | } |
24 | } |
25 | |
26 | #[inline ] |
27 | /// Returns true with a probability of 1 / d |
28 | /// Uses an expected two bits of randomness |
29 | /// Panics if d == 0 |
30 | pub fn random_ratio_one_over(&mut self, d: usize) -> bool { |
31 | debug_assert_ne!(d, 0); |
32 | // This uses the same logic as `random_ratio` but is optimized for the case that |
33 | // the starting numerator is one (which it always is for `Sequence::Choose()`) |
34 | |
35 | // In this case (but not `random_ratio`), this way of calculating c is always accurate |
36 | let c = (usize::BITS - 1 - d.leading_zeros()).min(32); |
37 | |
38 | if self.flip_c_heads(c) { |
39 | let numerator = 1 << c; |
40 | self.random_ratio(numerator, d) |
41 | } else { |
42 | false |
43 | } |
44 | } |
45 | |
46 | #[inline ] |
47 | /// Returns true with a probability of n / d |
48 | /// Uses an expected two bits of randomness |
49 | fn random_ratio(&mut self, mut n: usize, d: usize) -> bool { |
50 | // Explanation: |
51 | // We are trying to return true with a probability of n / d |
52 | // If n >= d, we can just return true |
53 | // Otherwise there are two possibilities 2n < d and 2n >= d |
54 | // In either case we flip a coin. |
55 | // If 2n < d |
56 | // If it comes up tails, return false |
57 | // If it comes up heads, double n and start again |
58 | // This is fair because (0.5 * 0) + (0.5 * 2n / d) = n / d and 2n is less than d |
59 | // (if 2n was greater than d we would effectively round it down to 1 |
60 | // by returning true) |
61 | // If 2n >= d |
62 | // If it comes up tails, set n to 2n - d and start again |
63 | // If it comes up heads, return true |
64 | // This is fair because (0.5 * 1) + (0.5 * (2n - d) / d) = n / d |
65 | // Note that if 2n = d and the coin comes up tails, n will be set to 0 |
66 | // before restarting which is equivalent to returning false. |
67 | |
68 | // As a performance optimization we can flip multiple coins at once |
69 | // This is efficient because we can use the `lzcnt` intrinsic |
70 | // We can check up to 32 flips at once but we only receive one bit of information |
71 | // - all heads or at least one tail. |
72 | |
73 | // Let c be the number of coins to flip. 1 <= c <= 32 |
74 | // If 2n < d, n * 2^c < d |
75 | // If the result is all heads, then set n to n * 2^c |
76 | // If there was at least one tail, return false |
77 | // If 2n >= d, the order of results matters so we flip one coin at a time so c = 1 |
78 | // Ideally, c will be as high as possible within these constraints |
79 | |
80 | while n < d { |
81 | // Find a good value for c by counting leading zeros |
82 | // This will either give the highest possible c, or 1 less than that |
83 | let c = n |
84 | .leading_zeros() |
85 | .saturating_sub(d.leading_zeros() + 1) |
86 | .clamp(1, 32); |
87 | |
88 | if self.flip_c_heads(c) { |
89 | // All heads |
90 | // Set n to n * 2^c |
91 | // If 2n >= d, the while loop will exit and we will return `true` |
92 | // If n * 2^c > `usize::MAX` we always return `true` anyway |
93 | n = n.saturating_mul(2_usize.pow(c)); |
94 | } else { |
95 | // At least one tail |
96 | if c == 1 { |
97 | // Calculate 2n - d. |
98 | // We need to use wrapping as 2n might be greater than `usize::MAX` |
99 | let next_n = n.wrapping_add(n).wrapping_sub(d); |
100 | if next_n == 0 || next_n > n { |
101 | // This will happen if 2n < d |
102 | return false; |
103 | } |
104 | n = next_n; |
105 | } else { |
106 | // c > 1 so 2n < d so we can return false |
107 | return false; |
108 | } |
109 | } |
110 | } |
111 | true |
112 | } |
113 | |
114 | /// If the next `c` bits of randomness all represent heads, consume them, return true |
115 | /// Otherwise return false and consume the number of heads plus one. |
116 | /// Generates new bits of randomness when necessary (in 32 bit chunks) |
117 | /// Has a 1 in 2 to the `c` chance of returning true |
118 | /// `c` must be less than or equal to 32 |
119 | fn flip_c_heads(&mut self, mut c: u32) -> bool { |
120 | debug_assert!(c <= 32); |
121 | // Note that zeros on the left of the chunk represent heads. |
122 | // It needs to be this way round because zeros are filled in when left shifting |
123 | loop { |
124 | let zeros = self.chunk.leading_zeros(); |
125 | |
126 | if zeros < c { |
127 | // The happy path - we found a 1 and can return false |
128 | // Note that because a 1 bit was detected, |
129 | // We cannot have run out of random bits so we don't need to check |
130 | |
131 | // First consume all of the bits read |
132 | // Using shl seems to give worse performance for size-hinted iterators |
133 | self.chunk = self.chunk.wrapping_shl(zeros + 1); |
134 | |
135 | self.chunk_remaining = self.chunk_remaining.saturating_sub(zeros + 1); |
136 | return false; |
137 | } else { |
138 | // The number of zeros is larger than `c` |
139 | // There are two possibilities |
140 | if let Some(new_remaining) = self.chunk_remaining.checked_sub(c) { |
141 | // Those zeroes were all part of our random chunk, |
142 | // throw away `c` bits of randomness and return true |
143 | self.chunk_remaining = new_remaining; |
144 | self.chunk <<= c; |
145 | return true; |
146 | } else { |
147 | // Some of those zeroes were part of the random chunk |
148 | // and some were part of the space behind it |
149 | // We need to take into account only the zeroes that were random |
150 | c -= self.chunk_remaining; |
151 | |
152 | // Generate a new chunk |
153 | self.chunk = self.rng.next_u32(); |
154 | self.chunk_remaining = 32; |
155 | // Go back to start of loop |
156 | } |
157 | } |
158 | } |
159 | } |
160 | } |
161 | |