1use crate::bound::{has_bound, InferredBound, Supertraits};
2use crate::lifetime::{AddLifetimeToImplTrait, CollectLifetimes};
3use crate::parse::Item;
4use crate::receiver::{has_self_in_block, has_self_in_sig, mut_pat, ReplaceSelf};
5use crate::verbatim::VerbatimFn;
6use proc_macro2::{Span, TokenStream};
7use quote::{format_ident, quote, quote_spanned, ToTokens};
8use std::collections::BTreeSet as Set;
9use std::mem;
10use syn::punctuated::Punctuated;
11use syn::visit_mut::{self, VisitMut};
12use syn::{
13 parse_quote, parse_quote_spanned, Attribute, Block, FnArg, GenericArgument, GenericParam,
14 Generics, Ident, ImplItem, Lifetime, LifetimeParam, Pat, PatIdent, PathArguments, Receiver,
15 ReturnType, Signature, Token, TraitItem, Type, TypePath, WhereClause,
16};
17
18impl ToTokens for Item {
19 fn to_tokens(&self, tokens: &mut TokenStream) {
20 match self {
21 Item::Trait(item: &ItemTrait) => item.to_tokens(tokens),
22 Item::Impl(item: &ItemImpl) => item.to_tokens(tokens),
23 }
24 }
25}
26
27#[derive(Clone, Copy)]
28enum Context<'a> {
29 Trait {
30 generics: &'a Generics,
31 supertraits: &'a Supertraits,
32 },
33 Impl {
34 impl_generics: &'a Generics,
35 associated_type_impl_traits: &'a Set<Ident>,
36 },
37}
38
39impl Context<'_> {
40 fn lifetimes<'a>(&'a self, used: &'a [Lifetime]) -> impl Iterator<Item = &'a LifetimeParam> {
41 let generics: &&Generics = match self {
42 Context::Trait { generics: &&Generics, .. } => generics,
43 Context::Impl { impl_generics: &&Generics, .. } => impl_generics,
44 };
45 generics.params.iter().filter_map(move |param: &GenericParam| {
46 if let GenericParam::Lifetime(param: &LifetimeParam) = param {
47 if used.contains(&param.lifetime) {
48 return Some(param);
49 }
50 }
51 None
52 })
53 }
54}
55
56pub fn expand(input: &mut Item, is_local: bool) {
57 match input {
58 Item::Trait(input) => {
59 let context = Context::Trait {
60 generics: &input.generics,
61 supertraits: &input.supertraits,
62 };
63 for inner in &mut input.items {
64 if let TraitItem::Fn(method) = inner {
65 let sig = &mut method.sig;
66 if sig.asyncness.is_some() {
67 let block = &mut method.default;
68 let mut has_self = has_self_in_sig(sig);
69 method.attrs.push(parse_quote!(#[must_use]));
70 if let Some(block) = block {
71 has_self |= has_self_in_block(block);
72 transform_block(context, sig, block);
73 method.attrs.push(lint_suppress_with_body());
74 } else {
75 method.attrs.push(lint_suppress_without_body());
76 }
77 let has_default = method.default.is_some();
78 transform_sig(context, sig, has_self, has_default, is_local);
79 }
80 }
81 }
82 }
83 Item::Impl(input) => {
84 let mut associated_type_impl_traits = Set::new();
85 for inner in &input.items {
86 if let ImplItem::Type(assoc) = inner {
87 if let Type::ImplTrait(_) = assoc.ty {
88 associated_type_impl_traits.insert(assoc.ident.clone());
89 }
90 }
91 }
92
93 let context = Context::Impl {
94 impl_generics: &input.generics,
95 associated_type_impl_traits: &associated_type_impl_traits,
96 };
97 for inner in &mut input.items {
98 match inner {
99 ImplItem::Fn(method) if method.sig.asyncness.is_some() => {
100 let sig = &mut method.sig;
101 let block = &mut method.block;
102 let has_self = has_self_in_sig(sig) || has_self_in_block(block);
103 transform_block(context, sig, block);
104 transform_sig(context, sig, has_self, false, is_local);
105 method.attrs.push(lint_suppress_with_body());
106 }
107 ImplItem::Verbatim(tokens) => {
108 let mut method = match syn::parse2::<VerbatimFn>(tokens.clone()) {
109 Ok(method) if method.sig.asyncness.is_some() => method,
110 _ => continue,
111 };
112 let sig = &mut method.sig;
113 let has_self = has_self_in_sig(sig);
114 transform_sig(context, sig, has_self, false, is_local);
115 method.attrs.push(lint_suppress_with_body());
116 *tokens = quote!(#method);
117 }
118 _ => {}
119 }
120 }
121 }
122 }
123}
124
125fn lint_suppress_with_body() -> Attribute {
126 parse_quote! {
127 #[allow(
128 clippy::async_yields_async,
129 clippy::diverging_sub_expression,
130 clippy::let_unit_value,
131 clippy::no_effect_underscore_binding,
132 clippy::shadow_same,
133 clippy::type_complexity,
134 clippy::type_repetition_in_bounds,
135 clippy::used_underscore_binding
136 )]
137 }
138}
139
140fn lint_suppress_without_body() -> Attribute {
141 parse_quote! {
142 #[allow(
143 clippy::type_complexity,
144 clippy::type_repetition_in_bounds
145 )]
146 }
147}
148
149// Input:
150// async fn f<T>(&self, x: &T) -> Ret;
151//
152// Output:
153// fn f<'life0, 'life1, 'async_trait, T>(
154// &'life0 self,
155// x: &'life1 T,
156// ) -> Pin<Box<dyn Future<Output = Ret> + Send + 'async_trait>>
157// where
158// 'life0: 'async_trait,
159// 'life1: 'async_trait,
160// T: 'async_trait,
161// Self: Sync + 'async_trait;
162fn transform_sig(
163 context: Context,
164 sig: &mut Signature,
165 has_self: bool,
166 has_default: bool,
167 is_local: bool,
168) {
169 sig.fn_token.span = sig.asyncness.take().unwrap().span;
170
171 let (ret_arrow, ret) = match &sig.output {
172 ReturnType::Default => (Token![->](Span::call_site()), quote!(())),
173 ReturnType::Type(arrow, ret) => (*arrow, quote!(#ret)),
174 };
175
176 let mut lifetimes = CollectLifetimes::new();
177 for arg in &mut sig.inputs {
178 match arg {
179 FnArg::Receiver(arg) => lifetimes.visit_receiver_mut(arg),
180 FnArg::Typed(arg) => lifetimes.visit_type_mut(&mut arg.ty),
181 }
182 }
183
184 for param in &mut sig.generics.params {
185 match param {
186 GenericParam::Type(param) => {
187 let param_name = &param.ident;
188 let span = match param.colon_token.take() {
189 Some(colon_token) => colon_token.span,
190 None => param_name.span(),
191 };
192 let bounds = mem::replace(&mut param.bounds, Punctuated::new());
193 where_clause_or_default(&mut sig.generics.where_clause)
194 .predicates
195 .push(parse_quote_spanned!(span=> #param_name: 'async_trait + #bounds));
196 }
197 GenericParam::Lifetime(param) => {
198 let param_name = &param.lifetime;
199 let span = match param.colon_token.take() {
200 Some(colon_token) => colon_token.span,
201 None => param_name.span(),
202 };
203 let bounds = mem::replace(&mut param.bounds, Punctuated::new());
204 where_clause_or_default(&mut sig.generics.where_clause)
205 .predicates
206 .push(parse_quote_spanned!(span=> #param: 'async_trait + #bounds));
207 }
208 GenericParam::Const(_) => {}
209 }
210 }
211
212 for param in context.lifetimes(&lifetimes.explicit) {
213 let param = &param.lifetime;
214 let span = param.span();
215 where_clause_or_default(&mut sig.generics.where_clause)
216 .predicates
217 .push(parse_quote_spanned!(span=> #param: 'async_trait));
218 }
219
220 if sig.generics.lt_token.is_none() {
221 sig.generics.lt_token = Some(Token![<](sig.ident.span()));
222 }
223 if sig.generics.gt_token.is_none() {
224 sig.generics.gt_token = Some(Token![>](sig.paren_token.span.join()));
225 }
226
227 for elided in lifetimes.elided {
228 sig.generics.params.push(parse_quote!(#elided));
229 where_clause_or_default(&mut sig.generics.where_clause)
230 .predicates
231 .push(parse_quote_spanned!(elided.span()=> #elided: 'async_trait));
232 }
233
234 sig.generics.params.push(parse_quote!('async_trait));
235
236 if has_self {
237 let bounds: &[InferredBound] = if is_local {
238 &[]
239 } else if let Some(receiver) = sig.receiver() {
240 match receiver.ty.as_ref() {
241 // self: &Self
242 Type::Reference(ty) if ty.mutability.is_none() => &[InferredBound::Sync],
243 // self: Arc<Self>
244 Type::Path(ty)
245 if {
246 let segment = ty.path.segments.last().unwrap();
247 segment.ident == "Arc"
248 && match &segment.arguments {
249 PathArguments::AngleBracketed(arguments) => {
250 arguments.args.len() == 1
251 && match &arguments.args[0] {
252 GenericArgument::Type(Type::Path(arg)) => {
253 arg.path.is_ident("Self")
254 }
255 _ => false,
256 }
257 }
258 _ => false,
259 }
260 } =>
261 {
262 &[InferredBound::Sync, InferredBound::Send]
263 }
264 _ => &[InferredBound::Send],
265 }
266 } else {
267 &[InferredBound::Send]
268 };
269
270 let bounds = bounds.iter().filter(|bound| match context {
271 Context::Trait { supertraits, .. } => has_default && !has_bound(supertraits, bound),
272 Context::Impl { .. } => false,
273 });
274
275 where_clause_or_default(&mut sig.generics.where_clause)
276 .predicates
277 .push(parse_quote! {
278 Self: #(#bounds +)* 'async_trait
279 });
280 }
281
282 for (i, arg) in sig.inputs.iter_mut().enumerate() {
283 match arg {
284 FnArg::Receiver(receiver) => {
285 if receiver.reference.is_none() {
286 receiver.mutability = None;
287 }
288 }
289 FnArg::Typed(arg) => {
290 if match *arg.ty {
291 Type::Reference(_) => false,
292 _ => true,
293 } {
294 if let Pat::Ident(pat) = &mut *arg.pat {
295 pat.by_ref = None;
296 pat.mutability = None;
297 } else {
298 let positional = positional_arg(i, &arg.pat);
299 let m = mut_pat(&mut arg.pat);
300 arg.pat = parse_quote!(#m #positional);
301 }
302 }
303 AddLifetimeToImplTrait.visit_type_mut(&mut arg.ty);
304 }
305 }
306 }
307
308 let bounds = if is_local {
309 quote!('async_trait)
310 } else {
311 quote!(::core::marker::Send + 'async_trait)
312 };
313 sig.output = parse_quote! {
314 #ret_arrow ::core::pin::Pin<Box<
315 dyn ::core::future::Future<Output = #ret> + #bounds
316 >>
317 };
318}
319
320// Input:
321// async fn f<T>(&self, x: &T, (a, b): (A, B)) -> Ret {
322// self + x + a + b
323// }
324//
325// Output:
326// Box::pin(async move {
327// let ___ret: Ret = {
328// let __self = self;
329// let x = x;
330// let (a, b) = __arg1;
331//
332// __self + x + a + b
333// };
334//
335// ___ret
336// })
337fn transform_block(context: Context, sig: &mut Signature, block: &mut Block) {
338 let mut self_span = None;
339 let decls = sig
340 .inputs
341 .iter()
342 .enumerate()
343 .map(|(i, arg)| match arg {
344 FnArg::Receiver(Receiver {
345 self_token,
346 mutability,
347 ..
348 }) => {
349 let ident = Ident::new("__self", self_token.span);
350 self_span = Some(self_token.span);
351 quote!(let #mutability #ident = #self_token;)
352 }
353 FnArg::Typed(arg) => {
354 // If there is a #[cfg(...)] attribute that selectively enables
355 // the parameter, forward it to the variable.
356 //
357 // This is currently not applied to the `self` parameter.
358 let attrs = arg.attrs.iter().filter(|attr| attr.path().is_ident("cfg"));
359
360 if let Type::Reference(_) = *arg.ty {
361 quote!()
362 } else if let Pat::Ident(PatIdent {
363 ident, mutability, ..
364 }) = &*arg.pat
365 {
366 quote! {
367 #(#attrs)*
368 let #mutability #ident = #ident;
369 }
370 } else {
371 let pat = &arg.pat;
372 let ident = positional_arg(i, pat);
373 if let Pat::Wild(_) = **pat {
374 quote! {
375 #(#attrs)*
376 let #ident = #ident;
377 }
378 } else {
379 quote! {
380 #(#attrs)*
381 let #pat = {
382 let #ident = #ident;
383 #ident
384 };
385 }
386 }
387 }
388 }
389 })
390 .collect::<Vec<_>>();
391
392 if let Some(span) = self_span {
393 let mut replace_self = ReplaceSelf(span);
394 replace_self.visit_block_mut(block);
395 }
396
397 let stmts = &block.stmts;
398 let let_ret = match &mut sig.output {
399 ReturnType::Default => quote_spanned! {block.brace_token.span=>
400 #(#decls)*
401 let () = { #(#stmts)* };
402 },
403 ReturnType::Type(_, ret) => {
404 if contains_associated_type_impl_trait(context, ret) {
405 if decls.is_empty() {
406 quote!(#(#stmts)*)
407 } else {
408 quote!(#(#decls)* { #(#stmts)* })
409 }
410 } else {
411 quote! {
412 if let ::core::option::Option::Some(__ret) = ::core::option::Option::None::<#ret> {
413 return __ret;
414 }
415 #(#decls)*
416 let __ret: #ret = { #(#stmts)* };
417 #[allow(unreachable_code)]
418 __ret
419 }
420 }
421 }
422 };
423 let box_pin = quote_spanned!(block.brace_token.span=>
424 Box::pin(async move { #let_ret })
425 );
426 block.stmts = parse_quote!(#box_pin);
427}
428
429fn positional_arg(i: usize, pat: &Pat) -> Ident {
430 let span: Span = syn::spanned::Spanned::span(self:pat);
431 #[cfg(not(no_span_mixed_site))]
432 let span: Span = span.resolved_at(Span::mixed_site());
433 format_ident!("__arg{}", i, span = span)
434}
435
436fn contains_associated_type_impl_trait(context: Context, ret: &mut Type) -> bool {
437 struct AssociatedTypeImplTraits<'a> {
438 set: &'a Set<Ident>,
439 contains: bool,
440 }
441
442 impl<'a> VisitMut for AssociatedTypeImplTraits<'a> {
443 fn visit_type_path_mut(&mut self, ty: &mut TypePath) {
444 if ty.qself.is_none()
445 && ty.path.segments.len() == 2
446 && ty.path.segments[0].ident == "Self"
447 && self.set.contains(&ty.path.segments[1].ident)
448 {
449 self.contains = true;
450 }
451 visit_mut::visit_type_path_mut(self, ty);
452 }
453 }
454
455 match context {
456 Context::Trait { .. } => false,
457 Context::Impl {
458 associated_type_impl_traits,
459 ..
460 } => {
461 let mut visit = AssociatedTypeImplTraits {
462 set: associated_type_impl_traits,
463 contains: false,
464 };
465 visit.visit_type_mut(ret);
466 visit.contains
467 }
468 }
469}
470
471fn where_clause_or_default(clause: &mut Option<WhereClause>) -> &mut WhereClause {
472 clause.get_or_insert_with(|| WhereClause {
473 where_token: Default::default(),
474 predicates: Punctuated::new(),
475 })
476}
477