1 | use event_listener::{Event, EventListener}; |
2 | use event_listener_strategy::{easy_wrapper, EventListenerFuture, Strategy}; |
3 | |
4 | use core::fmt; |
5 | use core::pin::Pin; |
6 | use core::task::Poll; |
7 | |
8 | use crate::futures::Lock; |
9 | use crate::Mutex; |
10 | |
11 | /// A counter to synchronize multiple tasks at the same time. |
12 | #[derive (Debug)] |
13 | pub struct Barrier { |
14 | n: usize, |
15 | state: Mutex<State>, |
16 | event: Event, |
17 | } |
18 | |
19 | #[derive (Debug)] |
20 | struct State { |
21 | count: usize, |
22 | generation_id: u64, |
23 | } |
24 | |
25 | impl Barrier { |
26 | /// Creates a barrier that can block the given number of tasks. |
27 | /// |
28 | /// A barrier will block `n`-1 tasks which call [`wait()`] and then wake up all tasks |
29 | /// at once when the `n`th task calls [`wait()`]. |
30 | /// |
31 | /// [`wait()`]: `Barrier::wait()` |
32 | /// |
33 | /// # Examples |
34 | /// |
35 | /// ``` |
36 | /// use async_lock::Barrier; |
37 | /// |
38 | /// let barrier = Barrier::new(5); |
39 | /// ``` |
40 | pub const fn new(n: usize) -> Barrier { |
41 | Barrier { |
42 | n, |
43 | state: Mutex::new(State { |
44 | count: 0, |
45 | generation_id: 0, |
46 | }), |
47 | event: Event::new(), |
48 | } |
49 | } |
50 | |
51 | /// Blocks the current task until all tasks reach this point. |
52 | /// |
53 | /// Barriers are reusable after all tasks have synchronized, and can be used continuously. |
54 | /// |
55 | /// Returns a [`BarrierWaitResult`] indicating whether this task is the "leader", meaning the |
56 | /// last task to call this method. |
57 | /// |
58 | /// # Examples |
59 | /// |
60 | /// ``` |
61 | /// use async_lock::Barrier; |
62 | /// use futures_lite::future; |
63 | /// use std::sync::Arc; |
64 | /// use std::thread; |
65 | /// |
66 | /// let barrier = Arc::new(Barrier::new(5)); |
67 | /// |
68 | /// for _ in 0..5 { |
69 | /// let b = barrier.clone(); |
70 | /// thread::spawn(move || { |
71 | /// future::block_on(async { |
72 | /// // The same messages will be printed together. |
73 | /// // There will NOT be interleaving of "before" and "after". |
74 | /// println!("before wait" ); |
75 | /// b.wait().await; |
76 | /// println!("after wait" ); |
77 | /// }); |
78 | /// }); |
79 | /// } |
80 | /// ``` |
81 | pub fn wait(&self) -> BarrierWait<'_> { |
82 | BarrierWait::_new(BarrierWaitInner { |
83 | barrier: self, |
84 | lock: Some(self.state.lock()), |
85 | evl: EventListener::new(), |
86 | state: WaitState::Initial, |
87 | }) |
88 | } |
89 | |
90 | /// Blocks the current thread until all tasks reach this point. |
91 | /// |
92 | /// Barriers are reusable after all tasks have synchronized, and can be used continuously. |
93 | /// |
94 | /// Returns a [`BarrierWaitResult`] indicating whether this task is the "leader", meaning the |
95 | /// last task to call this method. |
96 | /// |
97 | /// # Blocking |
98 | /// |
99 | /// Rather than using asynchronous waiting, like the [`wait`][`Barrier::wait`] method, |
100 | /// this method will block the current thread until the wait is complete. |
101 | /// |
102 | /// This method should not be used in an asynchronous context. It is intended to be |
103 | /// used in a way that a barrier can be used in both asynchronous and synchronous contexts. |
104 | /// Calling this method in an asynchronous context may result in a deadlock. |
105 | /// |
106 | /// # Examples |
107 | /// |
108 | /// ``` |
109 | /// use async_lock::Barrier; |
110 | /// use futures_lite::future; |
111 | /// use std::sync::Arc; |
112 | /// use std::thread; |
113 | /// |
114 | /// let barrier = Arc::new(Barrier::new(5)); |
115 | /// |
116 | /// for _ in 0..5 { |
117 | /// let b = barrier.clone(); |
118 | /// thread::spawn(move || { |
119 | /// // The same messages will be printed together. |
120 | /// // There will NOT be interleaving of "before" and "after". |
121 | /// println!("before wait" ); |
122 | /// b.wait_blocking(); |
123 | /// println!("after wait" ); |
124 | /// }); |
125 | /// } |
126 | /// ``` |
127 | #[cfg (all(feature = "std" , not(target_family = "wasm" )))] |
128 | pub fn wait_blocking(&self) -> BarrierWaitResult { |
129 | self.wait().wait() |
130 | } |
131 | } |
132 | |
133 | easy_wrapper! { |
134 | /// The future returned by [`Barrier::wait()`]. |
135 | pub struct BarrierWait<'a>(BarrierWaitInner<'a> => BarrierWaitResult); |
136 | #[cfg (all(feature = "std" , not(target_family = "wasm" )))] |
137 | pub(crate) wait(); |
138 | } |
139 | |
140 | pin_project_lite::pin_project! { |
141 | /// The future returned by [`Barrier::wait()`]. |
142 | struct BarrierWaitInner<'a> { |
143 | // The barrier to wait on. |
144 | barrier: &'a Barrier, |
145 | |
146 | // The ongoing mutex lock operation we are blocking on. |
147 | #[pin] |
148 | lock: Option<Lock<'a, State>>, |
149 | |
150 | // An event listener for the `barrier.event` event. |
151 | #[pin] |
152 | evl: EventListener, |
153 | |
154 | // The current state of the future. |
155 | state: WaitState, |
156 | } |
157 | } |
158 | |
159 | impl fmt::Debug for BarrierWait<'_> { |
160 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
161 | f.write_str(data:"BarrierWait { .. }" ) |
162 | } |
163 | } |
164 | |
165 | enum WaitState { |
166 | /// We are getting the original values of the state. |
167 | Initial, |
168 | |
169 | /// We are waiting for the listener to complete. |
170 | Waiting { local_gen: u64 }, |
171 | |
172 | /// Waiting to re-acquire the lock to check the state again. |
173 | Reacquiring { local_gen: u64 }, |
174 | } |
175 | |
176 | impl EventListenerFuture for BarrierWaitInner<'_> { |
177 | type Output = BarrierWaitResult; |
178 | |
179 | fn poll_with_strategy<'a, S: Strategy<'a>>( |
180 | self: Pin<&mut Self>, |
181 | strategy: &mut S, |
182 | cx: &mut S::Context, |
183 | ) -> Poll<Self::Output> { |
184 | let mut this = self.project(); |
185 | |
186 | loop { |
187 | match this.state { |
188 | WaitState::Initial => { |
189 | // See if the lock is ready yet. |
190 | let mut state = ready!(this |
191 | .lock |
192 | .as_mut() |
193 | .as_pin_mut() |
194 | .unwrap() |
195 | .poll_with_strategy(strategy, cx)); |
196 | this.lock.as_mut().set(None); |
197 | |
198 | let local_gen = state.generation_id; |
199 | state.count += 1; |
200 | |
201 | if state.count < this.barrier.n { |
202 | // We need to wait for the event. |
203 | this.evl.as_mut().listen(&this.barrier.event); |
204 | *this.state = WaitState::Waiting { local_gen }; |
205 | } else { |
206 | // We are the last one. |
207 | state.count = 0; |
208 | state.generation_id = state.generation_id.wrapping_add(1); |
209 | this.barrier.event.notify(core::usize::MAX); |
210 | return Poll::Ready(BarrierWaitResult { is_leader: true }); |
211 | } |
212 | } |
213 | |
214 | WaitState::Waiting { local_gen } => { |
215 | ready!(strategy.poll(this.evl.as_mut(), cx)); |
216 | |
217 | // We are now re-acquiring the mutex. |
218 | this.lock.as_mut().set(Some(this.barrier.state.lock())); |
219 | *this.state = WaitState::Reacquiring { |
220 | local_gen: *local_gen, |
221 | }; |
222 | } |
223 | |
224 | WaitState::Reacquiring { local_gen } => { |
225 | // Acquire the local state again. |
226 | let state = ready!(this |
227 | .lock |
228 | .as_mut() |
229 | .as_pin_mut() |
230 | .unwrap() |
231 | .poll_with_strategy(strategy, cx)); |
232 | this.lock.set(None); |
233 | |
234 | if *local_gen == state.generation_id && state.count < this.barrier.n { |
235 | // We need to wait for the event again. |
236 | this.evl.as_mut().listen(&this.barrier.event); |
237 | *this.state = WaitState::Waiting { |
238 | local_gen: *local_gen, |
239 | }; |
240 | } else { |
241 | // We are ready, but not the leader. |
242 | return Poll::Ready(BarrierWaitResult { is_leader: false }); |
243 | } |
244 | } |
245 | } |
246 | } |
247 | } |
248 | } |
249 | |
250 | /// Returned by [`Barrier::wait()`] when all tasks have called it. |
251 | /// |
252 | /// # Examples |
253 | /// |
254 | /// ``` |
255 | /// # futures_lite::future::block_on(async { |
256 | /// use async_lock::Barrier; |
257 | /// |
258 | /// let barrier = Barrier::new(1); |
259 | /// let barrier_wait_result = barrier.wait().await; |
260 | /// # }); |
261 | /// ``` |
262 | #[derive (Debug, Clone)] |
263 | pub struct BarrierWaitResult { |
264 | is_leader: bool, |
265 | } |
266 | |
267 | impl BarrierWaitResult { |
268 | /// Returns `true` if this task was the last to call to [`Barrier::wait()`]. |
269 | /// |
270 | /// # Examples |
271 | /// |
272 | /// ``` |
273 | /// # futures_lite::future::block_on(async { |
274 | /// use async_lock::Barrier; |
275 | /// use futures_lite::future; |
276 | /// |
277 | /// let barrier = Barrier::new(2); |
278 | /// let (a, b) = future::zip(barrier.wait(), barrier.wait()).await; |
279 | /// assert_eq!(a.is_leader(), false); |
280 | /// assert_eq!(b.is_leader(), true); |
281 | /// # }); |
282 | /// ``` |
283 | pub fn is_leader(&self) -> bool { |
284 | self.is_leader |
285 | } |
286 | } |
287 | |