1 | /// Waits on multiple concurrent branches, returning when **all** branches |
2 | /// complete with `Ok(_)` or on the first `Err(_)`. |
3 | /// |
4 | /// The `try_join!` macro must be used inside of async functions, closures, and |
5 | /// blocks. |
6 | /// |
7 | /// Similar to [`join!`], the `try_join!` macro takes a list of async |
8 | /// expressions and evaluates them concurrently on the same task. Each async |
9 | /// expression evaluates to a future and the futures from each expression are |
10 | /// multiplexed on the current task. The `try_join!` macro returns when **all** |
11 | /// branches return with `Ok` or when the **first** branch returns with `Err`. |
12 | /// |
13 | /// [`join!`]: macro@join |
14 | /// |
15 | /// # Notes |
16 | /// |
17 | /// The supplied futures are stored inline and does not require allocating a |
18 | /// `Vec`. |
19 | /// |
20 | /// ### Runtime characteristics |
21 | /// |
22 | /// By running all async expressions on the current task, the expressions are |
23 | /// able to run **concurrently** but not in **parallel**. This means all |
24 | /// expressions are run on the same thread and if one branch blocks the thread, |
25 | /// all other expressions will be unable to continue. If parallelism is |
26 | /// required, spawn each async expression using [`tokio::spawn`] and pass the |
27 | /// join handle to `try_join!`. |
28 | /// |
29 | /// [`tokio::spawn`]: crate::spawn |
30 | /// |
31 | /// # Examples |
32 | /// |
33 | /// Basic `try_join` with two branches. |
34 | /// |
35 | /// ``` |
36 | /// async fn do_stuff_async() -> Result<(), &'static str> { |
37 | /// // async work |
38 | /// # Ok(()) |
39 | /// } |
40 | /// |
41 | /// async fn more_async_work() -> Result<(), &'static str> { |
42 | /// // more here |
43 | /// # Ok(()) |
44 | /// } |
45 | /// |
46 | /// #[tokio::main] |
47 | /// async fn main() { |
48 | /// let res = tokio::try_join!( |
49 | /// do_stuff_async(), |
50 | /// more_async_work()); |
51 | /// |
52 | /// match res { |
53 | /// Ok((first, second)) => { |
54 | /// // do something with the values |
55 | /// } |
56 | /// Err(err) => { |
57 | /// println!("processing failed; error = {}" , err); |
58 | /// } |
59 | /// } |
60 | /// } |
61 | /// ``` |
62 | /// |
63 | /// Using `try_join!` with spawned tasks. |
64 | /// |
65 | /// ``` |
66 | /// use tokio::task::JoinHandle; |
67 | /// |
68 | /// async fn do_stuff_async() -> Result<(), &'static str> { |
69 | /// // async work |
70 | /// # Err("failed" ) |
71 | /// } |
72 | /// |
73 | /// async fn more_async_work() -> Result<(), &'static str> { |
74 | /// // more here |
75 | /// # Ok(()) |
76 | /// } |
77 | /// |
78 | /// async fn flatten<T>(handle: JoinHandle<Result<T, &'static str>>) -> Result<T, &'static str> { |
79 | /// match handle.await { |
80 | /// Ok(Ok(result)) => Ok(result), |
81 | /// Ok(Err(err)) => Err(err), |
82 | /// Err(err) => Err("handling failed" ), |
83 | /// } |
84 | /// } |
85 | /// |
86 | /// #[tokio::main] |
87 | /// async fn main() { |
88 | /// let handle1 = tokio::spawn(do_stuff_async()); |
89 | /// let handle2 = tokio::spawn(more_async_work()); |
90 | /// match tokio::try_join!(flatten(handle1), flatten(handle2)) { |
91 | /// Ok(val) => { |
92 | /// // do something with the values |
93 | /// } |
94 | /// Err(err) => { |
95 | /// println!("Failed with {}." , err); |
96 | /// # assert_eq!(err, "failed" ); |
97 | /// } |
98 | /// } |
99 | /// } |
100 | /// ``` |
101 | #[macro_export ] |
102 | #[cfg_attr (docsrs, doc(cfg(feature = "macros" )))] |
103 | macro_rules! try_join { |
104 | (@ { |
105 | // One `_` for each branch in the `try_join!` macro. This is not used once |
106 | // normalization is complete. |
107 | ( $($count:tt)* ) |
108 | |
109 | // The expression `0+1+1+ ... +1` equal to the number of branches. |
110 | ( $($total:tt)* ) |
111 | |
112 | // Normalized try_join! branches |
113 | $( ( $($skip:tt)* ) $e:expr, )* |
114 | |
115 | }) => {{ |
116 | use $crate::macros::support::{maybe_done, poll_fn, Future, Pin}; |
117 | use $crate::macros::support::Poll::{Ready, Pending}; |
118 | |
119 | // Safety: nothing must be moved out of `futures`. This is to satisfy |
120 | // the requirement of `Pin::new_unchecked` called below. |
121 | // |
122 | // We can't use the `pin!` macro for this because `futures` is a tuple |
123 | // and the standard library provides no way to pin-project to the fields |
124 | // of a tuple. |
125 | let mut futures = ( $( maybe_done($e), )* ); |
126 | |
127 | // This assignment makes sure that the `poll_fn` closure only has a |
128 | // reference to the futures, instead of taking ownership of them. This |
129 | // mitigates the issue described in |
130 | // <https://internals.rust-lang.org/t/surprising-soundness-trouble-around-pollfn/17484> |
131 | let mut futures = &mut futures; |
132 | |
133 | // Each time the future created by poll_fn is polled, a different future will be polled first |
134 | // to ensure every future passed to join! gets a chance to make progress even if |
135 | // one of the futures consumes the whole budget. |
136 | // |
137 | // This is number of futures that will be skipped in the first loop |
138 | // iteration the next time. |
139 | let mut skip_next_time: u32 = 0; |
140 | |
141 | poll_fn(move |cx| { |
142 | const COUNT: u32 = $($total)*; |
143 | |
144 | let mut is_pending = false; |
145 | |
146 | let mut to_run = COUNT; |
147 | |
148 | // The number of futures that will be skipped in the first loop iteration |
149 | let mut skip = skip_next_time; |
150 | |
151 | skip_next_time = if skip + 1 == COUNT { 0 } else { skip + 1 }; |
152 | |
153 | // This loop runs twice and the first `skip` futures |
154 | // are not polled in the first iteration. |
155 | loop { |
156 | $( |
157 | if skip == 0 { |
158 | if to_run == 0 { |
159 | // Every future has been polled |
160 | break; |
161 | } |
162 | to_run -= 1; |
163 | |
164 | // Extract the future for this branch from the tuple. |
165 | let ( $($skip,)* fut, .. ) = &mut *futures; |
166 | |
167 | // Safety: future is stored on the stack above |
168 | // and never moved. |
169 | let mut fut = unsafe { Pin::new_unchecked(fut) }; |
170 | |
171 | // Try polling |
172 | if fut.as_mut().poll(cx).is_pending() { |
173 | is_pending = true; |
174 | } else if fut.as_mut().output_mut().expect("expected completed future" ).is_err() { |
175 | return Ready(Err(fut.take_output().expect("expected completed future" ).err().unwrap())) |
176 | } |
177 | } else { |
178 | // Future skipped, one less future to skip in the next iteration |
179 | skip -= 1; |
180 | } |
181 | )* |
182 | } |
183 | |
184 | if is_pending { |
185 | Pending |
186 | } else { |
187 | Ready(Ok(($({ |
188 | // Extract the future for this branch from the tuple. |
189 | let ( $($skip,)* fut, .. ) = &mut futures; |
190 | |
191 | // Safety: future is stored on the stack above |
192 | // and never moved. |
193 | let mut fut = unsafe { Pin::new_unchecked(fut) }; |
194 | |
195 | fut |
196 | .take_output() |
197 | .expect("expected completed future" ) |
198 | .ok() |
199 | .expect("expected Ok(_)" ) |
200 | },)*))) |
201 | } |
202 | }).await |
203 | }}; |
204 | |
205 | // ===== Normalize ===== |
206 | |
207 | (@ { ( $($s:tt)* ) ( $($n:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => { |
208 | $crate::try_join!(@{ ($($s)* _) ($($n)* + 1) $($t)* ($($s)*) $e, } $($r)*) |
209 | }; |
210 | |
211 | // ===== Entry point ===== |
212 | |
213 | ( $($e:expr),+ $(,)?) => { |
214 | $crate::try_join!(@{ () (0) } $($e,)*) |
215 | }; |
216 | |
217 | () => { async { Ok(()) }.await } |
218 | } |
219 | |