1 | use proc_macro2::{Span, TokenStream};
|
2 | use quote::{quote, ToTokens};
|
3 | use syn::{
|
4 | parse_quote, punctuated::Punctuated, visit_mut::VisitMut, Block, Lifetime, Receiver,
|
5 | ReturnType, Signature, TypeReference, WhereClause,
|
6 | };
|
7 |
|
8 | use crate::parse::{AsyncItem, RecursionArgs};
|
9 |
|
10 | impl ToTokens for AsyncItem {
|
11 | fn to_tokens(&self, tokens: &mut TokenStream) {
|
12 | self.0.to_tokens(tokens);
|
13 | }
|
14 | }
|
15 |
|
16 | pub fn expand(item: &mut AsyncItem, args: &RecursionArgs) {
|
17 | item.0.attrs.push(parse_quote!(#[must_use]));
|
18 | transform_sig(&mut item.0.sig, args);
|
19 | transform_block(&mut item.0.block);
|
20 | }
|
21 |
|
22 | fn transform_block(block: &mut Block) {
|
23 | let brace: Brace = block.brace_token;
|
24 | *block = parse_quote!({
|
25 | Box::pin(async move #block)
|
26 | });
|
27 | block.brace_token = brace;
|
28 | }
|
29 |
|
30 | enum ArgLifetime {
|
31 | New(Lifetime),
|
32 | Existing(Lifetime),
|
33 | }
|
34 |
|
35 | impl ArgLifetime {
|
36 | pub fn lifetime(self) -> Lifetime {
|
37 | match self {
|
38 | ArgLifetime::New(lt: Lifetime) | ArgLifetime::Existing(lt: Lifetime) => lt,
|
39 | }
|
40 | }
|
41 | }
|
42 |
|
43 | #[derive (Default)]
|
44 | struct ReferenceVisitor {
|
45 | counter: usize,
|
46 | lifetimes: Vec<ArgLifetime>,
|
47 | self_receiver: bool,
|
48 | self_receiver_new_lifetime: bool,
|
49 | self_lifetime: Option<Lifetime>,
|
50 | }
|
51 |
|
52 | impl VisitMut for ReferenceVisitor {
|
53 | fn visit_receiver_mut(&mut self, receiver: &mut Receiver) {
|
54 | self.self_lifetime = Some(if let Some((_, lt)) = &mut receiver.reference {
|
55 | self.self_receiver = true;
|
56 |
|
57 | if let Some(lt) = lt {
|
58 | lt.clone()
|
59 | } else {
|
60 | // Use 'life_self to avoid collisions with 'life<count> lifetimes.
|
61 | let new_lifetime: Lifetime = parse_quote!('life_self);
|
62 | lt.replace(new_lifetime.clone());
|
63 |
|
64 | self.self_receiver_new_lifetime = true;
|
65 |
|
66 | new_lifetime
|
67 | }
|
68 | } else {
|
69 | return;
|
70 | });
|
71 | }
|
72 |
|
73 | fn visit_type_reference_mut(&mut self, argument: &mut TypeReference) {
|
74 | if argument.lifetime.is_none() {
|
75 | // If this reference doesn't have a lifetime (e.g. &T), then give it one.
|
76 | let lt = Lifetime::new(&format!("'life {}" , self.counter), Span::call_site());
|
77 | self.lifetimes.push(ArgLifetime::New(parse_quote!(#lt)));
|
78 | argument.lifetime = Some(lt);
|
79 | self.counter += 1;
|
80 | } else {
|
81 | // If it does (e.g. &'life T), then keep track of it.
|
82 | let lt = argument.lifetime.as_ref().cloned().unwrap();
|
83 |
|
84 | // Check that this lifetime isn't already in our vector
|
85 | let ident_matches = |x: &ArgLifetime| {
|
86 | if let ArgLifetime::Existing(elt) = x {
|
87 | elt.ident == lt.ident
|
88 | } else {
|
89 | false
|
90 | }
|
91 | };
|
92 |
|
93 | if !self.lifetimes.iter().any(ident_matches) {
|
94 | self.lifetimes.push(ArgLifetime::Existing(lt));
|
95 | }
|
96 | }
|
97 | }
|
98 | }
|
99 |
|
100 | // Input:
|
101 | // async fn f<S, T>(x : S, y : &T) -> Ret;
|
102 | //
|
103 | // Output:
|
104 | // fn f<S, T>(x : S, y : &T) -> Pin<Box<dyn Future<Output = Ret> + Send>
|
105 | fn transform_sig(sig: &mut Signature, args: &RecursionArgs) {
|
106 | // Determine the original return type
|
107 | let ret = match &sig.output {
|
108 | ReturnType::Default => quote!(()),
|
109 | ReturnType::Type(_, ret) => quote!(#ret),
|
110 | };
|
111 |
|
112 | // Remove the asyncness of this function
|
113 | sig.asyncness = None;
|
114 |
|
115 | // Find and update any references in the input arguments
|
116 | let mut v = ReferenceVisitor::default();
|
117 | for input in &mut sig.inputs {
|
118 | v.visit_fn_arg_mut(input);
|
119 | }
|
120 |
|
121 | // Does this expansion require `async_recursion to be added to the output?
|
122 | let mut requires_lifetime = false;
|
123 | let mut where_clause_lifetimes = vec![];
|
124 | let mut where_clause_generics = vec![];
|
125 |
|
126 | // 'async_recursion lifetime
|
127 | let asr: Lifetime = parse_quote!('async_recursion);
|
128 |
|
129 | // Add an S : 'async_recursion bound to any generic parameter
|
130 | for param in sig.generics.type_params() {
|
131 | let ident = param.ident.clone();
|
132 | where_clause_generics.push(ident);
|
133 | requires_lifetime = true;
|
134 | }
|
135 |
|
136 | // Add an 'a : 'async_recursion bound to any lifetimes 'a appearing in the function
|
137 | if !v.lifetimes.is_empty() {
|
138 | requires_lifetime = true;
|
139 | for alt in v.lifetimes {
|
140 | if let ArgLifetime::New(lt) = &alt {
|
141 | // If this is a new argument,
|
142 | sig.generics.params.push(parse_quote!(#lt));
|
143 | }
|
144 |
|
145 | // Add a bound to the where clause
|
146 | let lt = alt.lifetime();
|
147 | where_clause_lifetimes.push(lt);
|
148 | }
|
149 | }
|
150 |
|
151 | // If our function accepts &self, then we modify this to the explicit lifetime &'life_self,
|
152 | // and add the bound &'life_self : 'async_recursion
|
153 | if v.self_receiver {
|
154 | if v.self_receiver_new_lifetime {
|
155 | sig.generics.params.push(parse_quote!('life_self));
|
156 | }
|
157 | where_clause_lifetimes.extend(v.self_lifetime);
|
158 | requires_lifetime = true;
|
159 | }
|
160 |
|
161 | let box_lifetime: TokenStream = if requires_lifetime {
|
162 | // Add 'async_recursion to our generic parameters
|
163 | sig.generics.params.push(parse_quote!('async_recursion));
|
164 |
|
165 | quote!(+ #asr)
|
166 | } else {
|
167 | quote!()
|
168 | };
|
169 |
|
170 | let send_bound: TokenStream = if args.send_bound {
|
171 | quote!(+ ::core::marker::Send)
|
172 | } else {
|
173 | quote!()
|
174 | };
|
175 |
|
176 | let sync_bound: TokenStream = if args.sync_bound {
|
177 | quote!(+ ::core::marker::Sync)
|
178 | } else {
|
179 | quote!()
|
180 | };
|
181 |
|
182 | let where_clause = sig
|
183 | .generics
|
184 | .where_clause
|
185 | .get_or_insert_with(|| WhereClause {
|
186 | where_token: Default::default(),
|
187 | predicates: Punctuated::new(),
|
188 | });
|
189 |
|
190 | // Add our S : 'async_recursion bounds
|
191 | for generic_ident in where_clause_generics {
|
192 | where_clause
|
193 | .predicates
|
194 | .push(parse_quote!(#generic_ident : #asr));
|
195 | }
|
196 |
|
197 | // Add our 'a : 'async_recursion bounds
|
198 | for lifetime in where_clause_lifetimes {
|
199 | where_clause.predicates.push(parse_quote!(#lifetime : #asr));
|
200 | }
|
201 |
|
202 | // Modify the return type
|
203 | sig.output = parse_quote! {
|
204 | -> ::core::pin::Pin<Box<
|
205 | dyn ::core::future::Future<Output = #ret> #box_lifetime #send_bound #sync_bound>>
|
206 | };
|
207 | }
|
208 | |