1//! The futures-rs `select! macro implementation.
2
3use proc_macro::TokenStream;
4use proc_macro2::Span;
5use quote::{format_ident, quote};
6use syn::parse::{Parse, ParseStream};
7use syn::{parse_quote, Expr, Ident, Pat, Token};
8
9mod kw {
10 syn::custom_keyword!(complete);
11}
12
13struct Select {
14 // span of `complete`, then expression after `=> ...`
15 complete: Option<Expr>,
16 default: Option<Expr>,
17 normal_fut_exprs: Vec<Expr>,
18 normal_fut_handlers: Vec<(Pat, Expr)>,
19}
20
21#[allow(clippy::large_enum_variant)]
22enum CaseKind {
23 Complete,
24 Default,
25 Normal(Pat, Expr),
26}
27
28impl Parse for Select {
29 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
30 let mut select = Self {
31 complete: None,
32 default: None,
33 normal_fut_exprs: vec![],
34 normal_fut_handlers: vec![],
35 };
36
37 while !input.is_empty() {
38 let case_kind = if input.peek(kw::complete) {
39 // `complete`
40 if select.complete.is_some() {
41 return Err(input.error("multiple `complete` cases found, only one allowed"));
42 }
43 input.parse::<kw::complete>()?;
44 CaseKind::Complete
45 } else if input.peek(Token![default]) {
46 // `default`
47 if select.default.is_some() {
48 return Err(input.error("multiple `default` cases found, only one allowed"));
49 }
50 input.parse::<Ident>()?;
51 CaseKind::Default
52 } else {
53 // `<pat> = <expr>`
54 let pat = Pat::parse_multi_with_leading_vert(input)?;
55 input.parse::<Token![=]>()?;
56 let expr = input.parse()?;
57 CaseKind::Normal(pat, expr)
58 };
59
60 // `=> <expr>`
61 input.parse::<Token![=>]>()?;
62 let expr = input.parse::<Expr>()?;
63
64 // Commas after the expression are only optional if it's a `Block`
65 // or it is the last branch in the `match`.
66 let is_block = match expr {
67 Expr::Block(_) => true,
68 _ => false,
69 };
70 if is_block || input.is_empty() {
71 input.parse::<Option<Token![,]>>()?;
72 } else {
73 input.parse::<Token![,]>()?;
74 }
75
76 match case_kind {
77 CaseKind::Complete => select.complete = Some(expr),
78 CaseKind::Default => select.default = Some(expr),
79 CaseKind::Normal(pat, fut_expr) => {
80 select.normal_fut_exprs.push(fut_expr);
81 select.normal_fut_handlers.push((pat, expr));
82 }
83 }
84 }
85
86 Ok(select)
87 }
88}
89
90// Enum over all the cases in which the `select!` waiting has completed and the result
91// can be processed.
92//
93// `enum __PrivResult<_1, _2, ...> { _1(_1), _2(_2), ..., Complete }`
94fn declare_result_enum(
95 result_ident: Ident,
96 variants: usize,
97 complete: bool,
98 span: Span,
99) -> (Vec<Ident>, syn::ItemEnum) {
100 // "_0", "_1", "_2"
101 let variant_names: Vec<Ident> =
102 (0..variants).map(|num| format_ident!("_{}", num, span = span)).collect();
103
104 let type_parameters = &variant_names;
105 let variants = &variant_names;
106
107 let complete_variant = if complete { Some(quote!(Complete)) } else { None };
108
109 let enum_item = parse_quote! {
110 enum #result_ident<#(#type_parameters,)*> {
111 #(
112 #variants(#type_parameters),
113 )*
114 #complete_variant
115 }
116 };
117
118 (variant_names, enum_item)
119}
120
121/// The `select!` macro.
122pub(crate) fn select(input: TokenStream) -> TokenStream {
123 select_inner(input, true)
124}
125
126/// The `select_biased!` macro.
127pub(crate) fn select_biased(input: TokenStream) -> TokenStream {
128 select_inner(input, false)
129}
130
131fn select_inner(input: TokenStream, random: bool) -> TokenStream {
132 let parsed = syn::parse_macro_input!(input as Select);
133
134 // should be def_site, but that's unstable
135 let span = Span::call_site();
136
137 let enum_ident = Ident::new("__PrivResult", span);
138
139 let (variant_names, enum_item) = declare_result_enum(
140 enum_ident.clone(),
141 parsed.normal_fut_exprs.len(),
142 parsed.complete.is_some(),
143 span,
144 );
145
146 // bind non-`Ident` future exprs w/ `let`
147 let mut future_let_bindings = Vec::with_capacity(parsed.normal_fut_exprs.len());
148 let bound_future_names: Vec<_> = parsed
149 .normal_fut_exprs
150 .into_iter()
151 .zip(variant_names.iter())
152 .map(|(expr, variant_name)| {
153 match expr {
154 syn::Expr::Path(path) => {
155 // Don't bind futures that are already a path.
156 // This prevents creating redundant stack space
157 // for them.
158 // Passing Futures by path requires those Futures to implement Unpin.
159 // We check for this condition here in order to be able to
160 // safely use Pin::new_unchecked(&mut #path) later on.
161 future_let_bindings.push(quote! {
162 __futures_crate::async_await::assert_fused_future(&#path);
163 __futures_crate::async_await::assert_unpin(&#path);
164 });
165 path
166 }
167 _ => {
168 // Bind and pin the resulting Future on the stack. This is
169 // necessary to support direct select! calls on !Unpin
170 // Futures. The Future is not explicitly pinned here with
171 // a Pin call, but assumed as pinned. The actual Pin is
172 // created inside the poll() function below to defer the
173 // creation of the temporary pointer, which would otherwise
174 // increase the size of the generated Future.
175 // Safety: This is safe since the lifetime of the Future
176 // is totally constraint to the lifetime of the select!
177 // expression, and the Future can't get moved inside it
178 // (it is shadowed).
179 future_let_bindings.push(quote! {
180 let mut #variant_name = #expr;
181 });
182 parse_quote! { #variant_name }
183 }
184 }
185 })
186 .collect();
187
188 // For each future, make an `&mut dyn FnMut(&mut Context<'_>) -> Option<Poll<__PrivResult<...>>`
189 // to use for polling that individual future. These will then be put in an array.
190 let poll_functions = bound_future_names.iter().zip(variant_names.iter()).map(
191 |(bound_future_name, variant_name)| {
192 // Below we lazily create the Pin on the Future below.
193 // This is done in order to avoid allocating memory in the generator
194 // for the Pin variable.
195 // Safety: This is safe because one of the following condition applies:
196 // 1. The Future is passed by the caller by name, and we assert that
197 // it implements Unpin.
198 // 2. The Future is created in scope of the select! function and will
199 // not be moved for the duration of it. It is thereby stack-pinned
200 quote! {
201 let mut #variant_name = |__cx: &mut __futures_crate::task::Context<'_>| {
202 let mut #bound_future_name = unsafe {
203 __futures_crate::Pin::new_unchecked(&mut #bound_future_name)
204 };
205 if __futures_crate::future::FusedFuture::is_terminated(&#bound_future_name) {
206 __futures_crate::None
207 } else {
208 __futures_crate::Some(__futures_crate::future::FutureExt::poll_unpin(
209 &mut #bound_future_name,
210 __cx,
211 ).map(#enum_ident::#variant_name))
212 }
213 };
214 let #variant_name: &mut dyn FnMut(
215 &mut __futures_crate::task::Context<'_>
216 ) -> __futures_crate::Option<__futures_crate::task::Poll<_>> = &mut #variant_name;
217 }
218 },
219 );
220
221 let none_polled = if parsed.complete.is_some() {
222 quote! {
223 __futures_crate::task::Poll::Ready(#enum_ident::Complete)
224 }
225 } else {
226 quote! {
227 panic!("all futures in select! were completed,\
228 but no `complete =>` handler was provided")
229 }
230 };
231
232 let branches = parsed.normal_fut_handlers.into_iter().zip(variant_names.iter()).map(
233 |((pat, expr), variant_name)| {
234 quote! {
235 #enum_ident::#variant_name(#pat) => { #expr },
236 }
237 },
238 );
239 let branches = quote! { #( #branches )* };
240
241 let complete_branch = parsed.complete.map(|complete_expr| {
242 quote! {
243 #enum_ident::Complete => { #complete_expr },
244 }
245 });
246
247 let branches = quote! {
248 #branches
249 #complete_branch
250 };
251
252 let await_select_fut = if parsed.default.is_some() {
253 // For select! with default this returns the Poll result
254 quote! {
255 __poll_fn(&mut __futures_crate::task::Context::from_waker(
256 __futures_crate::task::noop_waker_ref()
257 ))
258 }
259 } else {
260 quote! {
261 __futures_crate::future::poll_fn(__poll_fn).await
262 }
263 };
264
265 let execute_result_expr = if let Some(default_expr) = &parsed.default {
266 // For select! with default __select_result is a Poll, otherwise not
267 quote! {
268 match __select_result {
269 __futures_crate::task::Poll::Ready(result) => match result {
270 #branches
271 },
272 _ => #default_expr
273 }
274 }
275 } else {
276 quote! {
277 match __select_result {
278 #branches
279 }
280 }
281 };
282
283 let shuffle = if random {
284 quote! {
285 __futures_crate::async_await::shuffle(&mut __select_arr);
286 }
287 } else {
288 quote!()
289 };
290
291 TokenStream::from(quote! { {
292 #enum_item
293
294 let __select_result = {
295 #( #future_let_bindings )*
296
297 let mut __poll_fn = |__cx: &mut __futures_crate::task::Context<'_>| {
298 let mut __any_polled = false;
299
300 #( #poll_functions )*
301
302 let mut __select_arr = [#( #variant_names ),*];
303 #shuffle
304 for poller in &mut __select_arr {
305 let poller: &mut &mut dyn FnMut(
306 &mut __futures_crate::task::Context<'_>
307 ) -> __futures_crate::Option<__futures_crate::task::Poll<_>> = poller;
308 match poller(__cx) {
309 __futures_crate::Some(x @ __futures_crate::task::Poll::Ready(_)) =>
310 return x,
311 __futures_crate::Some(__futures_crate::task::Poll::Pending) => {
312 __any_polled = true;
313 }
314 __futures_crate::None => {}
315 }
316 }
317
318 if !__any_polled {
319 #none_polled
320 } else {
321 __futures_crate::task::Poll::Pending
322 }
323 };
324
325 #await_select_fut
326 };
327
328 #execute_result_expr
329 } })
330}
331