1 | use crate::loom::sync::Mutex; |
2 | use crate::sync::watch; |
3 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
4 | use crate::util::trace; |
5 | |
6 | /// A barrier enables multiple tasks to synchronize the beginning of some computation. |
7 | /// |
8 | /// ``` |
9 | /// # #[tokio::main] |
10 | /// # async fn main() { |
11 | /// use tokio::sync::Barrier; |
12 | /// use std::sync::Arc; |
13 | /// |
14 | /// let mut handles = Vec::with_capacity(10); |
15 | /// let barrier = Arc::new(Barrier::new(10)); |
16 | /// for _ in 0..10 { |
17 | /// let c = barrier.clone(); |
18 | /// // The same messages will be printed together. |
19 | /// // You will NOT see any interleaving. |
20 | /// handles.push(tokio::spawn(async move { |
21 | /// println!("before wait" ); |
22 | /// let wait_result = c.wait().await; |
23 | /// println!("after wait" ); |
24 | /// wait_result |
25 | /// })); |
26 | /// } |
27 | /// |
28 | /// // Will not resolve until all "after wait" messages have been printed |
29 | /// let mut num_leaders = 0; |
30 | /// for handle in handles { |
31 | /// let wait_result = handle.await.unwrap(); |
32 | /// if wait_result.is_leader() { |
33 | /// num_leaders += 1; |
34 | /// } |
35 | /// } |
36 | /// |
37 | /// // Exactly one barrier will resolve as the "leader" |
38 | /// assert_eq!(num_leaders, 1); |
39 | /// # } |
40 | /// ``` |
41 | #[derive (Debug)] |
42 | pub struct Barrier { |
43 | state: Mutex<BarrierState>, |
44 | wait: watch::Receiver<usize>, |
45 | n: usize, |
46 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
47 | resource_span: tracing::Span, |
48 | } |
49 | |
50 | #[derive (Debug)] |
51 | struct BarrierState { |
52 | waker: watch::Sender<usize>, |
53 | arrived: usize, |
54 | generation: usize, |
55 | } |
56 | |
57 | impl Barrier { |
58 | /// Creates a new barrier that can block a given number of tasks. |
59 | /// |
60 | /// A barrier will block `n`-1 tasks which call [`Barrier::wait`] and then wake up all |
61 | /// tasks at once when the `n`th task calls `wait`. |
62 | #[track_caller ] |
63 | pub fn new(mut n: usize) -> Barrier { |
64 | let (waker, wait) = crate::sync::watch::channel(0); |
65 | |
66 | if n == 0 { |
67 | // if n is 0, it's not clear what behavior the user wants. |
68 | // in std::sync::Barrier, an n of 0 exhibits the same behavior as n == 1, where every |
69 | // .wait() immediately unblocks, so we adopt that here as well. |
70 | n = 1; |
71 | } |
72 | |
73 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
74 | let resource_span = { |
75 | let location = std::panic::Location::caller(); |
76 | let resource_span = tracing::trace_span!( |
77 | "runtime.resource" , |
78 | concrete_type = "Barrier" , |
79 | kind = "Sync" , |
80 | loc.file = location.file(), |
81 | loc.line = location.line(), |
82 | loc.col = location.column(), |
83 | ); |
84 | |
85 | resource_span.in_scope(|| { |
86 | tracing::trace!( |
87 | target: "runtime::resource::state_update" , |
88 | size = n, |
89 | ); |
90 | |
91 | tracing::trace!( |
92 | target: "runtime::resource::state_update" , |
93 | arrived = 0, |
94 | ) |
95 | }); |
96 | resource_span |
97 | }; |
98 | |
99 | Barrier { |
100 | state: Mutex::new(BarrierState { |
101 | waker, |
102 | arrived: 0, |
103 | generation: 1, |
104 | }), |
105 | n, |
106 | wait, |
107 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
108 | resource_span, |
109 | } |
110 | } |
111 | |
112 | /// Does not resolve until all tasks have rendezvoused here. |
113 | /// |
114 | /// Barriers are re-usable after all tasks have rendezvoused once, and can |
115 | /// be used continuously. |
116 | /// |
117 | /// A single (arbitrary) future will receive a [`BarrierWaitResult`] that returns `true` from |
118 | /// [`BarrierWaitResult::is_leader`] when returning from this function, and all other tasks |
119 | /// will receive a result that will return `false` from `is_leader`. |
120 | pub async fn wait(&self) -> BarrierWaitResult { |
121 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
122 | return trace::async_op( |
123 | || self.wait_internal(), |
124 | self.resource_span.clone(), |
125 | "Barrier::wait" , |
126 | "poll" , |
127 | false, |
128 | ) |
129 | .await; |
130 | |
131 | #[cfg (any(not(tokio_unstable), not(feature = "tracing" )))] |
132 | return self.wait_internal().await; |
133 | } |
134 | async fn wait_internal(&self) -> BarrierWaitResult { |
135 | crate::trace::async_trace_leaf().await; |
136 | |
137 | // NOTE: we are taking a _synchronous_ lock here. |
138 | // It is okay to do so because the critical section is fast and never yields, so it cannot |
139 | // deadlock even if another future is concurrently holding the lock. |
140 | // It is _desirable_ to do so as synchronous Mutexes are, at least in theory, faster than |
141 | // the asynchronous counter-parts, so we should use them where possible [citation needed]. |
142 | // NOTE: the extra scope here is so that the compiler doesn't think `state` is held across |
143 | // a yield point, and thus marks the returned future as !Send. |
144 | let generation = { |
145 | let mut state = self.state.lock(); |
146 | let generation = state.generation; |
147 | state.arrived += 1; |
148 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
149 | tracing::trace!( |
150 | target: "runtime::resource::state_update" , |
151 | arrived = 1, |
152 | arrived.op = "add" , |
153 | ); |
154 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
155 | tracing::trace!( |
156 | target: "runtime::resource::async_op::state_update" , |
157 | arrived = true, |
158 | ); |
159 | if state.arrived == self.n { |
160 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
161 | tracing::trace!( |
162 | target: "runtime::resource::async_op::state_update" , |
163 | is_leader = true, |
164 | ); |
165 | // we are the leader for this generation |
166 | // wake everyone, increment the generation, and return |
167 | state |
168 | .waker |
169 | .send(state.generation) |
170 | .expect("there is at least one receiver" ); |
171 | state.arrived = 0; |
172 | state.generation += 1; |
173 | return BarrierWaitResult(true); |
174 | } |
175 | |
176 | generation |
177 | }; |
178 | |
179 | // we're going to have to wait for the last of the generation to arrive |
180 | let mut wait = self.wait.clone(); |
181 | |
182 | loop { |
183 | let _ = wait.changed().await; |
184 | |
185 | // note that the first time through the loop, this _will_ yield a generation |
186 | // immediately, since we cloned a receiver that has never seen any values. |
187 | if *wait.borrow() >= generation { |
188 | break; |
189 | } |
190 | } |
191 | |
192 | BarrierWaitResult(false) |
193 | } |
194 | } |
195 | |
196 | /// A `BarrierWaitResult` is returned by `wait` when all tasks in the `Barrier` have rendezvoused. |
197 | #[derive (Debug, Clone)] |
198 | pub struct BarrierWaitResult(bool); |
199 | |
200 | impl BarrierWaitResult { |
201 | /// Returns `true` if this task from wait is the "leader task". |
202 | /// |
203 | /// Only one task will have `true` returned from their result, all other tasks will have |
204 | /// `false` returned. |
205 | pub fn is_leader(&self) -> bool { |
206 | self.0 |
207 | } |
208 | } |
209 | |