1 | //! A `Barrier` that provides `wait_timeout`. |
2 | //! |
3 | //! This implementation mirrors that of the Rust standard library. |
4 | |
5 | use crate::loom::sync::{Condvar, Mutex}; |
6 | use std::fmt; |
7 | use std::time::{Duration, Instant}; |
8 | |
9 | /// A barrier enables multiple threads to synchronize the beginning |
10 | /// of some computation. |
11 | /// |
12 | /// # Examples |
13 | /// |
14 | /// ``` |
15 | /// use std::sync::{Arc, Barrier}; |
16 | /// use std::thread; |
17 | /// |
18 | /// let mut handles = Vec::with_capacity(10); |
19 | /// let barrier = Arc::new(Barrier::new(10)); |
20 | /// for _ in 0..10 { |
21 | /// let c = Arc::clone(&barrier); |
22 | /// // The same messages will be printed together. |
23 | /// // You will NOT see any interleaving. |
24 | /// handles.push(thread::spawn(move|| { |
25 | /// println!("before wait" ); |
26 | /// c.wait(); |
27 | /// println!("after wait" ); |
28 | /// })); |
29 | /// } |
30 | /// // Wait for other threads to finish. |
31 | /// for handle in handles { |
32 | /// handle.join().unwrap(); |
33 | /// } |
34 | /// ``` |
35 | pub(crate) struct Barrier { |
36 | lock: Mutex<BarrierState>, |
37 | cvar: Condvar, |
38 | num_threads: usize, |
39 | } |
40 | |
41 | // The inner state of a double barrier |
42 | struct BarrierState { |
43 | count: usize, |
44 | generation_id: usize, |
45 | } |
46 | |
47 | /// A `BarrierWaitResult` is returned by [`Barrier::wait()`] when all threads |
48 | /// in the [`Barrier`] have rendezvoused. |
49 | /// |
50 | /// # Examples |
51 | /// |
52 | /// ``` |
53 | /// use std::sync::Barrier; |
54 | /// |
55 | /// let barrier = Barrier::new(1); |
56 | /// let barrier_wait_result = barrier.wait(); |
57 | /// ``` |
58 | pub(crate) struct BarrierWaitResult(bool); |
59 | |
60 | impl fmt::Debug for Barrier { |
61 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
62 | f.debug_struct(name:"Barrier" ).finish_non_exhaustive() |
63 | } |
64 | } |
65 | |
66 | impl Barrier { |
67 | /// Creates a new barrier that can block a given number of threads. |
68 | /// |
69 | /// A barrier will block `n`-1 threads which call [`wait()`] and then wake |
70 | /// up all threads at once when the `n`th thread calls [`wait()`]. |
71 | /// |
72 | /// [`wait()`]: Barrier::wait |
73 | /// |
74 | /// # Examples |
75 | /// |
76 | /// ``` |
77 | /// use std::sync::Barrier; |
78 | /// |
79 | /// let barrier = Barrier::new(10); |
80 | /// ``` |
81 | #[must_use ] |
82 | pub(crate) fn new(n: usize) -> Barrier { |
83 | Barrier { |
84 | lock: Mutex::new(BarrierState { |
85 | count: 0, |
86 | generation_id: 0, |
87 | }), |
88 | cvar: Condvar::new(), |
89 | num_threads: n, |
90 | } |
91 | } |
92 | |
93 | /// Blocks the current thread until all threads have rendezvoused here. |
94 | /// |
95 | /// Barriers are re-usable after all threads have rendezvoused once, and can |
96 | /// be used continuously. |
97 | /// |
98 | /// A single (arbitrary) thread will receive a [`BarrierWaitResult`] that |
99 | /// returns `true` from [`BarrierWaitResult::is_leader()`] when returning |
100 | /// from this function, and all other threads will receive a result that |
101 | /// will return `false` from [`BarrierWaitResult::is_leader()`]. |
102 | /// |
103 | /// # Examples |
104 | /// |
105 | /// ``` |
106 | /// use std::sync::{Arc, Barrier}; |
107 | /// use std::thread; |
108 | /// |
109 | /// let mut handles = Vec::with_capacity(10); |
110 | /// let barrier = Arc::new(Barrier::new(10)); |
111 | /// for _ in 0..10 { |
112 | /// let c = Arc::clone(&barrier); |
113 | /// // The same messages will be printed together. |
114 | /// // You will NOT see any interleaving. |
115 | /// handles.push(thread::spawn(move|| { |
116 | /// println!("before wait" ); |
117 | /// c.wait(); |
118 | /// println!("after wait" ); |
119 | /// })); |
120 | /// } |
121 | /// // Wait for other threads to finish. |
122 | /// for handle in handles { |
123 | /// handle.join().unwrap(); |
124 | /// } |
125 | /// ``` |
126 | pub(crate) fn wait(&self) -> BarrierWaitResult { |
127 | let mut lock = self.lock.lock(); |
128 | let local_gen = lock.generation_id; |
129 | lock.count += 1; |
130 | if lock.count < self.num_threads { |
131 | // We need a while loop to guard against spurious wakeups. |
132 | // https://en.wikipedia.org/wiki/Spurious_wakeup |
133 | while local_gen == lock.generation_id { |
134 | lock = self.cvar.wait(lock).unwrap(); |
135 | } |
136 | BarrierWaitResult(false) |
137 | } else { |
138 | lock.count = 0; |
139 | lock.generation_id = lock.generation_id.wrapping_add(1); |
140 | self.cvar.notify_all(); |
141 | BarrierWaitResult(true) |
142 | } |
143 | } |
144 | |
145 | /// Blocks the current thread until all threads have rendezvoused here for |
146 | /// at most `timeout` duration. |
147 | pub(crate) fn wait_timeout(&self, timeout: Duration) -> Option<BarrierWaitResult> { |
148 | // This implementation mirrors `wait`, but with each blocking operation |
149 | // replaced by a timeout-amenable alternative. |
150 | |
151 | let deadline = Instant::now() + timeout; |
152 | |
153 | // Acquire `self.lock` with at most `timeout` duration. |
154 | let mut lock = loop { |
155 | if let Some(guard) = self.lock.try_lock() { |
156 | break guard; |
157 | } else if Instant::now() > deadline { |
158 | return None; |
159 | } else { |
160 | std::thread::yield_now(); |
161 | } |
162 | }; |
163 | |
164 | // Shrink the `timeout` to account for the time taken to acquire `lock`. |
165 | let timeout = deadline.saturating_duration_since(Instant::now()); |
166 | |
167 | let local_gen = lock.generation_id; |
168 | lock.count += 1; |
169 | if lock.count < self.num_threads { |
170 | // We need a while loop to guard against spurious wakeups. |
171 | // https://en.wikipedia.org/wiki/Spurious_wakeup |
172 | while local_gen == lock.generation_id { |
173 | let (guard, timeout_result) = self.cvar.wait_timeout(lock, timeout).unwrap(); |
174 | lock = guard; |
175 | if timeout_result.timed_out() { |
176 | return None; |
177 | } |
178 | } |
179 | Some(BarrierWaitResult(false)) |
180 | } else { |
181 | lock.count = 0; |
182 | lock.generation_id = lock.generation_id.wrapping_add(1); |
183 | self.cvar.notify_all(); |
184 | Some(BarrierWaitResult(true)) |
185 | } |
186 | } |
187 | } |
188 | |
189 | impl fmt::Debug for BarrierWaitResult { |
190 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
191 | f&mut DebugStruct<'_, '_>.debug_struct("BarrierWaitResult" ) |
192 | .field(name:"is_leader" , &self.is_leader()) |
193 | .finish() |
194 | } |
195 | } |
196 | |
197 | impl BarrierWaitResult { |
198 | /// Returns `true` if this thread is the "leader thread" for the call to |
199 | /// [`Barrier::wait()`]. |
200 | /// |
201 | /// Only one thread will have `true` returned from their result, all other |
202 | /// threads will have `false` returned. |
203 | /// |
204 | /// # Examples |
205 | /// |
206 | /// ``` |
207 | /// use std::sync::Barrier; |
208 | /// |
209 | /// let barrier = Barrier::new(1); |
210 | /// let barrier_wait_result = barrier.wait(); |
211 | /// println!("{:?}" , barrier_wait_result.is_leader()); |
212 | /// ``` |
213 | #[must_use ] |
214 | pub(crate) fn is_leader(&self) -> bool { |
215 | self.0 |
216 | } |
217 | } |
218 | |