1//! The futures-rs `join! macro implementation.
2
3use proc_macro::TokenStream;
4use proc_macro2::{Span, TokenStream as TokenStream2};
5use quote::{format_ident, quote};
6use syn::parse::{Parse, ParseStream};
7use syn::{Expr, Ident, Token};
8
9#[derive(Default)]
10struct Join {
11 fut_exprs: Vec<Expr>,
12}
13
14impl Parse for Join {
15 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
16 let mut join = Self::default();
17
18 while !input.is_empty() {
19 join.fut_exprs.push(input.parse::<Expr>()?);
20
21 if !input.is_empty() {
22 input.parse::<Token![,]>()?;
23 }
24 }
25
26 Ok(join)
27 }
28}
29
30fn bind_futures(fut_exprs: Vec<Expr>, span: Span) -> (Vec<TokenStream2>, Vec<Ident>) {
31 let mut future_let_bindings = Vec::with_capacity(fut_exprs.len());
32 let future_names: Vec<_> = fut_exprs
33 .into_iter()
34 .enumerate()
35 .map(|(i, expr)| {
36 let name = format_ident!("_fut{}", i, span = span);
37 future_let_bindings.push(quote! {
38 // Move future into a local so that it is pinned in one place and
39 // is no longer accessible by the end user.
40 let mut #name = __futures_crate::future::maybe_done(#expr);
41 let mut #name = unsafe { __futures_crate::Pin::new_unchecked(&mut #name) };
42 });
43 name
44 })
45 .collect();
46
47 (future_let_bindings, future_names)
48}
49
50/// The `join!` macro.
51pub(crate) fn join(input: TokenStream) -> TokenStream {
52 let parsed = syn::parse_macro_input!(input as Join);
53
54 // should be def_site, but that's unstable
55 let span = Span::call_site();
56
57 let (future_let_bindings, future_names) = bind_futures(parsed.fut_exprs, span);
58
59 let poll_futures = future_names.iter().map(|fut| {
60 quote! {
61 __all_done &= __futures_crate::future::Future::poll(
62 #fut.as_mut(), __cx).is_ready();
63 }
64 });
65 let take_outputs = future_names.iter().map(|fut| {
66 quote! {
67 #fut.as_mut().take_output().unwrap(),
68 }
69 });
70
71 TokenStream::from(quote! { {
72 #( #future_let_bindings )*
73
74 __futures_crate::future::poll_fn(move |__cx: &mut __futures_crate::task::Context<'_>| {
75 let mut __all_done = true;
76 #( #poll_futures )*
77 if __all_done {
78 __futures_crate::task::Poll::Ready((
79 #( #take_outputs )*
80 ))
81 } else {
82 __futures_crate::task::Poll::Pending
83 }
84 }).await
85 } })
86}
87
88/// The `try_join!` macro.
89pub(crate) fn try_join(input: TokenStream) -> TokenStream {
90 let parsed = syn::parse_macro_input!(input as Join);
91
92 // should be def_site, but that's unstable
93 let span = Span::call_site();
94
95 let (future_let_bindings, future_names) = bind_futures(parsed.fut_exprs, span);
96
97 let poll_futures = future_names.iter().map(|fut| {
98 quote! {
99 if __futures_crate::future::Future::poll(
100 #fut.as_mut(), __cx).is_pending()
101 {
102 __all_done = false;
103 } else if #fut.as_mut().output_mut().unwrap().is_err() {
104 // `.err().unwrap()` rather than `.unwrap_err()` so that we don't introduce
105 // a `T: Debug` bound.
106 // Also, for an error type of ! any code after `err().unwrap()` is unreachable.
107 #[allow(unreachable_code)]
108 return __futures_crate::task::Poll::Ready(
109 __futures_crate::Err(
110 #fut.as_mut().take_output().unwrap().err().unwrap()
111 )
112 );
113 }
114 }
115 });
116 let take_outputs = future_names.iter().map(|fut| {
117 quote! {
118 // `.ok().unwrap()` rather than `.unwrap()` so that we don't introduce
119 // an `E: Debug` bound.
120 // Also, for an ok type of ! any code after `ok().unwrap()` is unreachable.
121 #[allow(unreachable_code)]
122 #fut.as_mut().take_output().unwrap().ok().unwrap(),
123 }
124 });
125
126 TokenStream::from(quote! { {
127 #( #future_let_bindings )*
128
129 #[allow(clippy::diverging_sub_expression)]
130 __futures_crate::future::poll_fn(move |__cx: &mut __futures_crate::task::Context<'_>| {
131 let mut __all_done = true;
132 #( #poll_futures )*
133 if __all_done {
134 __futures_crate::task::Poll::Ready(
135 __futures_crate::Ok((
136 #( #take_outputs )*
137 ))
138 )
139 } else {
140 __futures_crate::task::Poll::Pending
141 }
142 }).await
143 } })
144}
145