1/// Waits on multiple concurrent branches, returning when **all** branches
2/// complete with `Ok(_)` or on the first `Err(_)`.
3///
4/// The `try_join!` macro must be used inside of async functions, closures, and
5/// blocks.
6///
7/// Similar to [`join!`], the `try_join!` macro takes a list of async
8/// expressions and evaluates them concurrently on the same task. Each async
9/// expression evaluates to a future and the futures from each expression are
10/// multiplexed on the current task. The `try_join!` macro returns when **all**
11/// branches return with `Ok` or when the **first** branch returns with `Err`.
12///
13/// [`join!`]: macro@join
14///
15/// # Notes
16///
17/// The supplied futures are stored inline and does not require allocating a
18/// `Vec`.
19///
20/// ### Runtime characteristics
21///
22/// By running all async expressions on the current task, the expressions are
23/// able to run **concurrently** but not in **parallel**. This means all
24/// expressions are run on the same thread and if one branch blocks the thread,
25/// all other expressions will be unable to continue. If parallelism is
26/// required, spawn each async expression using [`tokio::spawn`] and pass the
27/// join handle to `try_join!`.
28///
29/// [`tokio::spawn`]: crate::spawn
30///
31/// # Examples
32///
33/// Basic try_join with two branches.
34///
35/// ```
36/// async fn do_stuff_async() -> Result<(), &'static str> {
37/// // async work
38/// # Ok(())
39/// }
40///
41/// async fn more_async_work() -> Result<(), &'static str> {
42/// // more here
43/// # Ok(())
44/// }
45///
46/// #[tokio::main]
47/// async fn main() {
48/// let res = tokio::try_join!(
49/// do_stuff_async(),
50/// more_async_work());
51///
52/// match res {
53/// Ok((first, second)) => {
54/// // do something with the values
55/// }
56/// Err(err) => {
57/// println!("processing failed; error = {}", err);
58/// }
59/// }
60/// }
61/// ```
62///
63/// Using `try_join!` with spawned tasks.
64///
65/// ```
66/// use tokio::task::JoinHandle;
67///
68/// async fn do_stuff_async() -> Result<(), &'static str> {
69/// // async work
70/// # Err("failed")
71/// }
72///
73/// async fn more_async_work() -> Result<(), &'static str> {
74/// // more here
75/// # Ok(())
76/// }
77///
78/// async fn flatten<T>(handle: JoinHandle<Result<T, &'static str>>) -> Result<T, &'static str> {
79/// match handle.await {
80/// Ok(Ok(result)) => Ok(result),
81/// Ok(Err(err)) => Err(err),
82/// Err(err) => Err("handling failed"),
83/// }
84/// }
85///
86/// #[tokio::main]
87/// async fn main() {
88/// let handle1 = tokio::spawn(do_stuff_async());
89/// let handle2 = tokio::spawn(more_async_work());
90/// match tokio::try_join!(flatten(handle1), flatten(handle2)) {
91/// Ok(val) => {
92/// // do something with the values
93/// }
94/// Err(err) => {
95/// println!("Failed with {}.", err);
96/// # assert_eq!(err, "failed");
97/// }
98/// }
99/// }
100/// ```
101#[macro_export]
102#[cfg_attr(docsrs, doc(cfg(feature = "macros")))]
103macro_rules! try_join {
104 (@ {
105 // One `_` for each branch in the `try_join!` macro. This is not used once
106 // normalization is complete.
107 ( $($count:tt)* )
108
109 // The expression `0+1+1+ ... +1` equal to the number of branches.
110 ( $($total:tt)* )
111
112 // Normalized try_join! branches
113 $( ( $($skip:tt)* ) $e:expr, )*
114
115 }) => {{
116 use $crate::macros::support::{maybe_done, poll_fn, Future, Pin};
117 use $crate::macros::support::Poll::{Ready, Pending};
118
119 // Safety: nothing must be moved out of `futures`. This is to satisfy
120 // the requirement of `Pin::new_unchecked` called below.
121 //
122 // We can't use the `pin!` macro for this because `futures` is a tuple
123 // and the standard library provides no way to pin-project to the fields
124 // of a tuple.
125 let mut futures = ( $( maybe_done($e), )* );
126
127 // This assignment makes sure that the `poll_fn` closure only has a
128 // reference to the futures, instead of taking ownership of them. This
129 // mitigates the issue described in
130 // <https://internals.rust-lang.org/t/surprising-soundness-trouble-around-pollfn/17484>
131 let mut futures = &mut futures;
132
133 // Each time the future created by poll_fn is polled, a different future will be polled first
134 // to ensure every future passed to join! gets a chance to make progress even if
135 // one of the futures consumes the whole budget.
136 //
137 // This is number of futures that will be skipped in the first loop
138 // iteration the next time.
139 let mut skip_next_time: u32 = 0;
140
141 poll_fn(move |cx| {
142 const COUNT: u32 = $($total)*;
143
144 let mut is_pending = false;
145
146 let mut to_run = COUNT;
147
148 // The number of futures that will be skipped in the first loop iteration
149 let mut skip = skip_next_time;
150
151 skip_next_time = if skip + 1 == COUNT { 0 } else { skip + 1 };
152
153 // This loop runs twice and the first `skip` futures
154 // are not polled in the first iteration.
155 loop {
156 $(
157 if skip == 0 {
158 if to_run == 0 {
159 // Every future has been polled
160 break;
161 }
162 to_run -= 1;
163
164 // Extract the future for this branch from the tuple.
165 let ( $($skip,)* fut, .. ) = &mut *futures;
166
167 // Safety: future is stored on the stack above
168 // and never moved.
169 let mut fut = unsafe { Pin::new_unchecked(fut) };
170
171 // Try polling
172 if fut.as_mut().poll(cx).is_pending() {
173 is_pending = true;
174 } else if fut.as_mut().output_mut().expect("expected completed future").is_err() {
175 return Ready(Err(fut.take_output().expect("expected completed future").err().unwrap()))
176 }
177 } else {
178 // Future skipped, one less future to skip in the next iteration
179 skip -= 1;
180 }
181 )*
182 }
183
184 if is_pending {
185 Pending
186 } else {
187 Ready(Ok(($({
188 // Extract the future for this branch from the tuple.
189 let ( $($skip,)* fut, .. ) = &mut futures;
190
191 // Safety: future is stored on the stack above
192 // and never moved.
193 let mut fut = unsafe { Pin::new_unchecked(fut) };
194
195 fut
196 .take_output()
197 .expect("expected completed future")
198 .ok()
199 .expect("expected Ok(_)")
200 },)*)))
201 }
202 }).await
203 }};
204
205 // ===== Normalize =====
206
207 (@ { ( $($s:tt)* ) ( $($n:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => {
208 $crate::try_join!(@{ ($($s)* _) ($($n)* + 1) $($t)* ($($s)*) $e, } $($r)*)
209 };
210
211 // ===== Entry point =====
212
213 ( $($e:expr),+ $(,)?) => {
214 $crate::try_join!(@{ () (0) } $($e,)*)
215 };
216
217 () => { async { Ok(()) }.await }
218}
219