1 | use event_listener::{Event, EventListener}; |
2 | use event_listener_strategy::{easy_wrapper, EventListenerFuture, Strategy}; |
3 | |
4 | use core::fmt; |
5 | use core::pin::Pin; |
6 | use core::task::Poll; |
7 | |
8 | use crate::futures::Lock; |
9 | use crate::Mutex; |
10 | |
11 | /// A counter to synchronize multiple tasks at the same time. |
12 | #[derive (Debug)] |
13 | pub struct Barrier { |
14 | n: usize, |
15 | state: Mutex<State>, |
16 | event: Event, |
17 | } |
18 | |
19 | #[derive (Debug)] |
20 | struct State { |
21 | count: usize, |
22 | generation_id: u64, |
23 | } |
24 | |
25 | impl Barrier { |
26 | const_fn! { |
27 | const_if: #[cfg(not(loom))]; |
28 | /// Creates a barrier that can block the given number of tasks. |
29 | /// |
30 | /// A barrier will block `n`-1 tasks which call [`wait()`] and then wake up all tasks |
31 | /// at once when the `n`th task calls [`wait()`]. |
32 | /// |
33 | /// [`wait()`]: `Barrier::wait()` |
34 | /// |
35 | /// # Examples |
36 | /// |
37 | /// ``` |
38 | /// use async_lock::Barrier; |
39 | /// |
40 | /// let barrier = Barrier::new(5); |
41 | /// ``` |
42 | pub const fn new(n: usize) -> Barrier { |
43 | Barrier { |
44 | n, |
45 | state: Mutex::new(State { |
46 | count: 0, |
47 | generation_id: 0, |
48 | }), |
49 | event: Event::new(), |
50 | } |
51 | } |
52 | } |
53 | |
54 | /// Blocks the current task until all tasks reach this point. |
55 | /// |
56 | /// Barriers are reusable after all tasks have synchronized, and can be used continuously. |
57 | /// |
58 | /// Returns a [`BarrierWaitResult`] indicating whether this task is the "leader", meaning the |
59 | /// last task to call this method. |
60 | /// |
61 | /// # Examples |
62 | /// |
63 | /// ``` |
64 | /// use async_lock::Barrier; |
65 | /// use futures_lite::future; |
66 | /// use std::sync::Arc; |
67 | /// use std::thread; |
68 | /// |
69 | /// let barrier = Arc::new(Barrier::new(5)); |
70 | /// |
71 | /// for _ in 0..5 { |
72 | /// let b = barrier.clone(); |
73 | /// thread::spawn(move || { |
74 | /// future::block_on(async { |
75 | /// // The same messages will be printed together. |
76 | /// // There will NOT be interleaving of "before" and "after". |
77 | /// println!("before wait" ); |
78 | /// b.wait().await; |
79 | /// println!("after wait" ); |
80 | /// }); |
81 | /// }); |
82 | /// } |
83 | /// ``` |
84 | pub fn wait(&self) -> BarrierWait<'_> { |
85 | BarrierWait::_new(BarrierWaitInner { |
86 | barrier: self, |
87 | lock: Some(self.state.lock()), |
88 | evl: None, |
89 | state: WaitState::Initial, |
90 | }) |
91 | } |
92 | |
93 | /// Blocks the current thread until all tasks reach this point. |
94 | /// |
95 | /// Barriers are reusable after all tasks have synchronized, and can be used continuously. |
96 | /// |
97 | /// Returns a [`BarrierWaitResult`] indicating whether this task is the "leader", meaning the |
98 | /// last task to call this method. |
99 | /// |
100 | /// # Blocking |
101 | /// |
102 | /// Rather than using asynchronous waiting, like the [`wait`][`Barrier::wait`] method, |
103 | /// this method will block the current thread until the wait is complete. |
104 | /// |
105 | /// This method should not be used in an asynchronous context. It is intended to be |
106 | /// used in a way that a barrier can be used in both asynchronous and synchronous contexts. |
107 | /// Calling this method in an asynchronous context may result in a deadlock. |
108 | /// |
109 | /// # Examples |
110 | /// |
111 | /// ``` |
112 | /// use async_lock::Barrier; |
113 | /// use futures_lite::future; |
114 | /// use std::sync::Arc; |
115 | /// use std::thread; |
116 | /// |
117 | /// let barrier = Arc::new(Barrier::new(5)); |
118 | /// |
119 | /// for _ in 0..5 { |
120 | /// let b = barrier.clone(); |
121 | /// thread::spawn(move || { |
122 | /// // The same messages will be printed together. |
123 | /// // There will NOT be interleaving of "before" and "after". |
124 | /// println!("before wait" ); |
125 | /// b.wait_blocking(); |
126 | /// println!("after wait" ); |
127 | /// }); |
128 | /// } |
129 | /// # // Wait for threads to stop. |
130 | /// # std::thread::sleep(std::time::Duration::from_secs(1)); |
131 | /// ``` |
132 | #[cfg (all(feature = "std" , not(target_family = "wasm" )))] |
133 | pub fn wait_blocking(&self) -> BarrierWaitResult { |
134 | self.wait().wait() |
135 | } |
136 | } |
137 | |
138 | easy_wrapper! { |
139 | /// The future returned by [`Barrier::wait()`]. |
140 | pub struct BarrierWait<'a>(BarrierWaitInner<'a> => BarrierWaitResult); |
141 | #[cfg (all(feature = "std" , not(target_family = "wasm" )))] |
142 | pub(crate) wait(); |
143 | } |
144 | |
145 | pin_project_lite::pin_project! { |
146 | /// The future returned by [`Barrier::wait()`]. |
147 | struct BarrierWaitInner<'a> { |
148 | // The barrier to wait on. |
149 | barrier: &'a Barrier, |
150 | |
151 | // The ongoing mutex lock operation we are blocking on. |
152 | #[pin] |
153 | lock: Option<Lock<'a, State>>, |
154 | |
155 | // An event listener for the `barrier.event` event. |
156 | evl: Option<EventListener>, |
157 | |
158 | // The current state of the future. |
159 | state: WaitState, |
160 | } |
161 | } |
162 | |
163 | impl fmt::Debug for BarrierWait<'_> { |
164 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
165 | f.write_str(data:"BarrierWait { .. }" ) |
166 | } |
167 | } |
168 | |
169 | enum WaitState { |
170 | /// We are getting the original values of the state. |
171 | Initial, |
172 | |
173 | /// We are waiting for the listener to complete. |
174 | Waiting { local_gen: u64 }, |
175 | |
176 | /// Waiting to re-acquire the lock to check the state again. |
177 | Reacquiring { local_gen: u64 }, |
178 | } |
179 | |
180 | impl EventListenerFuture for BarrierWaitInner<'_> { |
181 | type Output = BarrierWaitResult; |
182 | |
183 | fn poll_with_strategy<'a, S: Strategy<'a>>( |
184 | self: Pin<&mut Self>, |
185 | strategy: &mut S, |
186 | cx: &mut S::Context, |
187 | ) -> Poll<Self::Output> { |
188 | let mut this = self.project(); |
189 | |
190 | loop { |
191 | match this.state { |
192 | WaitState::Initial => { |
193 | // See if the lock is ready yet. |
194 | let mut state = ready!(this |
195 | .lock |
196 | .as_mut() |
197 | .as_pin_mut() |
198 | .unwrap() |
199 | .poll_with_strategy(strategy, cx)); |
200 | this.lock.as_mut().set(None); |
201 | |
202 | let local_gen = state.generation_id; |
203 | state.count += 1; |
204 | |
205 | if state.count < this.barrier.n { |
206 | // We need to wait for the event. |
207 | *this.evl = Some(this.barrier.event.listen()); |
208 | *this.state = WaitState::Waiting { local_gen }; |
209 | } else { |
210 | // We are the last one. |
211 | state.count = 0; |
212 | state.generation_id = state.generation_id.wrapping_add(1); |
213 | this.barrier.event.notify(core::usize::MAX); |
214 | return Poll::Ready(BarrierWaitResult { is_leader: true }); |
215 | } |
216 | } |
217 | |
218 | WaitState::Waiting { local_gen } => { |
219 | ready!(strategy.poll(this.evl, cx)); |
220 | |
221 | // We are now re-acquiring the mutex. |
222 | this.lock.as_mut().set(Some(this.barrier.state.lock())); |
223 | *this.state = WaitState::Reacquiring { |
224 | local_gen: *local_gen, |
225 | }; |
226 | } |
227 | |
228 | WaitState::Reacquiring { local_gen } => { |
229 | // Acquire the local state again. |
230 | let state = ready!(this |
231 | .lock |
232 | .as_mut() |
233 | .as_pin_mut() |
234 | .unwrap() |
235 | .poll_with_strategy(strategy, cx)); |
236 | this.lock.set(None); |
237 | |
238 | if *local_gen == state.generation_id && state.count < this.barrier.n { |
239 | // We need to wait for the event again. |
240 | *this.evl = Some(this.barrier.event.listen()); |
241 | *this.state = WaitState::Waiting { |
242 | local_gen: *local_gen, |
243 | }; |
244 | } else { |
245 | // We are ready, but not the leader. |
246 | return Poll::Ready(BarrierWaitResult { is_leader: false }); |
247 | } |
248 | } |
249 | } |
250 | } |
251 | } |
252 | } |
253 | |
254 | /// Returned by [`Barrier::wait()`] when all tasks have called it. |
255 | /// |
256 | /// # Examples |
257 | /// |
258 | /// ``` |
259 | /// # futures_lite::future::block_on(async { |
260 | /// use async_lock::Barrier; |
261 | /// |
262 | /// let barrier = Barrier::new(1); |
263 | /// let barrier_wait_result = barrier.wait().await; |
264 | /// # }); |
265 | /// ``` |
266 | #[derive (Debug, Clone)] |
267 | pub struct BarrierWaitResult { |
268 | is_leader: bool, |
269 | } |
270 | |
271 | impl BarrierWaitResult { |
272 | /// Returns `true` if this task was the last to call to [`Barrier::wait()`]. |
273 | /// |
274 | /// # Examples |
275 | /// |
276 | /// ``` |
277 | /// # futures_lite::future::block_on(async { |
278 | /// use async_lock::Barrier; |
279 | /// use futures_lite::future; |
280 | /// |
281 | /// let barrier = Barrier::new(2); |
282 | /// let (a, b) = future::zip(barrier.wait(), barrier.wait()).await; |
283 | /// assert_eq!(a.is_leader(), false); |
284 | /// assert_eq!(b.is_leader(), true); |
285 | /// # }); |
286 | /// ``` |
287 | pub fn is_leader(&self) -> bool { |
288 | self.is_leader |
289 | } |
290 | } |
291 | |