1use event_listener::{Event, EventListener};
2
3use std::fmt;
4use std::future::Future;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7
8use crate::futures::Lock;
9use crate::Mutex;
10
11/// A counter to synchronize multiple tasks at the same time.
12#[derive(Debug)]
13pub struct Barrier {
14 n: usize,
15 state: Mutex<State>,
16 event: Event,
17}
18
19#[derive(Debug)]
20struct State {
21 count: usize,
22 generation_id: u64,
23}
24
25impl Barrier {
26 /// Creates a barrier that can block the given number of tasks.
27 ///
28 /// A barrier will block `n`-1 tasks which call [`wait()`] and then wake up all tasks
29 /// at once when the `n`th task calls [`wait()`].
30 ///
31 /// [`wait()`]: `Barrier::wait()`
32 ///
33 /// # Examples
34 ///
35 /// ```
36 /// use async_lock::Barrier;
37 ///
38 /// let barrier = Barrier::new(5);
39 /// ```
40 pub const fn new(n: usize) -> Barrier {
41 Barrier {
42 n,
43 state: Mutex::new(State {
44 count: 0,
45 generation_id: 0,
46 }),
47 event: Event::new(),
48 }
49 }
50
51 /// Blocks the current task until all tasks reach this point.
52 ///
53 /// Barriers are reusable after all tasks have synchronized, and can be used continuously.
54 ///
55 /// Returns a [`BarrierWaitResult`] indicating whether this task is the "leader", meaning the
56 /// last task to call this method.
57 ///
58 /// # Examples
59 ///
60 /// ```
61 /// use async_lock::Barrier;
62 /// use futures_lite::future;
63 /// use std::sync::Arc;
64 /// use std::thread;
65 ///
66 /// let barrier = Arc::new(Barrier::new(5));
67 ///
68 /// for _ in 0..5 {
69 /// let b = barrier.clone();
70 /// thread::spawn(move || {
71 /// future::block_on(async {
72 /// // The same messages will be printed together.
73 /// // There will NOT be interleaving of "before" and "after".
74 /// println!("before wait");
75 /// b.wait().await;
76 /// println!("after wait");
77 /// });
78 /// });
79 /// }
80 /// ```
81 pub fn wait(&self) -> BarrierWait<'_> {
82 BarrierWait {
83 barrier: self,
84 lock: Some(self.state.lock()),
85 state: WaitState::Initial,
86 }
87 }
88}
89
90/// The future returned by [`Barrier::wait()`].
91pub struct BarrierWait<'a> {
92 /// The barrier to wait on.
93 barrier: &'a Barrier,
94
95 /// The ongoing mutex lock operation we are blocking on.
96 lock: Option<Lock<'a, State>>,
97
98 /// The current state of the future.
99 state: WaitState,
100}
101
102impl fmt::Debug for BarrierWait<'_> {
103 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104 f.write_str(data:"BarrierWait { .. }")
105 }
106}
107
108enum WaitState {
109 /// We are getting the original values of the state.
110 Initial,
111
112 /// We are waiting for the listener to complete.
113 Waiting { evl: EventListener, local_gen: u64 },
114
115 /// Waiting to re-acquire the lock to check the state again.
116 Reacquiring(u64),
117}
118
119impl Future for BarrierWait<'_> {
120 type Output = BarrierWaitResult;
121
122 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
123 let this = self.get_mut();
124
125 loop {
126 match this.state {
127 WaitState::Initial => {
128 // See if the lock is ready yet.
129 let mut state = ready!(Pin::new(this.lock.as_mut().unwrap()).poll(cx));
130 this.lock = None;
131
132 let local_gen = state.generation_id;
133 state.count += 1;
134
135 if state.count < this.barrier.n {
136 // We need to wait for the event.
137 this.state = WaitState::Waiting {
138 evl: this.barrier.event.listen(),
139 local_gen,
140 };
141 } else {
142 // We are the last one.
143 state.count = 0;
144 state.generation_id = state.generation_id.wrapping_add(1);
145 this.barrier.event.notify(std::usize::MAX);
146 return Poll::Ready(BarrierWaitResult { is_leader: true });
147 }
148 }
149
150 WaitState::Waiting {
151 ref mut evl,
152 local_gen,
153 } => {
154 ready!(Pin::new(evl).poll(cx));
155
156 // We are now re-acquiring the mutex.
157 this.lock = Some(this.barrier.state.lock());
158 this.state = WaitState::Reacquiring(local_gen);
159 }
160
161 WaitState::Reacquiring(local_gen) => {
162 // Acquire the local state again.
163 let state = ready!(Pin::new(this.lock.as_mut().unwrap()).poll(cx));
164 this.lock = None;
165
166 if local_gen == state.generation_id && state.count < this.barrier.n {
167 // We need to wait for the event again.
168 this.state = WaitState::Waiting {
169 evl: this.barrier.event.listen(),
170 local_gen,
171 };
172 } else {
173 // We are ready, but not the leader.
174 return Poll::Ready(BarrierWaitResult { is_leader: false });
175 }
176 }
177 }
178 }
179 }
180}
181
182/// Returned by [`Barrier::wait()`] when all tasks have called it.
183///
184/// # Examples
185///
186/// ```
187/// # futures_lite::future::block_on(async {
188/// use async_lock::Barrier;
189///
190/// let barrier = Barrier::new(1);
191/// let barrier_wait_result = barrier.wait().await;
192/// # });
193/// ```
194#[derive(Debug, Clone)]
195pub struct BarrierWaitResult {
196 is_leader: bool,
197}
198
199impl BarrierWaitResult {
200 /// Returns `true` if this task was the last to call to [`Barrier::wait()`].
201 ///
202 /// # Examples
203 ///
204 /// ```
205 /// # futures_lite::future::block_on(async {
206 /// use async_lock::Barrier;
207 /// use futures_lite::future;
208 ///
209 /// let barrier = Barrier::new(2);
210 /// let (a, b) = future::zip(barrier.wait(), barrier.wait()).await;
211 /// assert_eq!(a.is_leader(), false);
212 /// assert_eq!(b.is_leader(), true);
213 /// # });
214 /// ```
215 pub fn is_leader(&self) -> bool {
216 self.is_leader
217 }
218}
219