1use crate::loom::sync::Mutex;
2use crate::sync::watch;
3#[cfg(all(tokio_unstable, feature = "tracing"))]
4use crate::util::trace;
5
6/// A barrier enables multiple tasks to synchronize the beginning of some computation.
7///
8/// ```
9/// # #[tokio::main]
10/// # async fn main() {
11/// use tokio::sync::Barrier;
12/// use std::sync::Arc;
13///
14/// let mut handles = Vec::with_capacity(10);
15/// let barrier = Arc::new(Barrier::new(10));
16/// for _ in 0..10 {
17/// let c = barrier.clone();
18/// // The same messages will be printed together.
19/// // You will NOT see any interleaving.
20/// handles.push(tokio::spawn(async move {
21/// println!("before wait");
22/// let wait_result = c.wait().await;
23/// println!("after wait");
24/// wait_result
25/// }));
26/// }
27///
28/// // Will not resolve until all "after wait" messages have been printed
29/// let mut num_leaders = 0;
30/// for handle in handles {
31/// let wait_result = handle.await.unwrap();
32/// if wait_result.is_leader() {
33/// num_leaders += 1;
34/// }
35/// }
36///
37/// // Exactly one barrier will resolve as the "leader"
38/// assert_eq!(num_leaders, 1);
39/// # }
40/// ```
41#[derive(Debug)]
42pub struct Barrier {
43 state: Mutex<BarrierState>,
44 wait: watch::Receiver<usize>,
45 n: usize,
46 #[cfg(all(tokio_unstable, feature = "tracing"))]
47 resource_span: tracing::Span,
48}
49
50#[derive(Debug)]
51struct BarrierState {
52 waker: watch::Sender<usize>,
53 arrived: usize,
54 generation: usize,
55}
56
57impl Barrier {
58 /// Creates a new barrier that can block a given number of tasks.
59 ///
60 /// A barrier will block `n`-1 tasks which call [`Barrier::wait`] and then wake up all
61 /// tasks at once when the `n`th task calls `wait`.
62 #[track_caller]
63 pub fn new(mut n: usize) -> Barrier {
64 let (waker, wait) = crate::sync::watch::channel(0);
65
66 if n == 0 {
67 // if n is 0, it's not clear what behavior the user wants.
68 // in std::sync::Barrier, an n of 0 exhibits the same behavior as n == 1, where every
69 // .wait() immediately unblocks, so we adopt that here as well.
70 n = 1;
71 }
72
73 #[cfg(all(tokio_unstable, feature = "tracing"))]
74 let resource_span = {
75 let location = std::panic::Location::caller();
76 let resource_span = tracing::trace_span!(
77 "runtime.resource",
78 concrete_type = "Barrier",
79 kind = "Sync",
80 loc.file = location.file(),
81 loc.line = location.line(),
82 loc.col = location.column(),
83 );
84
85 resource_span.in_scope(|| {
86 tracing::trace!(
87 target: "runtime::resource::state_update",
88 size = n,
89 );
90
91 tracing::trace!(
92 target: "runtime::resource::state_update",
93 arrived = 0,
94 )
95 });
96 resource_span
97 };
98
99 Barrier {
100 state: Mutex::new(BarrierState {
101 waker,
102 arrived: 0,
103 generation: 1,
104 }),
105 n,
106 wait,
107 #[cfg(all(tokio_unstable, feature = "tracing"))]
108 resource_span,
109 }
110 }
111
112 /// Does not resolve until all tasks have rendezvoused here.
113 ///
114 /// Barriers are re-usable after all tasks have rendezvoused once, and can
115 /// be used continuously.
116 ///
117 /// A single (arbitrary) future will receive a [`BarrierWaitResult`] that returns `true` from
118 /// [`BarrierWaitResult::is_leader`] when returning from this function, and all other tasks
119 /// will receive a result that will return `false` from `is_leader`.
120 pub async fn wait(&self) -> BarrierWaitResult {
121 #[cfg(all(tokio_unstable, feature = "tracing"))]
122 return trace::async_op(
123 || self.wait_internal(),
124 self.resource_span.clone(),
125 "Barrier::wait",
126 "poll",
127 false,
128 )
129 .await;
130
131 #[cfg(any(not(tokio_unstable), not(feature = "tracing")))]
132 return self.wait_internal().await;
133 }
134 async fn wait_internal(&self) -> BarrierWaitResult {
135 crate::trace::async_trace_leaf().await;
136
137 // NOTE: we are taking a _synchronous_ lock here.
138 // It is okay to do so because the critical section is fast and never yields, so it cannot
139 // deadlock even if another future is concurrently holding the lock.
140 // It is _desirable_ to do so as synchronous Mutexes are, at least in theory, faster than
141 // the asynchronous counter-parts, so we should use them where possible [citation needed].
142 // NOTE: the extra scope here is so that the compiler doesn't think `state` is held across
143 // a yield point, and thus marks the returned future as !Send.
144 let generation = {
145 let mut state = self.state.lock();
146 let generation = state.generation;
147 state.arrived += 1;
148 #[cfg(all(tokio_unstable, feature = "tracing"))]
149 tracing::trace!(
150 target: "runtime::resource::state_update",
151 arrived = 1,
152 arrived.op = "add",
153 );
154 #[cfg(all(tokio_unstable, feature = "tracing"))]
155 tracing::trace!(
156 target: "runtime::resource::async_op::state_update",
157 arrived = true,
158 );
159 if state.arrived == self.n {
160 #[cfg(all(tokio_unstable, feature = "tracing"))]
161 tracing::trace!(
162 target: "runtime::resource::async_op::state_update",
163 is_leader = true,
164 );
165 // we are the leader for this generation
166 // wake everyone, increment the generation, and return
167 state
168 .waker
169 .send(state.generation)
170 .expect("there is at least one receiver");
171 state.arrived = 0;
172 state.generation += 1;
173 return BarrierWaitResult(true);
174 }
175
176 generation
177 };
178
179 // we're going to have to wait for the last of the generation to arrive
180 let mut wait = self.wait.clone();
181
182 loop {
183 let _ = wait.changed().await;
184
185 // note that the first time through the loop, this _will_ yield a generation
186 // immediately, since we cloned a receiver that has never seen any values.
187 if *wait.borrow() >= generation {
188 break;
189 }
190 }
191
192 BarrierWaitResult(false)
193 }
194}
195
196/// A `BarrierWaitResult` is returned by `wait` when all tasks in the `Barrier` have rendezvoused.
197#[derive(Debug, Clone)]
198pub struct BarrierWaitResult(bool);
199
200impl BarrierWaitResult {
201 /// Returns `true` if this task from wait is the "leader task".
202 ///
203 /// Only one task will have `true` returned from their result, all other tasks will have
204 /// `false` returned.
205 pub fn is_leader(&self) -> bool {
206 self.0
207 }
208}
209