1 | use crate::loom::sync::Mutex; |
2 | use crate::sync::watch; |
3 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
4 | use crate::util::trace; |
5 | |
6 | /// A barrier enables multiple tasks to synchronize the beginning of some computation. |
7 | /// |
8 | /// ``` |
9 | /// # #[tokio::main] |
10 | /// # async fn main() { |
11 | /// use tokio::sync::Barrier; |
12 | /// use std::sync::Arc; |
13 | /// |
14 | /// let mut handles = Vec::with_capacity(10); |
15 | /// let barrier = Arc::new(Barrier::new(10)); |
16 | /// for _ in 0..10 { |
17 | /// let c = barrier.clone(); |
18 | /// // The same messages will be printed together. |
19 | /// // You will NOT see any interleaving. |
20 | /// handles.push(tokio::spawn(async move { |
21 | /// println!("before wait" ); |
22 | /// let wait_result = c.wait().await; |
23 | /// println!("after wait" ); |
24 | /// wait_result |
25 | /// })); |
26 | /// } |
27 | /// |
28 | /// // Will not resolve until all "after wait" messages have been printed |
29 | /// let mut num_leaders = 0; |
30 | /// for handle in handles { |
31 | /// let wait_result = handle.await.unwrap(); |
32 | /// if wait_result.is_leader() { |
33 | /// num_leaders += 1; |
34 | /// } |
35 | /// } |
36 | /// |
37 | /// // Exactly one barrier will resolve as the "leader" |
38 | /// assert_eq!(num_leaders, 1); |
39 | /// # } |
40 | /// ``` |
41 | #[derive(Debug)] |
42 | pub struct Barrier { |
43 | state: Mutex<BarrierState>, |
44 | wait: watch::Receiver<usize>, |
45 | n: usize, |
46 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
47 | resource_span: tracing::Span, |
48 | } |
49 | |
50 | #[derive(Debug)] |
51 | struct BarrierState { |
52 | waker: watch::Sender<usize>, |
53 | arrived: usize, |
54 | generation: usize, |
55 | } |
56 | |
57 | impl Barrier { |
58 | /// Creates a new barrier that can block a given number of tasks. |
59 | /// |
60 | /// A barrier will block `n`-1 tasks which call [`Barrier::wait`] and then wake up all |
61 | /// tasks at once when the `n`th task calls `wait`. |
62 | #[track_caller ] |
63 | pub fn new(mut n: usize) -> Barrier { |
64 | let (waker, wait) = crate::sync::watch::channel(0); |
65 | |
66 | if n == 0 { |
67 | // if n is 0, it's not clear what behavior the user wants. |
68 | // in std::sync::Barrier, an n of 0 exhibits the same behavior as n == 1, where every |
69 | // .wait() immediately unblocks, so we adopt that here as well. |
70 | n = 1; |
71 | } |
72 | |
73 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
74 | let resource_span = { |
75 | let location = std::panic::Location::caller(); |
76 | let resource_span = tracing::trace_span!( |
77 | parent: None, |
78 | "runtime.resource" , |
79 | concrete_type = "Barrier" , |
80 | kind = "Sync" , |
81 | loc.file = location.file(), |
82 | loc.line = location.line(), |
83 | loc.col = location.column(), |
84 | ); |
85 | |
86 | resource_span.in_scope(|| { |
87 | tracing::trace!( |
88 | target: "runtime::resource::state_update" , |
89 | size = n, |
90 | ); |
91 | |
92 | tracing::trace!( |
93 | target: "runtime::resource::state_update" , |
94 | arrived = 0, |
95 | ) |
96 | }); |
97 | resource_span |
98 | }; |
99 | |
100 | Barrier { |
101 | state: Mutex::new(BarrierState { |
102 | waker, |
103 | arrived: 0, |
104 | generation: 1, |
105 | }), |
106 | n, |
107 | wait, |
108 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
109 | resource_span, |
110 | } |
111 | } |
112 | |
113 | /// Does not resolve until all tasks have rendezvoused here. |
114 | /// |
115 | /// Barriers are re-usable after all tasks have rendezvoused once, and can |
116 | /// be used continuously. |
117 | /// |
118 | /// A single (arbitrary) future will receive a [`BarrierWaitResult`] that returns `true` from |
119 | /// [`BarrierWaitResult::is_leader`] when returning from this function, and all other tasks |
120 | /// will receive a result that will return `false` from `is_leader`. |
121 | pub async fn wait(&self) -> BarrierWaitResult { |
122 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
123 | return trace::async_op( |
124 | || self.wait_internal(), |
125 | self.resource_span.clone(), |
126 | "Barrier::wait" , |
127 | "poll" , |
128 | false, |
129 | ) |
130 | .await; |
131 | |
132 | #[cfg (any(not(tokio_unstable), not(feature = "tracing" )))] |
133 | return self.wait_internal().await; |
134 | } |
135 | async fn wait_internal(&self) -> BarrierWaitResult { |
136 | crate::trace::async_trace_leaf().await; |
137 | |
138 | // NOTE: we are taking a _synchronous_ lock here. |
139 | // It is okay to do so because the critical section is fast and never yields, so it cannot |
140 | // deadlock even if another future is concurrently holding the lock. |
141 | // It is _desirable_ to do so as synchronous Mutexes are, at least in theory, faster than |
142 | // the asynchronous counter-parts, so we should use them where possible [citation needed]. |
143 | // NOTE: the extra scope here is so that the compiler doesn't think `state` is held across |
144 | // a yield point, and thus marks the returned future as !Send. |
145 | let generation = { |
146 | let mut state = self.state.lock(); |
147 | let generation = state.generation; |
148 | state.arrived += 1; |
149 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
150 | tracing::trace!( |
151 | target: "runtime::resource::state_update" , |
152 | arrived = 1, |
153 | arrived.op = "add" , |
154 | ); |
155 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
156 | tracing::trace!( |
157 | target: "runtime::resource::async_op::state_update" , |
158 | arrived = true, |
159 | ); |
160 | if state.arrived == self.n { |
161 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
162 | tracing::trace!( |
163 | target: "runtime::resource::async_op::state_update" , |
164 | is_leader = true, |
165 | ); |
166 | // we are the leader for this generation |
167 | // wake everyone, increment the generation, and return |
168 | state |
169 | .waker |
170 | .send(state.generation) |
171 | .expect("there is at least one receiver" ); |
172 | state.arrived = 0; |
173 | state.generation += 1; |
174 | return BarrierWaitResult(true); |
175 | } |
176 | |
177 | generation |
178 | }; |
179 | |
180 | // we're going to have to wait for the last of the generation to arrive |
181 | let mut wait = self.wait.clone(); |
182 | |
183 | loop { |
184 | let _ = wait.changed().await; |
185 | |
186 | // note that the first time through the loop, this _will_ yield a generation |
187 | // immediately, since we cloned a receiver that has never seen any values. |
188 | if *wait.borrow() >= generation { |
189 | break; |
190 | } |
191 | } |
192 | |
193 | BarrierWaitResult(false) |
194 | } |
195 | } |
196 | |
197 | /// A `BarrierWaitResult` is returned by `wait` when all tasks in the `Barrier` have rendezvoused. |
198 | #[derive(Debug, Clone)] |
199 | pub struct BarrierWaitResult(bool); |
200 | |
201 | impl BarrierWaitResult { |
202 | /// Returns `true` if this task from wait is the "leader task". |
203 | /// |
204 | /// Only one task will have `true` returned from their result, all other tasks will have |
205 | /// `false` returned. |
206 | pub fn is_leader(&self) -> bool { |
207 | self.0 |
208 | } |
209 | } |
210 | |