| 1 | //! Wait for multiple futures to complete. |
| 2 | |
| 3 | use core::future::Future; |
| 4 | use core::mem::MaybeUninit; |
| 5 | use core::pin::Pin; |
| 6 | use core::task::{Context, Poll}; |
| 7 | use core::{fmt, mem}; |
| 8 | |
| 9 | #[derive (Debug)] |
| 10 | enum MaybeDone<Fut: Future> { |
| 11 | /// A not-yet-completed future |
| 12 | Future(/* #[pin] */ Fut), |
| 13 | /// The output of the completed future |
| 14 | Done(Fut::Output), |
| 15 | /// The empty variant after the result of a [`MaybeDone`] has been |
| 16 | /// taken using the [`take_output`](MaybeDone::take_output) method. |
| 17 | Gone, |
| 18 | } |
| 19 | |
| 20 | impl<Fut: Future> MaybeDone<Fut> { |
| 21 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> bool { |
| 22 | let this = unsafe { self.get_unchecked_mut() }; |
| 23 | match this { |
| 24 | Self::Future(fut) => match unsafe { Pin::new_unchecked(fut) }.poll(cx) { |
| 25 | Poll::Ready(res) => { |
| 26 | *this = Self::Done(res); |
| 27 | true |
| 28 | } |
| 29 | Poll::Pending => false, |
| 30 | }, |
| 31 | _ => true, |
| 32 | } |
| 33 | } |
| 34 | |
| 35 | fn take_output(&mut self) -> Fut::Output { |
| 36 | match &*self { |
| 37 | Self::Done(_) => {} |
| 38 | Self::Future(_) | Self::Gone => panic!("take_output when MaybeDone is not done." ), |
| 39 | } |
| 40 | match mem::replace(self, Self::Gone) { |
| 41 | MaybeDone::Done(output) => output, |
| 42 | _ => unreachable!(), |
| 43 | } |
| 44 | } |
| 45 | } |
| 46 | |
| 47 | impl<Fut: Future + Unpin> Unpin for MaybeDone<Fut> {} |
| 48 | |
| 49 | macro_rules! generate { |
| 50 | ($( |
| 51 | $(#[$doc:meta])* |
| 52 | ($Join:ident, <$($Fut:ident),*>), |
| 53 | )*) => ($( |
| 54 | $(#[$doc])* |
| 55 | #[must_use = "futures do nothing unless you `.await` or poll them" ] |
| 56 | #[allow(non_snake_case)] |
| 57 | pub struct $Join<$($Fut: Future),*> { |
| 58 | $( |
| 59 | $Fut: MaybeDone<$Fut>, |
| 60 | )* |
| 61 | } |
| 62 | |
| 63 | impl<$($Fut),*> fmt::Debug for $Join<$($Fut),*> |
| 64 | where |
| 65 | $( |
| 66 | $Fut: Future + fmt::Debug, |
| 67 | $Fut::Output: fmt::Debug, |
| 68 | )* |
| 69 | { |
| 70 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 71 | f.debug_struct(stringify!($Join)) |
| 72 | $(.field(stringify!($Fut), &self.$Fut))* |
| 73 | .finish() |
| 74 | } |
| 75 | } |
| 76 | |
| 77 | impl<$($Fut: Future),*> $Join<$($Fut),*> { |
| 78 | #[allow(non_snake_case)] |
| 79 | fn new($($Fut: $Fut),*) -> Self { |
| 80 | Self { |
| 81 | $($Fut: MaybeDone::Future($Fut)),* |
| 82 | } |
| 83 | } |
| 84 | } |
| 85 | |
| 86 | impl<$($Fut: Future),*> Future for $Join<$($Fut),*> { |
| 87 | type Output = ($($Fut::Output),*); |
| 88 | |
| 89 | fn poll( |
| 90 | self: Pin<&mut Self>, cx: &mut Context<'_> |
| 91 | ) -> Poll<Self::Output> { |
| 92 | let this = unsafe { self.get_unchecked_mut() }; |
| 93 | let mut all_done = true; |
| 94 | $( |
| 95 | all_done &= unsafe { Pin::new_unchecked(&mut this.$Fut) }.poll(cx); |
| 96 | )* |
| 97 | |
| 98 | if all_done { |
| 99 | Poll::Ready(($(this.$Fut.take_output()), *)) |
| 100 | } else { |
| 101 | Poll::Pending |
| 102 | } |
| 103 | } |
| 104 | } |
| 105 | )*) |
| 106 | } |
| 107 | |
| 108 | generate! { |
| 109 | /// Future for the [`join`](join()) function. |
| 110 | (Join, <Fut1, Fut2>), |
| 111 | |
| 112 | /// Future for the [`join3`] function. |
| 113 | (Join3, <Fut1, Fut2, Fut3>), |
| 114 | |
| 115 | /// Future for the [`join4`] function. |
| 116 | (Join4, <Fut1, Fut2, Fut3, Fut4>), |
| 117 | |
| 118 | /// Future for the [`join5`] function. |
| 119 | (Join5, <Fut1, Fut2, Fut3, Fut4, Fut5>), |
| 120 | } |
| 121 | |
| 122 | /// Joins the result of two futures, waiting for them both to complete. |
| 123 | /// |
| 124 | /// This function will return a new future which awaits both futures to |
| 125 | /// complete. The returned future will finish with a tuple of both results. |
| 126 | /// |
| 127 | /// Note that this function consumes the passed futures and returns a |
| 128 | /// wrapped version of it. |
| 129 | /// |
| 130 | /// # Examples |
| 131 | /// |
| 132 | /// ``` |
| 133 | /// # embassy_futures::block_on(async { |
| 134 | /// |
| 135 | /// let a = async { 1 }; |
| 136 | /// let b = async { 2 }; |
| 137 | /// let pair = embassy_futures::join::join(a, b).await; |
| 138 | /// |
| 139 | /// assert_eq!(pair, (1, 2)); |
| 140 | /// # }); |
| 141 | /// ``` |
| 142 | pub fn join<Fut1, Fut2>(future1: Fut1, future2: Fut2) -> Join<Fut1, Fut2> |
| 143 | where |
| 144 | Fut1: Future, |
| 145 | Fut2: Future, |
| 146 | { |
| 147 | Join::new(Fut1:future1, Fut2:future2) |
| 148 | } |
| 149 | |
| 150 | /// Joins the result of three futures, waiting for them all to complete. |
| 151 | /// |
| 152 | /// This function will return a new future which awaits all futures to |
| 153 | /// complete. The returned future will finish with a tuple of all results. |
| 154 | /// |
| 155 | /// Note that this function consumes the passed futures and returns a |
| 156 | /// wrapped version of it. |
| 157 | /// |
| 158 | /// # Examples |
| 159 | /// |
| 160 | /// ``` |
| 161 | /// # embassy_futures::block_on(async { |
| 162 | /// |
| 163 | /// let a = async { 1 }; |
| 164 | /// let b = async { 2 }; |
| 165 | /// let c = async { 3 }; |
| 166 | /// let res = embassy_futures::join::join3(a, b, c).await; |
| 167 | /// |
| 168 | /// assert_eq!(res, (1, 2, 3)); |
| 169 | /// # }); |
| 170 | /// ``` |
| 171 | pub fn join3<Fut1, Fut2, Fut3>(future1: Fut1, future2: Fut2, future3: Fut3) -> Join3<Fut1, Fut2, Fut3> |
| 172 | where |
| 173 | Fut1: Future, |
| 174 | Fut2: Future, |
| 175 | Fut3: Future, |
| 176 | { |
| 177 | Join3::new(Fut1:future1, Fut2:future2, Fut3:future3) |
| 178 | } |
| 179 | |
| 180 | /// Joins the result of four futures, waiting for them all to complete. |
| 181 | /// |
| 182 | /// This function will return a new future which awaits all futures to |
| 183 | /// complete. The returned future will finish with a tuple of all results. |
| 184 | /// |
| 185 | /// Note that this function consumes the passed futures and returns a |
| 186 | /// wrapped version of it. |
| 187 | /// |
| 188 | /// # Examples |
| 189 | /// |
| 190 | /// ``` |
| 191 | /// # embassy_futures::block_on(async { |
| 192 | /// |
| 193 | /// let a = async { 1 }; |
| 194 | /// let b = async { 2 }; |
| 195 | /// let c = async { 3 }; |
| 196 | /// let d = async { 4 }; |
| 197 | /// let res = embassy_futures::join::join4(a, b, c, d).await; |
| 198 | /// |
| 199 | /// assert_eq!(res, (1, 2, 3, 4)); |
| 200 | /// # }); |
| 201 | /// ``` |
| 202 | pub fn join4<Fut1, Fut2, Fut3, Fut4>( |
| 203 | future1: Fut1, |
| 204 | future2: Fut2, |
| 205 | future3: Fut3, |
| 206 | future4: Fut4, |
| 207 | ) -> Join4<Fut1, Fut2, Fut3, Fut4> |
| 208 | where |
| 209 | Fut1: Future, |
| 210 | Fut2: Future, |
| 211 | Fut3: Future, |
| 212 | Fut4: Future, |
| 213 | { |
| 214 | Join4::new(Fut1:future1, Fut2:future2, Fut3:future3, Fut4:future4) |
| 215 | } |
| 216 | |
| 217 | /// Joins the result of five futures, waiting for them all to complete. |
| 218 | /// |
| 219 | /// This function will return a new future which awaits all futures to |
| 220 | /// complete. The returned future will finish with a tuple of all results. |
| 221 | /// |
| 222 | /// Note that this function consumes the passed futures and returns a |
| 223 | /// wrapped version of it. |
| 224 | /// |
| 225 | /// # Examples |
| 226 | /// |
| 227 | /// ``` |
| 228 | /// # embassy_futures::block_on(async { |
| 229 | /// |
| 230 | /// let a = async { 1 }; |
| 231 | /// let b = async { 2 }; |
| 232 | /// let c = async { 3 }; |
| 233 | /// let d = async { 4 }; |
| 234 | /// let e = async { 5 }; |
| 235 | /// let res = embassy_futures::join::join5(a, b, c, d, e).await; |
| 236 | /// |
| 237 | /// assert_eq!(res, (1, 2, 3, 4, 5)); |
| 238 | /// # }); |
| 239 | /// ``` |
| 240 | pub fn join5<Fut1, Fut2, Fut3, Fut4, Fut5>( |
| 241 | future1: Fut1, |
| 242 | future2: Fut2, |
| 243 | future3: Fut3, |
| 244 | future4: Fut4, |
| 245 | future5: Fut5, |
| 246 | ) -> Join5<Fut1, Fut2, Fut3, Fut4, Fut5> |
| 247 | where |
| 248 | Fut1: Future, |
| 249 | Fut2: Future, |
| 250 | Fut3: Future, |
| 251 | Fut4: Future, |
| 252 | Fut5: Future, |
| 253 | { |
| 254 | Join5::new(Fut1:future1, Fut2:future2, Fut3:future3, Fut4:future4, Fut5:future5) |
| 255 | } |
| 256 | |
| 257 | // ===================================================== |
| 258 | |
| 259 | /// Future for the [`join_array`] function. |
| 260 | #[must_use = "futures do nothing unless you `.await` or poll them" ] |
| 261 | pub struct JoinArray<Fut: Future, const N: usize> { |
| 262 | futures: [MaybeDone<Fut>; N], |
| 263 | } |
| 264 | |
| 265 | impl<Fut: Future, const N: usize> fmt::Debug for JoinArray<Fut, N> |
| 266 | where |
| 267 | Fut: Future + fmt::Debug, |
| 268 | Fut::Output: fmt::Debug, |
| 269 | { |
| 270 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 271 | f.debug_struct("JoinArray" ).field(name:"futures" , &self.futures).finish() |
| 272 | } |
| 273 | } |
| 274 | |
| 275 | impl<Fut: Future, const N: usize> Future for JoinArray<Fut, N> { |
| 276 | type Output = [Fut::Output; N]; |
| 277 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
| 278 | let this: &mut JoinArray = unsafe { self.get_unchecked_mut() }; |
| 279 | let mut all_done: bool = true; |
| 280 | for f: &mut MaybeDone in this.futures.iter_mut() { |
| 281 | all_done &= unsafe { Pin::new_unchecked(pointer:f) }.poll(cx); |
| 282 | } |
| 283 | |
| 284 | if all_done { |
| 285 | let mut array: [MaybeUninit<Fut::Output>; N] = unsafe { MaybeUninit::uninit().assume_init() }; |
| 286 | for i: usize in 0..N { |
| 287 | array[i].write(val:this.futures[i].take_output()); |
| 288 | } |
| 289 | Poll::Ready(unsafe { (&array as *const _ as *const [Fut::Output; N]).read() }) |
| 290 | } else { |
| 291 | Poll::Pending |
| 292 | } |
| 293 | } |
| 294 | } |
| 295 | |
| 296 | /// Joins the result of an array of futures, waiting for them all to complete. |
| 297 | /// |
| 298 | /// This function will return a new future which awaits all futures to |
| 299 | /// complete. The returned future will finish with a tuple of all results. |
| 300 | /// |
| 301 | /// Note that this function consumes the passed futures and returns a |
| 302 | /// wrapped version of it. |
| 303 | /// |
| 304 | /// # Examples |
| 305 | /// |
| 306 | /// ``` |
| 307 | /// # embassy_futures::block_on(async { |
| 308 | /// |
| 309 | /// async fn foo(n: u32) -> u32 { n } |
| 310 | /// let a = foo(1); |
| 311 | /// let b = foo(2); |
| 312 | /// let c = foo(3); |
| 313 | /// let res = embassy_futures::join::join_array([a, b, c]).await; |
| 314 | /// |
| 315 | /// assert_eq!(res, [1, 2, 3]); |
| 316 | /// # }); |
| 317 | /// ``` |
| 318 | pub fn join_array<Fut: Future, const N: usize>(futures: [Fut; N]) -> JoinArray<Fut, N> { |
| 319 | JoinArray { |
| 320 | futures: futures.map(MaybeDone::Future), |
| 321 | } |
| 322 | } |
| 323 | |