1 | use proc_macro::{TokenStream, TokenTree}; |
2 | use proc_macro2::Span; |
3 | use quote::quote; |
4 | use syn::{parse::Parser, Ident}; |
5 | |
6 | pub(crate) fn declare_output_enum(input: TokenStream) -> TokenStream { |
7 | // passed in is: `(_ _ _)` with one `_` per branch |
8 | let branches = match input.into_iter().next() { |
9 | Some(TokenTree::Group(group)) => group.stream().into_iter().count(), |
10 | _ => panic!("unexpected macro input" ), |
11 | }; |
12 | |
13 | let variants = (0..branches) |
14 | .map(|num| Ident::new(&format!("_ {}" , num), Span::call_site())) |
15 | .collect::<Vec<_>>(); |
16 | |
17 | // Use a bitfield to track which futures completed |
18 | let mask = Ident::new( |
19 | if branches <= 8 { |
20 | "u8" |
21 | } else if branches <= 16 { |
22 | "u16" |
23 | } else if branches <= 32 { |
24 | "u32" |
25 | } else if branches <= 64 { |
26 | "u64" |
27 | } else { |
28 | panic!("up to 64 branches supported" ); |
29 | }, |
30 | Span::call_site(), |
31 | ); |
32 | |
33 | TokenStream::from(quote! { |
34 | pub(super) enum Out<#( #variants ),*> { |
35 | #( #variants(#variants), )* |
36 | // Include a `Disabled` variant signifying that all select branches |
37 | // failed to resolve. |
38 | Disabled, |
39 | } |
40 | |
41 | pub(super) type Mask = #mask; |
42 | }) |
43 | } |
44 | |
45 | pub(crate) fn clean_pattern_macro(input: TokenStream) -> TokenStream { |
46 | // If this isn't a pattern, we return the token stream as-is. The select! |
47 | // macro is using it in a location requiring a pattern, so an error will be |
48 | // emitted there. |
49 | let mut input: syn::Pat = match syn::Pat::parse_single.parse(tokens:input.clone()) { |
50 | Ok(it: Pat) => it, |
51 | Err(_) => return input, |
52 | }; |
53 | |
54 | clean_pattern(&mut input); |
55 | quote::ToTokens::into_token_stream(self:input).into() |
56 | } |
57 | |
58 | // Removes any occurrences of ref or mut in the provided pattern. |
59 | fn clean_pattern(pat: &mut syn::Pat) { |
60 | match pat { |
61 | syn::Pat::Lit(_literal) => {} |
62 | syn::Pat::Macro(_macro) => {} |
63 | syn::Pat::Path(_path) => {} |
64 | syn::Pat::Range(_range) => {} |
65 | syn::Pat::Rest(_rest) => {} |
66 | syn::Pat::Verbatim(_tokens) => {} |
67 | syn::Pat::Wild(_underscore) => {} |
68 | syn::Pat::Ident(ident) => { |
69 | ident.by_ref = None; |
70 | ident.mutability = None; |
71 | if let Some((_at, pat)) = &mut ident.subpat { |
72 | clean_pattern(&mut *pat); |
73 | } |
74 | } |
75 | syn::Pat::Or(or) => { |
76 | for case in &mut or.cases { |
77 | clean_pattern(case); |
78 | } |
79 | } |
80 | syn::Pat::Slice(slice) => { |
81 | for elem in &mut slice.elems { |
82 | clean_pattern(elem); |
83 | } |
84 | } |
85 | syn::Pat::Struct(struct_pat) => { |
86 | for field in &mut struct_pat.fields { |
87 | clean_pattern(&mut field.pat); |
88 | } |
89 | } |
90 | syn::Pat::Tuple(tuple) => { |
91 | for elem in &mut tuple.elems { |
92 | clean_pattern(elem); |
93 | } |
94 | } |
95 | syn::Pat::TupleStruct(tuple) => { |
96 | for elem in &mut tuple.elems { |
97 | clean_pattern(elem); |
98 | } |
99 | } |
100 | syn::Pat::Reference(reference) => { |
101 | reference.mutability = None; |
102 | clean_pattern(&mut reference.pat); |
103 | } |
104 | syn::Pat::Type(type_pat) => { |
105 | clean_pattern(&mut type_pat.pat); |
106 | } |
107 | _ => {} |
108 | } |
109 | } |
110 | |