1 | use crate::bound::{has_bound, InferredBound, Supertraits}; |
2 | use crate::lifetime::{AddLifetimeToImplTrait, CollectLifetimes}; |
3 | use crate::parse::Item; |
4 | use crate::receiver::{has_self_in_block, has_self_in_sig, mut_pat, ReplaceSelf}; |
5 | use crate::verbatim::VerbatimFn; |
6 | use proc_macro2::{Span, TokenStream}; |
7 | use quote::{format_ident, quote, quote_spanned, ToTokens}; |
8 | use std::collections::BTreeSet as Set; |
9 | use std::mem; |
10 | use syn::punctuated::Punctuated; |
11 | use syn::visit_mut::{self, VisitMut}; |
12 | use syn::{ |
13 | parse_quote, parse_quote_spanned, Attribute, Block, FnArg, GenericArgument, GenericParam, |
14 | Generics, Ident, ImplItem, Lifetime, LifetimeParam, Pat, PatIdent, PathArguments, Receiver, |
15 | ReturnType, Signature, Token, TraitItem, Type, TypeInfer, TypePath, WhereClause, |
16 | }; |
17 | |
18 | impl ToTokens for Item { |
19 | fn to_tokens(&self, tokens: &mut TokenStream) { |
20 | match self { |
21 | Item::Trait(item: &ItemTrait) => item.to_tokens(tokens), |
22 | Item::Impl(item: &ItemImpl) => item.to_tokens(tokens), |
23 | } |
24 | } |
25 | } |
26 | |
27 | #[derive (Clone, Copy)] |
28 | enum Context<'a> { |
29 | Trait { |
30 | generics: &'a Generics, |
31 | supertraits: &'a Supertraits, |
32 | }, |
33 | Impl { |
34 | impl_generics: &'a Generics, |
35 | associated_type_impl_traits: &'a Set<Ident>, |
36 | }, |
37 | } |
38 | |
39 | impl Context<'_> { |
40 | fn lifetimes<'a>(&'a self, used: &'a [Lifetime]) -> impl Iterator<Item = &'a LifetimeParam> { |
41 | let generics: &&Generics = match self { |
42 | Context::Trait { generics: &&Generics, .. } => generics, |
43 | Context::Impl { impl_generics: &&Generics, .. } => impl_generics, |
44 | }; |
45 | generics.params.iter().filter_map(move |param: &GenericParam| { |
46 | if let GenericParam::Lifetime(param: &LifetimeParam) = param { |
47 | if used.contains(¶m.lifetime) { |
48 | return Some(param); |
49 | } |
50 | } |
51 | None |
52 | }) |
53 | } |
54 | } |
55 | |
56 | pub fn expand(input: &mut Item, is_local: bool) { |
57 | match input { |
58 | Item::Trait(input) => { |
59 | let context = Context::Trait { |
60 | generics: &input.generics, |
61 | supertraits: &input.supertraits, |
62 | }; |
63 | for inner in &mut input.items { |
64 | if let TraitItem::Fn(method) = inner { |
65 | let sig = &mut method.sig; |
66 | if sig.asyncness.is_some() { |
67 | let block = &mut method.default; |
68 | let mut has_self = has_self_in_sig(sig); |
69 | method.attrs.push(parse_quote!(#[must_use])); |
70 | if let Some(block) = block { |
71 | has_self |= has_self_in_block(block); |
72 | transform_block(context, sig, block); |
73 | method.attrs.push(lint_suppress_with_body()); |
74 | } else { |
75 | method.attrs.push(lint_suppress_without_body()); |
76 | } |
77 | let has_default = method.default.is_some(); |
78 | transform_sig(context, sig, has_self, has_default, is_local); |
79 | } |
80 | } |
81 | } |
82 | } |
83 | Item::Impl(input) => { |
84 | let mut associated_type_impl_traits = Set::new(); |
85 | for inner in &input.items { |
86 | if let ImplItem::Type(assoc) = inner { |
87 | if let Type::ImplTrait(_) = assoc.ty { |
88 | associated_type_impl_traits.insert(assoc.ident.clone()); |
89 | } |
90 | } |
91 | } |
92 | |
93 | let context = Context::Impl { |
94 | impl_generics: &input.generics, |
95 | associated_type_impl_traits: &associated_type_impl_traits, |
96 | }; |
97 | for inner in &mut input.items { |
98 | match inner { |
99 | ImplItem::Fn(method) if method.sig.asyncness.is_some() => { |
100 | let sig = &mut method.sig; |
101 | let block = &mut method.block; |
102 | let has_self = has_self_in_sig(sig); |
103 | transform_block(context, sig, block); |
104 | transform_sig(context, sig, has_self, false, is_local); |
105 | method.attrs.push(lint_suppress_with_body()); |
106 | } |
107 | ImplItem::Verbatim(tokens) => { |
108 | let mut method = match syn::parse2::<VerbatimFn>(tokens.clone()) { |
109 | Ok(method) if method.sig.asyncness.is_some() => method, |
110 | _ => continue, |
111 | }; |
112 | let sig = &mut method.sig; |
113 | let has_self = has_self_in_sig(sig); |
114 | transform_sig(context, sig, has_self, false, is_local); |
115 | method.attrs.push(lint_suppress_with_body()); |
116 | *tokens = quote!(#method); |
117 | } |
118 | _ => {} |
119 | } |
120 | } |
121 | } |
122 | } |
123 | } |
124 | |
125 | fn lint_suppress_with_body() -> Attribute { |
126 | parse_quote! { |
127 | #[allow( |
128 | elided_named_lifetimes, |
129 | clippy::async_yields_async, |
130 | clippy::diverging_sub_expression, |
131 | clippy::let_unit_value, |
132 | clippy::needless_arbitrary_self_type, |
133 | clippy::no_effect_underscore_binding, |
134 | clippy::shadow_same, |
135 | clippy::type_complexity, |
136 | clippy::type_repetition_in_bounds, |
137 | clippy::used_underscore_binding |
138 | )] |
139 | } |
140 | } |
141 | |
142 | fn lint_suppress_without_body() -> Attribute { |
143 | parse_quote! { |
144 | #[allow( |
145 | elided_named_lifetimes, |
146 | clippy::type_complexity, |
147 | clippy::type_repetition_in_bounds |
148 | )] |
149 | } |
150 | } |
151 | |
152 | // Input: |
153 | // async fn f<T>(&self, x: &T) -> Ret; |
154 | // |
155 | // Output: |
156 | // fn f<'life0, 'life1, 'async_trait, T>( |
157 | // &'life0 self, |
158 | // x: &'life1 T, |
159 | // ) -> Pin<Box<dyn Future<Output = Ret> + Send + 'async_trait>> |
160 | // where |
161 | // 'life0: 'async_trait, |
162 | // 'life1: 'async_trait, |
163 | // T: 'async_trait, |
164 | // Self: Sync + 'async_trait; |
165 | fn transform_sig( |
166 | context: Context, |
167 | sig: &mut Signature, |
168 | has_self: bool, |
169 | has_default: bool, |
170 | is_local: bool, |
171 | ) { |
172 | sig.fn_token.span = sig.asyncness.take().unwrap().span; |
173 | |
174 | let (ret_arrow, ret) = match &sig.output { |
175 | ReturnType::Default => (quote!(->), quote!(())), |
176 | ReturnType::Type(arrow, ret) => (quote!(#arrow), quote!(#ret)), |
177 | }; |
178 | |
179 | let mut lifetimes = CollectLifetimes::new(); |
180 | for arg in &mut sig.inputs { |
181 | match arg { |
182 | FnArg::Receiver(arg) => lifetimes.visit_receiver_mut(arg), |
183 | FnArg::Typed(arg) => lifetimes.visit_type_mut(&mut arg.ty), |
184 | } |
185 | } |
186 | |
187 | for param in &mut sig.generics.params { |
188 | match param { |
189 | GenericParam::Type(param) => { |
190 | let param_name = ¶m.ident; |
191 | let span = match param.colon_token.take() { |
192 | Some(colon_token) => colon_token.span, |
193 | None => param_name.span(), |
194 | }; |
195 | if param.attrs.is_empty() { |
196 | let bounds = mem::take(&mut param.bounds); |
197 | where_clause_or_default(&mut sig.generics.where_clause) |
198 | .predicates |
199 | .push(parse_quote_spanned!(span=> #param_name: 'async_trait + #bounds)); |
200 | } else { |
201 | param.bounds.push(parse_quote!('async_trait)); |
202 | } |
203 | } |
204 | GenericParam::Lifetime(param) => { |
205 | let param_name = ¶m.lifetime; |
206 | let span = match param.colon_token.take() { |
207 | Some(colon_token) => colon_token.span, |
208 | None => param_name.span(), |
209 | }; |
210 | if param.attrs.is_empty() { |
211 | let bounds = mem::take(&mut param.bounds); |
212 | where_clause_or_default(&mut sig.generics.where_clause) |
213 | .predicates |
214 | .push(parse_quote_spanned!(span=> #param: 'async_trait + #bounds)); |
215 | } else { |
216 | param.bounds.push(parse_quote!('async_trait)); |
217 | } |
218 | } |
219 | GenericParam::Const(_) => {} |
220 | } |
221 | } |
222 | |
223 | for param in context.lifetimes(&lifetimes.explicit) { |
224 | let param = ¶m.lifetime; |
225 | let span = param.span(); |
226 | where_clause_or_default(&mut sig.generics.where_clause) |
227 | .predicates |
228 | .push(parse_quote_spanned!(span=> #param: 'async_trait)); |
229 | } |
230 | |
231 | if sig.generics.lt_token.is_none() { |
232 | sig.generics.lt_token = Some(Token)); |
233 | } |
234 | if sig.generics.gt_token.is_none() { |
235 | sig.generics.gt_token = Some(Token)); |
236 | } |
237 | |
238 | for elided in lifetimes.elided { |
239 | sig.generics.params.push(parse_quote!(#elided)); |
240 | where_clause_or_default(&mut sig.generics.where_clause) |
241 | .predicates |
242 | .push(parse_quote_spanned!(elided.span()=> #elided: 'async_trait)); |
243 | } |
244 | |
245 | sig.generics.params.push(parse_quote!('async_trait)); |
246 | |
247 | if has_self { |
248 | let bounds: &[InferredBound] = if is_local { |
249 | &[] |
250 | } else if let Some(receiver) = sig.receiver() { |
251 | match receiver.ty.as_ref() { |
252 | // self: &Self |
253 | Type::Reference(ty) if ty.mutability.is_none() => &[InferredBound::Sync], |
254 | // self: Arc<Self> |
255 | Type::Path(ty) |
256 | if { |
257 | let segment = ty.path.segments.last().unwrap(); |
258 | segment.ident == "Arc" |
259 | && match &segment.arguments { |
260 | PathArguments::AngleBracketed(arguments) => { |
261 | arguments.args.len() == 1 |
262 | && match &arguments.args[0] { |
263 | GenericArgument::Type(Type::Path(arg)) => { |
264 | arg.path.is_ident("Self" ) |
265 | } |
266 | _ => false, |
267 | } |
268 | } |
269 | _ => false, |
270 | } |
271 | } => |
272 | { |
273 | &[InferredBound::Sync, InferredBound::Send] |
274 | } |
275 | _ => &[InferredBound::Send], |
276 | } |
277 | } else { |
278 | &[InferredBound::Send] |
279 | }; |
280 | |
281 | let bounds = bounds.iter().filter(|bound| match context { |
282 | Context::Trait { supertraits, .. } => has_default && !has_bound(supertraits, bound), |
283 | Context::Impl { .. } => false, |
284 | }); |
285 | |
286 | where_clause_or_default(&mut sig.generics.where_clause) |
287 | .predicates |
288 | .push(parse_quote! { |
289 | Self: #(#bounds +)* 'async_trait |
290 | }); |
291 | } |
292 | |
293 | for (i, arg) in sig.inputs.iter_mut().enumerate() { |
294 | match arg { |
295 | FnArg::Receiver(receiver) => { |
296 | if receiver.reference.is_none() { |
297 | receiver.mutability = None; |
298 | } |
299 | } |
300 | FnArg::Typed(arg) => { |
301 | if match *arg.ty { |
302 | Type::Reference(_) => false, |
303 | _ => true, |
304 | } { |
305 | if let Pat::Ident(pat) = &mut *arg.pat { |
306 | pat.by_ref = None; |
307 | pat.mutability = None; |
308 | } else { |
309 | let positional = positional_arg(i, &arg.pat); |
310 | let m = mut_pat(&mut arg.pat); |
311 | arg.pat = parse_quote!(#m #positional); |
312 | } |
313 | } |
314 | AddLifetimeToImplTrait.visit_type_mut(&mut arg.ty); |
315 | } |
316 | } |
317 | } |
318 | |
319 | let bounds = if is_local { |
320 | quote!('async_trait) |
321 | } else { |
322 | quote!(::core::marker::Send + 'async_trait) |
323 | }; |
324 | sig.output = parse_quote! { |
325 | #ret_arrow ::core::pin::Pin<Box< |
326 | dyn ::core::future::Future<Output = #ret> + #bounds |
327 | >> |
328 | }; |
329 | } |
330 | |
331 | // Input: |
332 | // async fn f<T>(&self, x: &T, (a, b): (A, B)) -> Ret { |
333 | // self + x + a + b |
334 | // } |
335 | // |
336 | // Output: |
337 | // Box::pin(async move { |
338 | // let ___ret: Ret = { |
339 | // let __self = self; |
340 | // let x = x; |
341 | // let (a, b) = __arg1; |
342 | // |
343 | // __self + x + a + b |
344 | // }; |
345 | // |
346 | // ___ret |
347 | // }) |
348 | fn transform_block(context: Context, sig: &mut Signature, block: &mut Block) { |
349 | let mut replace_self = false; |
350 | let decls = sig |
351 | .inputs |
352 | .iter() |
353 | .enumerate() |
354 | .map(|(i, arg)| match arg { |
355 | FnArg::Receiver(Receiver { |
356 | self_token, |
357 | mutability, |
358 | .. |
359 | }) => { |
360 | replace_self = true; |
361 | let ident = Ident::new("__self" , self_token.span); |
362 | quote!(let #mutability #ident = #self_token;) |
363 | } |
364 | FnArg::Typed(arg) => { |
365 | // If there is a #[cfg(...)] attribute that selectively enables |
366 | // the parameter, forward it to the variable. |
367 | // |
368 | // This is currently not applied to the `self` parameter. |
369 | let attrs = arg.attrs.iter().filter(|attr| attr.path().is_ident("cfg" )); |
370 | |
371 | if let Type::Reference(_) = *arg.ty { |
372 | quote!() |
373 | } else if let Pat::Ident(PatIdent { |
374 | ident, mutability, .. |
375 | }) = &*arg.pat |
376 | { |
377 | quote! { |
378 | #(#attrs)* |
379 | let #mutability #ident = #ident; |
380 | } |
381 | } else { |
382 | let pat = &arg.pat; |
383 | let ident = positional_arg(i, pat); |
384 | if let Pat::Wild(_) = **pat { |
385 | quote! { |
386 | #(#attrs)* |
387 | let #ident = #ident; |
388 | } |
389 | } else { |
390 | quote! { |
391 | #(#attrs)* |
392 | let #pat = { |
393 | let #ident = #ident; |
394 | #ident |
395 | }; |
396 | } |
397 | } |
398 | } |
399 | } |
400 | }) |
401 | .collect::<Vec<_>>(); |
402 | |
403 | if replace_self { |
404 | ReplaceSelf.visit_block_mut(block); |
405 | } |
406 | |
407 | let stmts = &block.stmts; |
408 | let let_ret = match &mut sig.output { |
409 | ReturnType::Default => quote_spanned! {block.brace_token.span=> |
410 | #(#decls)* |
411 | let () = { #(#stmts)* }; |
412 | }, |
413 | ReturnType::Type(_, ret) => { |
414 | if contains_associated_type_impl_trait(context, ret) { |
415 | if decls.is_empty() { |
416 | quote!(#(#stmts)*) |
417 | } else { |
418 | quote!(#(#decls)* { #(#stmts)* }) |
419 | } |
420 | } else { |
421 | let mut ret = ret.clone(); |
422 | replace_impl_trait_with_infer(&mut ret); |
423 | quote! { |
424 | if let ::core::option::Option::Some(__ret) = ::core::option::Option::None::<#ret> { |
425 | #[allow(unreachable_code)] |
426 | return __ret; |
427 | } |
428 | #(#decls)* |
429 | let __ret: #ret = { #(#stmts)* }; |
430 | #[allow(unreachable_code)] |
431 | __ret |
432 | } |
433 | } |
434 | } |
435 | }; |
436 | let box_pin = quote_spanned!(block.brace_token.span=> |
437 | Box::pin(async move { #let_ret }) |
438 | ); |
439 | block.stmts = parse_quote!(#box_pin); |
440 | } |
441 | |
442 | fn positional_arg(i: usize, pat: &Pat) -> Ident { |
443 | let span: Span = syn::spanned::Spanned::span(self:pat).resolved_at(Span::mixed_site()); |
444 | format_ident!("__arg {}" , i, span = span) |
445 | } |
446 | |
447 | fn contains_associated_type_impl_trait(context: Context, ret: &mut Type) -> bool { |
448 | struct AssociatedTypeImplTraits<'a> { |
449 | set: &'a Set<Ident>, |
450 | contains: bool, |
451 | } |
452 | |
453 | impl<'a> VisitMut for AssociatedTypeImplTraits<'a> { |
454 | fn visit_type_path_mut(&mut self, ty: &mut TypePath) { |
455 | if ty.qself.is_none() |
456 | && ty.path.segments.len() == 2 |
457 | && ty.path.segments[0].ident == "Self" |
458 | && self.set.contains(&ty.path.segments[1].ident) |
459 | { |
460 | self.contains = true; |
461 | } |
462 | visit_mut::visit_type_path_mut(self, ty); |
463 | } |
464 | } |
465 | |
466 | match context { |
467 | Context::Trait { .. } => false, |
468 | Context::Impl { |
469 | associated_type_impl_traits, |
470 | .. |
471 | } => { |
472 | let mut visit = AssociatedTypeImplTraits { |
473 | set: associated_type_impl_traits, |
474 | contains: false, |
475 | }; |
476 | visit.visit_type_mut(ret); |
477 | visit.contains |
478 | } |
479 | } |
480 | } |
481 | |
482 | fn where_clause_or_default(clause: &mut Option<WhereClause>) -> &mut WhereClause { |
483 | clause.get_or_insert_with(|| WhereClause { |
484 | where_token: Default::default(), |
485 | predicates: Punctuated::new(), |
486 | }) |
487 | } |
488 | |
489 | fn replace_impl_trait_with_infer(ty: &mut Type) { |
490 | struct ReplaceImplTraitWithInfer; |
491 | |
492 | impl VisitMut for ReplaceImplTraitWithInfer { |
493 | fn visit_type_mut(&mut self, ty: &mut Type) { |
494 | if let Type::ImplTrait(impl_trait: &mut TypeImplTrait) = ty { |
495 | *ty = Type::Infer(TypeInfer { |
496 | underscore_token: Token, |
497 | }); |
498 | } |
499 | visit_mut::visit_type_mut(self, node:ty); |
500 | } |
501 | } |
502 | |
503 | ReplaceImplTraitWithInfer.visit_type_mut(ty); |
504 | } |
505 | |