1 | //! Wait for multiple futures to complete. |
2 | |
3 | use core::future::Future; |
4 | use core::mem::MaybeUninit; |
5 | use core::pin::Pin; |
6 | use core::task::{Context, Poll}; |
7 | use core::{fmt, mem}; |
8 | |
9 | #[derive (Debug)] |
10 | enum MaybeDone<Fut: Future> { |
11 | /// A not-yet-completed future |
12 | Future(/* #[pin] */ Fut), |
13 | /// The output of the completed future |
14 | Done(Fut::Output), |
15 | /// The empty variant after the result of a [`MaybeDone`] has been |
16 | /// taken using the [`take_output`](MaybeDone::take_output) method. |
17 | Gone, |
18 | } |
19 | |
20 | impl<Fut: Future> MaybeDone<Fut> { |
21 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> bool { |
22 | let this = unsafe { self.get_unchecked_mut() }; |
23 | match this { |
24 | Self::Future(fut) => match unsafe { Pin::new_unchecked(fut) }.poll(cx) { |
25 | Poll::Ready(res) => { |
26 | *this = Self::Done(res); |
27 | true |
28 | } |
29 | Poll::Pending => false, |
30 | }, |
31 | _ => true, |
32 | } |
33 | } |
34 | |
35 | fn take_output(&mut self) -> Fut::Output { |
36 | match &*self { |
37 | Self::Done(_) => {} |
38 | Self::Future(_) | Self::Gone => panic!("take_output when MaybeDone is not done." ), |
39 | } |
40 | match mem::replace(self, Self::Gone) { |
41 | MaybeDone::Done(output) => output, |
42 | _ => unreachable!(), |
43 | } |
44 | } |
45 | } |
46 | |
47 | impl<Fut: Future + Unpin> Unpin for MaybeDone<Fut> {} |
48 | |
49 | macro_rules! generate { |
50 | ($( |
51 | $(#[$doc:meta])* |
52 | ($Join:ident, <$($Fut:ident),*>), |
53 | )*) => ($( |
54 | $(#[$doc])* |
55 | #[must_use = "futures do nothing unless you `.await` or poll them" ] |
56 | #[allow(non_snake_case)] |
57 | pub struct $Join<$($Fut: Future),*> { |
58 | $( |
59 | $Fut: MaybeDone<$Fut>, |
60 | )* |
61 | } |
62 | |
63 | impl<$($Fut),*> fmt::Debug for $Join<$($Fut),*> |
64 | where |
65 | $( |
66 | $Fut: Future + fmt::Debug, |
67 | $Fut::Output: fmt::Debug, |
68 | )* |
69 | { |
70 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
71 | f.debug_struct(stringify!($Join)) |
72 | $(.field(stringify!($Fut), &self.$Fut))* |
73 | .finish() |
74 | } |
75 | } |
76 | |
77 | impl<$($Fut: Future),*> $Join<$($Fut),*> { |
78 | #[allow(non_snake_case)] |
79 | fn new($($Fut: $Fut),*) -> Self { |
80 | Self { |
81 | $($Fut: MaybeDone::Future($Fut)),* |
82 | } |
83 | } |
84 | } |
85 | |
86 | impl<$($Fut: Future),*> Future for $Join<$($Fut),*> { |
87 | type Output = ($($Fut::Output),*); |
88 | |
89 | fn poll( |
90 | self: Pin<&mut Self>, cx: &mut Context<'_> |
91 | ) -> Poll<Self::Output> { |
92 | let this = unsafe { self.get_unchecked_mut() }; |
93 | let mut all_done = true; |
94 | $( |
95 | all_done &= unsafe { Pin::new_unchecked(&mut this.$Fut) }.poll(cx); |
96 | )* |
97 | |
98 | if all_done { |
99 | Poll::Ready(($(this.$Fut.take_output()), *)) |
100 | } else { |
101 | Poll::Pending |
102 | } |
103 | } |
104 | } |
105 | )*) |
106 | } |
107 | |
108 | generate! { |
109 | /// Future for the [`join`](join()) function. |
110 | (Join, <Fut1, Fut2>), |
111 | |
112 | /// Future for the [`join3`] function. |
113 | (Join3, <Fut1, Fut2, Fut3>), |
114 | |
115 | /// Future for the [`join4`] function. |
116 | (Join4, <Fut1, Fut2, Fut3, Fut4>), |
117 | |
118 | /// Future for the [`join5`] function. |
119 | (Join5, <Fut1, Fut2, Fut3, Fut4, Fut5>), |
120 | } |
121 | |
122 | /// Joins the result of two futures, waiting for them both to complete. |
123 | /// |
124 | /// This function will return a new future which awaits both futures to |
125 | /// complete. The returned future will finish with a tuple of both results. |
126 | /// |
127 | /// Note that this function consumes the passed futures and returns a |
128 | /// wrapped version of it. |
129 | /// |
130 | /// # Examples |
131 | /// |
132 | /// ``` |
133 | /// # embassy_futures::block_on(async { |
134 | /// |
135 | /// let a = async { 1 }; |
136 | /// let b = async { 2 }; |
137 | /// let pair = embassy_futures::join::join(a, b).await; |
138 | /// |
139 | /// assert_eq!(pair, (1, 2)); |
140 | /// # }); |
141 | /// ``` |
142 | pub fn join<Fut1, Fut2>(future1: Fut1, future2: Fut2) -> Join<Fut1, Fut2> |
143 | where |
144 | Fut1: Future, |
145 | Fut2: Future, |
146 | { |
147 | Join::new(Fut1:future1, Fut2:future2) |
148 | } |
149 | |
150 | /// Joins the result of three futures, waiting for them all to complete. |
151 | /// |
152 | /// This function will return a new future which awaits all futures to |
153 | /// complete. The returned future will finish with a tuple of all results. |
154 | /// |
155 | /// Note that this function consumes the passed futures and returns a |
156 | /// wrapped version of it. |
157 | /// |
158 | /// # Examples |
159 | /// |
160 | /// ``` |
161 | /// # embassy_futures::block_on(async { |
162 | /// |
163 | /// let a = async { 1 }; |
164 | /// let b = async { 2 }; |
165 | /// let c = async { 3 }; |
166 | /// let res = embassy_futures::join::join3(a, b, c).await; |
167 | /// |
168 | /// assert_eq!(res, (1, 2, 3)); |
169 | /// # }); |
170 | /// ``` |
171 | pub fn join3<Fut1, Fut2, Fut3>(future1: Fut1, future2: Fut2, future3: Fut3) -> Join3<Fut1, Fut2, Fut3> |
172 | where |
173 | Fut1: Future, |
174 | Fut2: Future, |
175 | Fut3: Future, |
176 | { |
177 | Join3::new(Fut1:future1, Fut2:future2, Fut3:future3) |
178 | } |
179 | |
180 | /// Joins the result of four futures, waiting for them all to complete. |
181 | /// |
182 | /// This function will return a new future which awaits all futures to |
183 | /// complete. The returned future will finish with a tuple of all results. |
184 | /// |
185 | /// Note that this function consumes the passed futures and returns a |
186 | /// wrapped version of it. |
187 | /// |
188 | /// # Examples |
189 | /// |
190 | /// ``` |
191 | /// # embassy_futures::block_on(async { |
192 | /// |
193 | /// let a = async { 1 }; |
194 | /// let b = async { 2 }; |
195 | /// let c = async { 3 }; |
196 | /// let d = async { 4 }; |
197 | /// let res = embassy_futures::join::join4(a, b, c, d).await; |
198 | /// |
199 | /// assert_eq!(res, (1, 2, 3, 4)); |
200 | /// # }); |
201 | /// ``` |
202 | pub fn join4<Fut1, Fut2, Fut3, Fut4>( |
203 | future1: Fut1, |
204 | future2: Fut2, |
205 | future3: Fut3, |
206 | future4: Fut4, |
207 | ) -> Join4<Fut1, Fut2, Fut3, Fut4> |
208 | where |
209 | Fut1: Future, |
210 | Fut2: Future, |
211 | Fut3: Future, |
212 | Fut4: Future, |
213 | { |
214 | Join4::new(Fut1:future1, Fut2:future2, Fut3:future3, Fut4:future4) |
215 | } |
216 | |
217 | /// Joins the result of five futures, waiting for them all to complete. |
218 | /// |
219 | /// This function will return a new future which awaits all futures to |
220 | /// complete. The returned future will finish with a tuple of all results. |
221 | /// |
222 | /// Note that this function consumes the passed futures and returns a |
223 | /// wrapped version of it. |
224 | /// |
225 | /// # Examples |
226 | /// |
227 | /// ``` |
228 | /// # embassy_futures::block_on(async { |
229 | /// |
230 | /// let a = async { 1 }; |
231 | /// let b = async { 2 }; |
232 | /// let c = async { 3 }; |
233 | /// let d = async { 4 }; |
234 | /// let e = async { 5 }; |
235 | /// let res = embassy_futures::join::join5(a, b, c, d, e).await; |
236 | /// |
237 | /// assert_eq!(res, (1, 2, 3, 4, 5)); |
238 | /// # }); |
239 | /// ``` |
240 | pub fn join5<Fut1, Fut2, Fut3, Fut4, Fut5>( |
241 | future1: Fut1, |
242 | future2: Fut2, |
243 | future3: Fut3, |
244 | future4: Fut4, |
245 | future5: Fut5, |
246 | ) -> Join5<Fut1, Fut2, Fut3, Fut4, Fut5> |
247 | where |
248 | Fut1: Future, |
249 | Fut2: Future, |
250 | Fut3: Future, |
251 | Fut4: Future, |
252 | Fut5: Future, |
253 | { |
254 | Join5::new(Fut1:future1, Fut2:future2, Fut3:future3, Fut4:future4, Fut5:future5) |
255 | } |
256 | |
257 | // ===================================================== |
258 | |
259 | /// Future for the [`join_array`] function. |
260 | #[must_use = "futures do nothing unless you `.await` or poll them" ] |
261 | pub struct JoinArray<Fut: Future, const N: usize> { |
262 | futures: [MaybeDone<Fut>; N], |
263 | } |
264 | |
265 | impl<Fut: Future, const N: usize> fmt::Debug for JoinArray<Fut, N> |
266 | where |
267 | Fut: Future + fmt::Debug, |
268 | Fut::Output: fmt::Debug, |
269 | { |
270 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
271 | f.debug_struct("JoinArray" ).field(name:"futures" , &self.futures).finish() |
272 | } |
273 | } |
274 | |
275 | impl<Fut: Future, const N: usize> Future for JoinArray<Fut, N> { |
276 | type Output = [Fut::Output; N]; |
277 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
278 | let this: &mut JoinArray = unsafe { self.get_unchecked_mut() }; |
279 | let mut all_done: bool = true; |
280 | for f: &mut MaybeDone in this.futures.iter_mut() { |
281 | all_done &= unsafe { Pin::new_unchecked(pointer:f) }.poll(cx); |
282 | } |
283 | |
284 | if all_done { |
285 | let mut array: [MaybeUninit<Fut::Output>; N] = unsafe { MaybeUninit::uninit().assume_init() }; |
286 | for i: usize in 0..N { |
287 | array[i].write(val:this.futures[i].take_output()); |
288 | } |
289 | Poll::Ready(unsafe { (&array as *const _ as *const [Fut::Output; N]).read() }) |
290 | } else { |
291 | Poll::Pending |
292 | } |
293 | } |
294 | } |
295 | |
296 | /// Joins the result of an array of futures, waiting for them all to complete. |
297 | /// |
298 | /// This function will return a new future which awaits all futures to |
299 | /// complete. The returned future will finish with a tuple of all results. |
300 | /// |
301 | /// Note that this function consumes the passed futures and returns a |
302 | /// wrapped version of it. |
303 | /// |
304 | /// # Examples |
305 | /// |
306 | /// ``` |
307 | /// # embassy_futures::block_on(async { |
308 | /// |
309 | /// async fn foo(n: u32) -> u32 { n } |
310 | /// let a = foo(1); |
311 | /// let b = foo(2); |
312 | /// let c = foo(3); |
313 | /// let res = embassy_futures::join::join_array([a, b, c]).await; |
314 | /// |
315 | /// assert_eq!(res, [1, 2, 3]); |
316 | /// # }); |
317 | /// ``` |
318 | pub fn join_array<Fut: Future, const N: usize>(futures: [Fut; N]) -> JoinArray<Fut, N> { |
319 | JoinArray { |
320 | futures: futures.map(MaybeDone::Future), |
321 | } |
322 | } |
323 | |