| 1 | use proc_macro::{TokenStream, TokenTree}; |
| 2 | use proc_macro2::Span; |
| 3 | use quote::quote; |
| 4 | use syn::{parse::Parser, Ident}; |
| 5 | |
| 6 | pub(crate) fn declare_output_enum(input: TokenStream) -> TokenStream { |
| 7 | // passed in is: `(_ _ _)` with one `_` per branch |
| 8 | let branches = match input.into_iter().next() { |
| 9 | Some(TokenTree::Group(group)) => group.stream().into_iter().count(), |
| 10 | _ => panic!("unexpected macro input" ), |
| 11 | }; |
| 12 | |
| 13 | let variants = (0..branches) |
| 14 | .map(|num| Ident::new(&format!("_ {num}" ), Span::call_site())) |
| 15 | .collect::<Vec<_>>(); |
| 16 | |
| 17 | // Use a bitfield to track which futures completed |
| 18 | let mask = Ident::new( |
| 19 | if branches <= 8 { |
| 20 | "u8" |
| 21 | } else if branches <= 16 { |
| 22 | "u16" |
| 23 | } else if branches <= 32 { |
| 24 | "u32" |
| 25 | } else if branches <= 64 { |
| 26 | "u64" |
| 27 | } else { |
| 28 | panic!("up to 64 branches supported" ); |
| 29 | }, |
| 30 | Span::call_site(), |
| 31 | ); |
| 32 | |
| 33 | TokenStream::from(quote! { |
| 34 | pub(super) enum Out<#( #variants ),*> { |
| 35 | #( #variants(#variants), )* |
| 36 | // Include a `Disabled` variant signifying that all select branches |
| 37 | // failed to resolve. |
| 38 | Disabled, |
| 39 | } |
| 40 | |
| 41 | pub(super) type Mask = #mask; |
| 42 | }) |
| 43 | } |
| 44 | |
| 45 | pub(crate) fn clean_pattern_macro(input: TokenStream) -> TokenStream { |
| 46 | // If this isn't a pattern, we return the token stream as-is. The select! |
| 47 | // macro is using it in a location requiring a pattern, so an error will be |
| 48 | // emitted there. |
| 49 | let mut input: syn::Pat = match syn::Pat::parse_single.parse(tokens:input.clone()) { |
| 50 | Ok(it: Pat) => it, |
| 51 | Err(_) => return input, |
| 52 | }; |
| 53 | |
| 54 | clean_pattern(&mut input); |
| 55 | quote::ToTokens::into_token_stream(self:input).into() |
| 56 | } |
| 57 | |
| 58 | // Removes any occurrences of ref or mut in the provided pattern. |
| 59 | fn clean_pattern(pat: &mut syn::Pat) { |
| 60 | match pat { |
| 61 | syn::Pat::Lit(_literal) => {} |
| 62 | syn::Pat::Macro(_macro) => {} |
| 63 | syn::Pat::Path(_path) => {} |
| 64 | syn::Pat::Range(_range) => {} |
| 65 | syn::Pat::Rest(_rest) => {} |
| 66 | syn::Pat::Verbatim(_tokens) => {} |
| 67 | syn::Pat::Wild(_underscore) => {} |
| 68 | syn::Pat::Ident(ident) => { |
| 69 | ident.by_ref = None; |
| 70 | ident.mutability = None; |
| 71 | if let Some((_at, pat)) = &mut ident.subpat { |
| 72 | clean_pattern(&mut *pat); |
| 73 | } |
| 74 | } |
| 75 | syn::Pat::Or(or) => { |
| 76 | for case in &mut or.cases { |
| 77 | clean_pattern(case); |
| 78 | } |
| 79 | } |
| 80 | syn::Pat::Slice(slice) => { |
| 81 | for elem in &mut slice.elems { |
| 82 | clean_pattern(elem); |
| 83 | } |
| 84 | } |
| 85 | syn::Pat::Struct(struct_pat) => { |
| 86 | for field in &mut struct_pat.fields { |
| 87 | clean_pattern(&mut field.pat); |
| 88 | } |
| 89 | } |
| 90 | syn::Pat::Tuple(tuple) => { |
| 91 | for elem in &mut tuple.elems { |
| 92 | clean_pattern(elem); |
| 93 | } |
| 94 | } |
| 95 | syn::Pat::TupleStruct(tuple) => { |
| 96 | for elem in &mut tuple.elems { |
| 97 | clean_pattern(elem); |
| 98 | } |
| 99 | } |
| 100 | syn::Pat::Reference(reference) => { |
| 101 | reference.mutability = None; |
| 102 | clean_pattern(&mut reference.pat); |
| 103 | } |
| 104 | syn::Pat::Type(type_pat) => { |
| 105 | clean_pattern(&mut type_pat.pat); |
| 106 | } |
| 107 | _ => {} |
| 108 | } |
| 109 | } |
| 110 | |