1 | use event_listener::{Event, EventListener}; |
2 | |
3 | use std::fmt; |
4 | use std::future::Future; |
5 | use std::pin::Pin; |
6 | use std::task::{Context, Poll}; |
7 | |
8 | use crate::futures::Lock; |
9 | use crate::Mutex; |
10 | |
11 | /// A counter to synchronize multiple tasks at the same time. |
12 | #[derive (Debug)] |
13 | pub struct Barrier { |
14 | n: usize, |
15 | state: Mutex<State>, |
16 | event: Event, |
17 | } |
18 | |
19 | #[derive (Debug)] |
20 | struct State { |
21 | count: usize, |
22 | generation_id: u64, |
23 | } |
24 | |
25 | impl Barrier { |
26 | /// Creates a barrier that can block the given number of tasks. |
27 | /// |
28 | /// A barrier will block `n`-1 tasks which call [`wait()`] and then wake up all tasks |
29 | /// at once when the `n`th task calls [`wait()`]. |
30 | /// |
31 | /// [`wait()`]: `Barrier::wait()` |
32 | /// |
33 | /// # Examples |
34 | /// |
35 | /// ``` |
36 | /// use async_lock::Barrier; |
37 | /// |
38 | /// let barrier = Barrier::new(5); |
39 | /// ``` |
40 | pub const fn new(n: usize) -> Barrier { |
41 | Barrier { |
42 | n, |
43 | state: Mutex::new(State { |
44 | count: 0, |
45 | generation_id: 0, |
46 | }), |
47 | event: Event::new(), |
48 | } |
49 | } |
50 | |
51 | /// Blocks the current task until all tasks reach this point. |
52 | /// |
53 | /// Barriers are reusable after all tasks have synchronized, and can be used continuously. |
54 | /// |
55 | /// Returns a [`BarrierWaitResult`] indicating whether this task is the "leader", meaning the |
56 | /// last task to call this method. |
57 | /// |
58 | /// # Examples |
59 | /// |
60 | /// ``` |
61 | /// use async_lock::Barrier; |
62 | /// use futures_lite::future; |
63 | /// use std::sync::Arc; |
64 | /// use std::thread; |
65 | /// |
66 | /// let barrier = Arc::new(Barrier::new(5)); |
67 | /// |
68 | /// for _ in 0..5 { |
69 | /// let b = barrier.clone(); |
70 | /// thread::spawn(move || { |
71 | /// future::block_on(async { |
72 | /// // The same messages will be printed together. |
73 | /// // There will NOT be interleaving of "before" and "after". |
74 | /// println!("before wait" ); |
75 | /// b.wait().await; |
76 | /// println!("after wait" ); |
77 | /// }); |
78 | /// }); |
79 | /// } |
80 | /// ``` |
81 | pub fn wait(&self) -> BarrierWait<'_> { |
82 | BarrierWait { |
83 | barrier: self, |
84 | lock: Some(self.state.lock()), |
85 | state: WaitState::Initial, |
86 | } |
87 | } |
88 | } |
89 | |
90 | /// The future returned by [`Barrier::wait()`]. |
91 | pub struct BarrierWait<'a> { |
92 | /// The barrier to wait on. |
93 | barrier: &'a Barrier, |
94 | |
95 | /// The ongoing mutex lock operation we are blocking on. |
96 | lock: Option<Lock<'a, State>>, |
97 | |
98 | /// The current state of the future. |
99 | state: WaitState, |
100 | } |
101 | |
102 | impl fmt::Debug for BarrierWait<'_> { |
103 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
104 | f.write_str(data:"BarrierWait { .. }" ) |
105 | } |
106 | } |
107 | |
108 | enum WaitState { |
109 | /// We are getting the original values of the state. |
110 | Initial, |
111 | |
112 | /// We are waiting for the listener to complete. |
113 | Waiting { evl: EventListener, local_gen: u64 }, |
114 | |
115 | /// Waiting to re-acquire the lock to check the state again. |
116 | Reacquiring(u64), |
117 | } |
118 | |
119 | impl Future for BarrierWait<'_> { |
120 | type Output = BarrierWaitResult; |
121 | |
122 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
123 | let this = self.get_mut(); |
124 | |
125 | loop { |
126 | match this.state { |
127 | WaitState::Initial => { |
128 | // See if the lock is ready yet. |
129 | let mut state = ready!(Pin::new(this.lock.as_mut().unwrap()).poll(cx)); |
130 | this.lock = None; |
131 | |
132 | let local_gen = state.generation_id; |
133 | state.count += 1; |
134 | |
135 | if state.count < this.barrier.n { |
136 | // We need to wait for the event. |
137 | this.state = WaitState::Waiting { |
138 | evl: this.barrier.event.listen(), |
139 | local_gen, |
140 | }; |
141 | } else { |
142 | // We are the last one. |
143 | state.count = 0; |
144 | state.generation_id = state.generation_id.wrapping_add(1); |
145 | this.barrier.event.notify(std::usize::MAX); |
146 | return Poll::Ready(BarrierWaitResult { is_leader: true }); |
147 | } |
148 | } |
149 | |
150 | WaitState::Waiting { |
151 | ref mut evl, |
152 | local_gen, |
153 | } => { |
154 | ready!(Pin::new(evl).poll(cx)); |
155 | |
156 | // We are now re-acquiring the mutex. |
157 | this.lock = Some(this.barrier.state.lock()); |
158 | this.state = WaitState::Reacquiring(local_gen); |
159 | } |
160 | |
161 | WaitState::Reacquiring(local_gen) => { |
162 | // Acquire the local state again. |
163 | let state = ready!(Pin::new(this.lock.as_mut().unwrap()).poll(cx)); |
164 | this.lock = None; |
165 | |
166 | if local_gen == state.generation_id && state.count < this.barrier.n { |
167 | // We need to wait for the event again. |
168 | this.state = WaitState::Waiting { |
169 | evl: this.barrier.event.listen(), |
170 | local_gen, |
171 | }; |
172 | } else { |
173 | // We are ready, but not the leader. |
174 | return Poll::Ready(BarrierWaitResult { is_leader: false }); |
175 | } |
176 | } |
177 | } |
178 | } |
179 | } |
180 | } |
181 | |
182 | /// Returned by [`Barrier::wait()`] when all tasks have called it. |
183 | /// |
184 | /// # Examples |
185 | /// |
186 | /// ``` |
187 | /// # futures_lite::future::block_on(async { |
188 | /// use async_lock::Barrier; |
189 | /// |
190 | /// let barrier = Barrier::new(1); |
191 | /// let barrier_wait_result = barrier.wait().await; |
192 | /// # }); |
193 | /// ``` |
194 | #[derive (Debug, Clone)] |
195 | pub struct BarrierWaitResult { |
196 | is_leader: bool, |
197 | } |
198 | |
199 | impl BarrierWaitResult { |
200 | /// Returns `true` if this task was the last to call to [`Barrier::wait()`]. |
201 | /// |
202 | /// # Examples |
203 | /// |
204 | /// ``` |
205 | /// # futures_lite::future::block_on(async { |
206 | /// use async_lock::Barrier; |
207 | /// use futures_lite::future; |
208 | /// |
209 | /// let barrier = Barrier::new(2); |
210 | /// let (a, b) = future::zip(barrier.wait(), barrier.wait()).await; |
211 | /// assert_eq!(a.is_leader(), false); |
212 | /// assert_eq!(b.is_leader(), true); |
213 | /// # }); |
214 | /// ``` |
215 | pub fn is_leader(&self) -> bool { |
216 | self.is_leader |
217 | } |
218 | } |
219 | |