1use event_listener::{Event, EventListener};
2use event_listener_strategy::{easy_wrapper, EventListenerFuture, Strategy};
3
4use core::fmt;
5use core::pin::Pin;
6use core::task::Poll;
7
8use crate::futures::Lock;
9use crate::Mutex;
10
11/// A counter to synchronize multiple tasks at the same time.
12#[derive(Debug)]
13pub struct Barrier {
14 n: usize,
15 state: Mutex<State>,
16 event: Event,
17}
18
19#[derive(Debug)]
20struct State {
21 count: usize,
22 generation_id: u64,
23}
24
25impl Barrier {
26 /// Creates a barrier that can block the given number of tasks.
27 ///
28 /// A barrier will block `n`-1 tasks which call [`wait()`] and then wake up all tasks
29 /// at once when the `n`th task calls [`wait()`].
30 ///
31 /// [`wait()`]: `Barrier::wait()`
32 ///
33 /// # Examples
34 ///
35 /// ```
36 /// use async_lock::Barrier;
37 ///
38 /// let barrier = Barrier::new(5);
39 /// ```
40 pub const fn new(n: usize) -> Barrier {
41 Barrier {
42 n,
43 state: Mutex::new(State {
44 count: 0,
45 generation_id: 0,
46 }),
47 event: Event::new(),
48 }
49 }
50
51 /// Blocks the current task until all tasks reach this point.
52 ///
53 /// Barriers are reusable after all tasks have synchronized, and can be used continuously.
54 ///
55 /// Returns a [`BarrierWaitResult`] indicating whether this task is the "leader", meaning the
56 /// last task to call this method.
57 ///
58 /// # Examples
59 ///
60 /// ```
61 /// use async_lock::Barrier;
62 /// use futures_lite::future;
63 /// use std::sync::Arc;
64 /// use std::thread;
65 ///
66 /// let barrier = Arc::new(Barrier::new(5));
67 ///
68 /// for _ in 0..5 {
69 /// let b = barrier.clone();
70 /// thread::spawn(move || {
71 /// future::block_on(async {
72 /// // The same messages will be printed together.
73 /// // There will NOT be interleaving of "before" and "after".
74 /// println!("before wait");
75 /// b.wait().await;
76 /// println!("after wait");
77 /// });
78 /// });
79 /// }
80 /// ```
81 pub fn wait(&self) -> BarrierWait<'_> {
82 BarrierWait::_new(BarrierWaitInner {
83 barrier: self,
84 lock: Some(self.state.lock()),
85 evl: EventListener::new(),
86 state: WaitState::Initial,
87 })
88 }
89
90 /// Blocks the current thread until all tasks reach this point.
91 ///
92 /// Barriers are reusable after all tasks have synchronized, and can be used continuously.
93 ///
94 /// Returns a [`BarrierWaitResult`] indicating whether this task is the "leader", meaning the
95 /// last task to call this method.
96 ///
97 /// # Blocking
98 ///
99 /// Rather than using asynchronous waiting, like the [`wait`][`Barrier::wait`] method,
100 /// this method will block the current thread until the wait is complete.
101 ///
102 /// This method should not be used in an asynchronous context. It is intended to be
103 /// used in a way that a barrier can be used in both asynchronous and synchronous contexts.
104 /// Calling this method in an asynchronous context may result in a deadlock.
105 ///
106 /// # Examples
107 ///
108 /// ```
109 /// use async_lock::Barrier;
110 /// use futures_lite::future;
111 /// use std::sync::Arc;
112 /// use std::thread;
113 ///
114 /// let barrier = Arc::new(Barrier::new(5));
115 ///
116 /// for _ in 0..5 {
117 /// let b = barrier.clone();
118 /// thread::spawn(move || {
119 /// // The same messages will be printed together.
120 /// // There will NOT be interleaving of "before" and "after".
121 /// println!("before wait");
122 /// b.wait_blocking();
123 /// println!("after wait");
124 /// });
125 /// }
126 /// ```
127 #[cfg(all(feature = "std", not(target_family = "wasm")))]
128 pub fn wait_blocking(&self) -> BarrierWaitResult {
129 self.wait().wait()
130 }
131}
132
133easy_wrapper! {
134 /// The future returned by [`Barrier::wait()`].
135 pub struct BarrierWait<'a>(BarrierWaitInner<'a> => BarrierWaitResult);
136 #[cfg(all(feature = "std", not(target_family = "wasm")))]
137 pub(crate) wait();
138}
139
140pin_project_lite::pin_project! {
141 /// The future returned by [`Barrier::wait()`].
142 struct BarrierWaitInner<'a> {
143 // The barrier to wait on.
144 barrier: &'a Barrier,
145
146 // The ongoing mutex lock operation we are blocking on.
147 #[pin]
148 lock: Option<Lock<'a, State>>,
149
150 // An event listener for the `barrier.event` event.
151 #[pin]
152 evl: EventListener,
153
154 // The current state of the future.
155 state: WaitState,
156 }
157}
158
159impl fmt::Debug for BarrierWait<'_> {
160 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
161 f.write_str(data:"BarrierWait { .. }")
162 }
163}
164
165enum WaitState {
166 /// We are getting the original values of the state.
167 Initial,
168
169 /// We are waiting for the listener to complete.
170 Waiting { local_gen: u64 },
171
172 /// Waiting to re-acquire the lock to check the state again.
173 Reacquiring { local_gen: u64 },
174}
175
176impl EventListenerFuture for BarrierWaitInner<'_> {
177 type Output = BarrierWaitResult;
178
179 fn poll_with_strategy<'a, S: Strategy<'a>>(
180 self: Pin<&mut Self>,
181 strategy: &mut S,
182 cx: &mut S::Context,
183 ) -> Poll<Self::Output> {
184 let mut this = self.project();
185
186 loop {
187 match this.state {
188 WaitState::Initial => {
189 // See if the lock is ready yet.
190 let mut state = ready!(this
191 .lock
192 .as_mut()
193 .as_pin_mut()
194 .unwrap()
195 .poll_with_strategy(strategy, cx));
196 this.lock.as_mut().set(None);
197
198 let local_gen = state.generation_id;
199 state.count += 1;
200
201 if state.count < this.barrier.n {
202 // We need to wait for the event.
203 this.evl.as_mut().listen(&this.barrier.event);
204 *this.state = WaitState::Waiting { local_gen };
205 } else {
206 // We are the last one.
207 state.count = 0;
208 state.generation_id = state.generation_id.wrapping_add(1);
209 this.barrier.event.notify(core::usize::MAX);
210 return Poll::Ready(BarrierWaitResult { is_leader: true });
211 }
212 }
213
214 WaitState::Waiting { local_gen } => {
215 ready!(strategy.poll(this.evl.as_mut(), cx));
216
217 // We are now re-acquiring the mutex.
218 this.lock.as_mut().set(Some(this.barrier.state.lock()));
219 *this.state = WaitState::Reacquiring {
220 local_gen: *local_gen,
221 };
222 }
223
224 WaitState::Reacquiring { local_gen } => {
225 // Acquire the local state again.
226 let state = ready!(this
227 .lock
228 .as_mut()
229 .as_pin_mut()
230 .unwrap()
231 .poll_with_strategy(strategy, cx));
232 this.lock.set(None);
233
234 if *local_gen == state.generation_id && state.count < this.barrier.n {
235 // We need to wait for the event again.
236 this.evl.as_mut().listen(&this.barrier.event);
237 *this.state = WaitState::Waiting {
238 local_gen: *local_gen,
239 };
240 } else {
241 // We are ready, but not the leader.
242 return Poll::Ready(BarrierWaitResult { is_leader: false });
243 }
244 }
245 }
246 }
247 }
248}
249
250/// Returned by [`Barrier::wait()`] when all tasks have called it.
251///
252/// # Examples
253///
254/// ```
255/// # futures_lite::future::block_on(async {
256/// use async_lock::Barrier;
257///
258/// let barrier = Barrier::new(1);
259/// let barrier_wait_result = barrier.wait().await;
260/// # });
261/// ```
262#[derive(Debug, Clone)]
263pub struct BarrierWaitResult {
264 is_leader: bool,
265}
266
267impl BarrierWaitResult {
268 /// Returns `true` if this task was the last to call to [`Barrier::wait()`].
269 ///
270 /// # Examples
271 ///
272 /// ```
273 /// # futures_lite::future::block_on(async {
274 /// use async_lock::Barrier;
275 /// use futures_lite::future;
276 ///
277 /// let barrier = Barrier::new(2);
278 /// let (a, b) = future::zip(barrier.wait(), barrier.wait()).await;
279 /// assert_eq!(a.is_leader(), false);
280 /// assert_eq!(b.is_leader(), true);
281 /// # });
282 /// ```
283 pub fn is_leader(&self) -> bool {
284 self.is_leader
285 }
286}
287