| 1 | use crate::loom::sync::Mutex; |
| 2 | use crate::sync::watch; |
| 3 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
| 4 | use crate::util::trace; |
| 5 | |
| 6 | /// A barrier enables multiple tasks to synchronize the beginning of some computation. |
| 7 | /// |
| 8 | /// ``` |
| 9 | /// # #[tokio::main] |
| 10 | /// # async fn main() { |
| 11 | /// use tokio::sync::Barrier; |
| 12 | /// use std::sync::Arc; |
| 13 | /// |
| 14 | /// let mut handles = Vec::with_capacity(10); |
| 15 | /// let barrier = Arc::new(Barrier::new(10)); |
| 16 | /// for _ in 0..10 { |
| 17 | /// let c = barrier.clone(); |
| 18 | /// // The same messages will be printed together. |
| 19 | /// // You will NOT see any interleaving. |
| 20 | /// handles.push(tokio::spawn(async move { |
| 21 | /// println!("before wait" ); |
| 22 | /// let wait_result = c.wait().await; |
| 23 | /// println!("after wait" ); |
| 24 | /// wait_result |
| 25 | /// })); |
| 26 | /// } |
| 27 | /// |
| 28 | /// // Will not resolve until all "after wait" messages have been printed |
| 29 | /// let mut num_leaders = 0; |
| 30 | /// for handle in handles { |
| 31 | /// let wait_result = handle.await.unwrap(); |
| 32 | /// if wait_result.is_leader() { |
| 33 | /// num_leaders += 1; |
| 34 | /// } |
| 35 | /// } |
| 36 | /// |
| 37 | /// // Exactly one barrier will resolve as the "leader" |
| 38 | /// assert_eq!(num_leaders, 1); |
| 39 | /// # } |
| 40 | /// ``` |
| 41 | #[derive (Debug)] |
| 42 | pub struct Barrier { |
| 43 | state: Mutex<BarrierState>, |
| 44 | wait: watch::Receiver<usize>, |
| 45 | n: usize, |
| 46 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
| 47 | resource_span: tracing::Span, |
| 48 | } |
| 49 | |
| 50 | #[derive (Debug)] |
| 51 | struct BarrierState { |
| 52 | waker: watch::Sender<usize>, |
| 53 | arrived: usize, |
| 54 | generation: usize, |
| 55 | } |
| 56 | |
| 57 | impl Barrier { |
| 58 | /// Creates a new barrier that can block a given number of tasks. |
| 59 | /// |
| 60 | /// A barrier will block `n`-1 tasks which call [`Barrier::wait`] and then wake up all |
| 61 | /// tasks at once when the `n`th task calls `wait`. |
| 62 | #[track_caller ] |
| 63 | pub fn new(mut n: usize) -> Barrier { |
| 64 | let (waker, wait) = crate::sync::watch::channel(0); |
| 65 | |
| 66 | if n == 0 { |
| 67 | // if n is 0, it's not clear what behavior the user wants. |
| 68 | // in std::sync::Barrier, an n of 0 exhibits the same behavior as n == 1, where every |
| 69 | // .wait() immediately unblocks, so we adopt that here as well. |
| 70 | n = 1; |
| 71 | } |
| 72 | |
| 73 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
| 74 | let resource_span = { |
| 75 | let location = std::panic::Location::caller(); |
| 76 | let resource_span = tracing::trace_span!( |
| 77 | parent: None, |
| 78 | "runtime.resource" , |
| 79 | concrete_type = "Barrier" , |
| 80 | kind = "Sync" , |
| 81 | loc.file = location.file(), |
| 82 | loc.line = location.line(), |
| 83 | loc.col = location.column(), |
| 84 | ); |
| 85 | |
| 86 | resource_span.in_scope(|| { |
| 87 | tracing::trace!( |
| 88 | target: "runtime::resource::state_update" , |
| 89 | size = n, |
| 90 | ); |
| 91 | |
| 92 | tracing::trace!( |
| 93 | target: "runtime::resource::state_update" , |
| 94 | arrived = 0, |
| 95 | ) |
| 96 | }); |
| 97 | resource_span |
| 98 | }; |
| 99 | |
| 100 | Barrier { |
| 101 | state: Mutex::new(BarrierState { |
| 102 | waker, |
| 103 | arrived: 0, |
| 104 | generation: 1, |
| 105 | }), |
| 106 | n, |
| 107 | wait, |
| 108 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
| 109 | resource_span, |
| 110 | } |
| 111 | } |
| 112 | |
| 113 | /// Does not resolve until all tasks have rendezvoused here. |
| 114 | /// |
| 115 | /// Barriers are re-usable after all tasks have rendezvoused once, and can |
| 116 | /// be used continuously. |
| 117 | /// |
| 118 | /// A single (arbitrary) future will receive a [`BarrierWaitResult`] that returns `true` from |
| 119 | /// [`BarrierWaitResult::is_leader`] when returning from this function, and all other tasks |
| 120 | /// will receive a result that will return `false` from `is_leader`. |
| 121 | /// |
| 122 | /// # Cancel safety |
| 123 | /// |
| 124 | /// This method is not cancel safe. |
| 125 | pub async fn wait(&self) -> BarrierWaitResult { |
| 126 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
| 127 | return trace::async_op( |
| 128 | || self.wait_internal(), |
| 129 | self.resource_span.clone(), |
| 130 | "Barrier::wait" , |
| 131 | "poll" , |
| 132 | false, |
| 133 | ) |
| 134 | .await; |
| 135 | |
| 136 | #[cfg (any(not(tokio_unstable), not(feature = "tracing" )))] |
| 137 | return self.wait_internal().await; |
| 138 | } |
| 139 | async fn wait_internal(&self) -> BarrierWaitResult { |
| 140 | crate::trace::async_trace_leaf().await; |
| 141 | |
| 142 | // NOTE: we are taking a _synchronous_ lock here. |
| 143 | // It is okay to do so because the critical section is fast and never yields, so it cannot |
| 144 | // deadlock even if another future is concurrently holding the lock. |
| 145 | // It is _desirable_ to do so as synchronous Mutexes are, at least in theory, faster than |
| 146 | // the asynchronous counter-parts, so we should use them where possible [citation needed]. |
| 147 | // NOTE: the extra scope here is so that the compiler doesn't think `state` is held across |
| 148 | // a yield point, and thus marks the returned future as !Send. |
| 149 | let generation = { |
| 150 | let mut state = self.state.lock(); |
| 151 | let generation = state.generation; |
| 152 | state.arrived += 1; |
| 153 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
| 154 | tracing::trace!( |
| 155 | target: "runtime::resource::state_update" , |
| 156 | arrived = 1, |
| 157 | arrived.op = "add" , |
| 158 | ); |
| 159 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
| 160 | tracing::trace!( |
| 161 | target: "runtime::resource::async_op::state_update" , |
| 162 | arrived = true, |
| 163 | ); |
| 164 | if state.arrived == self.n { |
| 165 | #[cfg (all(tokio_unstable, feature = "tracing" ))] |
| 166 | tracing::trace!( |
| 167 | target: "runtime::resource::async_op::state_update" , |
| 168 | is_leader = true, |
| 169 | ); |
| 170 | // we are the leader for this generation |
| 171 | // wake everyone, increment the generation, and return |
| 172 | state |
| 173 | .waker |
| 174 | .send(state.generation) |
| 175 | .expect("there is at least one receiver" ); |
| 176 | state.arrived = 0; |
| 177 | state.generation += 1; |
| 178 | return BarrierWaitResult(true); |
| 179 | } |
| 180 | |
| 181 | generation |
| 182 | }; |
| 183 | |
| 184 | // we're going to have to wait for the last of the generation to arrive |
| 185 | let mut wait = self.wait.clone(); |
| 186 | |
| 187 | loop { |
| 188 | let _ = wait.changed().await; |
| 189 | |
| 190 | // note that the first time through the loop, this _will_ yield a generation |
| 191 | // immediately, since we cloned a receiver that has never seen any values. |
| 192 | if *wait.borrow() >= generation { |
| 193 | break; |
| 194 | } |
| 195 | } |
| 196 | |
| 197 | BarrierWaitResult(false) |
| 198 | } |
| 199 | } |
| 200 | |
| 201 | /// A `BarrierWaitResult` is returned by `wait` when all tasks in the `Barrier` have rendezvoused. |
| 202 | #[derive (Debug, Clone)] |
| 203 | pub struct BarrierWaitResult(bool); |
| 204 | |
| 205 | impl BarrierWaitResult { |
| 206 | /// Returns `true` if this task from wait is the "leader task". |
| 207 | /// |
| 208 | /// Only one task will have `true` returned from their result, all other tasks will have |
| 209 | /// `false` returned. |
| 210 | pub fn is_leader(&self) -> bool { |
| 211 | self.0 |
| 212 | } |
| 213 | } |
| 214 | |