1 | use crate::bound::{has_bound, InferredBound, Supertraits}; |
2 | use crate::lifetime::{AddLifetimeToImplTrait, CollectLifetimes}; |
3 | use crate::parse::Item; |
4 | use crate::receiver::{has_self_in_block, has_self_in_sig, mut_pat, ReplaceSelf}; |
5 | use crate::verbatim::VerbatimFn; |
6 | use proc_macro2::{Span, TokenStream}; |
7 | use quote::{format_ident, quote, quote_spanned, ToTokens}; |
8 | use std::collections::BTreeSet as Set; |
9 | use std::mem; |
10 | use syn::punctuated::Punctuated; |
11 | use syn::visit_mut::{self, VisitMut}; |
12 | use syn::{ |
13 | parse_quote, parse_quote_spanned, Attribute, Block, FnArg, GenericArgument, GenericParam, |
14 | Generics, Ident, ImplItem, Lifetime, LifetimeParam, Pat, PatIdent, PathArguments, Receiver, |
15 | ReturnType, Signature, Token, TraitItem, Type, TypePath, WhereClause, |
16 | }; |
17 | |
18 | impl ToTokens for Item { |
19 | fn to_tokens(&self, tokens: &mut TokenStream) { |
20 | match self { |
21 | Item::Trait(item: &ItemTrait) => item.to_tokens(tokens), |
22 | Item::Impl(item: &ItemImpl) => item.to_tokens(tokens), |
23 | } |
24 | } |
25 | } |
26 | |
27 | #[derive (Clone, Copy)] |
28 | enum Context<'a> { |
29 | Trait { |
30 | generics: &'a Generics, |
31 | supertraits: &'a Supertraits, |
32 | }, |
33 | Impl { |
34 | impl_generics: &'a Generics, |
35 | associated_type_impl_traits: &'a Set<Ident>, |
36 | }, |
37 | } |
38 | |
39 | impl Context<'_> { |
40 | fn lifetimes<'a>(&'a self, used: &'a [Lifetime]) -> impl Iterator<Item = &'a LifetimeParam> { |
41 | let generics: &&Generics = match self { |
42 | Context::Trait { generics: &&Generics, .. } => generics, |
43 | Context::Impl { impl_generics: &&Generics, .. } => impl_generics, |
44 | }; |
45 | generics.params.iter().filter_map(move |param: &GenericParam| { |
46 | if let GenericParam::Lifetime(param: &LifetimeParam) = param { |
47 | if used.contains(¶m.lifetime) { |
48 | return Some(param); |
49 | } |
50 | } |
51 | None |
52 | }) |
53 | } |
54 | } |
55 | |
56 | pub fn expand(input: &mut Item, is_local: bool) { |
57 | match input { |
58 | Item::Trait(input) => { |
59 | let context = Context::Trait { |
60 | generics: &input.generics, |
61 | supertraits: &input.supertraits, |
62 | }; |
63 | for inner in &mut input.items { |
64 | if let TraitItem::Fn(method) = inner { |
65 | let sig = &mut method.sig; |
66 | if sig.asyncness.is_some() { |
67 | let block = &mut method.default; |
68 | let mut has_self = has_self_in_sig(sig); |
69 | method.attrs.push(parse_quote!(#[must_use])); |
70 | if let Some(block) = block { |
71 | has_self |= has_self_in_block(block); |
72 | transform_block(context, sig, block); |
73 | method.attrs.push(lint_suppress_with_body()); |
74 | } else { |
75 | method.attrs.push(lint_suppress_without_body()); |
76 | } |
77 | let has_default = method.default.is_some(); |
78 | transform_sig(context, sig, has_self, has_default, is_local); |
79 | } |
80 | } |
81 | } |
82 | } |
83 | Item::Impl(input) => { |
84 | let mut associated_type_impl_traits = Set::new(); |
85 | for inner in &input.items { |
86 | if let ImplItem::Type(assoc) = inner { |
87 | if let Type::ImplTrait(_) = assoc.ty { |
88 | associated_type_impl_traits.insert(assoc.ident.clone()); |
89 | } |
90 | } |
91 | } |
92 | |
93 | let context = Context::Impl { |
94 | impl_generics: &input.generics, |
95 | associated_type_impl_traits: &associated_type_impl_traits, |
96 | }; |
97 | for inner in &mut input.items { |
98 | match inner { |
99 | ImplItem::Fn(method) if method.sig.asyncness.is_some() => { |
100 | let sig = &mut method.sig; |
101 | let block = &mut method.block; |
102 | let has_self = has_self_in_sig(sig) || has_self_in_block(block); |
103 | transform_block(context, sig, block); |
104 | transform_sig(context, sig, has_self, false, is_local); |
105 | method.attrs.push(lint_suppress_with_body()); |
106 | } |
107 | ImplItem::Verbatim(tokens) => { |
108 | let mut method = match syn::parse2::<VerbatimFn>(tokens.clone()) { |
109 | Ok(method) if method.sig.asyncness.is_some() => method, |
110 | _ => continue, |
111 | }; |
112 | let sig = &mut method.sig; |
113 | let has_self = has_self_in_sig(sig); |
114 | transform_sig(context, sig, has_self, false, is_local); |
115 | method.attrs.push(lint_suppress_with_body()); |
116 | *tokens = quote!(#method); |
117 | } |
118 | _ => {} |
119 | } |
120 | } |
121 | } |
122 | } |
123 | } |
124 | |
125 | fn lint_suppress_with_body() -> Attribute { |
126 | parse_quote! { |
127 | #[allow( |
128 | clippy::async_yields_async, |
129 | clippy::diverging_sub_expression, |
130 | clippy::let_unit_value, |
131 | clippy::no_effect_underscore_binding, |
132 | clippy::shadow_same, |
133 | clippy::type_complexity, |
134 | clippy::type_repetition_in_bounds, |
135 | clippy::used_underscore_binding |
136 | )] |
137 | } |
138 | } |
139 | |
140 | fn lint_suppress_without_body() -> Attribute { |
141 | parse_quote! { |
142 | #[allow( |
143 | clippy::type_complexity, |
144 | clippy::type_repetition_in_bounds |
145 | )] |
146 | } |
147 | } |
148 | |
149 | // Input: |
150 | // async fn f<T>(&self, x: &T) -> Ret; |
151 | // |
152 | // Output: |
153 | // fn f<'life0, 'life1, 'async_trait, T>( |
154 | // &'life0 self, |
155 | // x: &'life1 T, |
156 | // ) -> Pin<Box<dyn Future<Output = Ret> + Send + 'async_trait>> |
157 | // where |
158 | // 'life0: 'async_trait, |
159 | // 'life1: 'async_trait, |
160 | // T: 'async_trait, |
161 | // Self: Sync + 'async_trait; |
162 | fn transform_sig( |
163 | context: Context, |
164 | sig: &mut Signature, |
165 | has_self: bool, |
166 | has_default: bool, |
167 | is_local: bool, |
168 | ) { |
169 | sig.fn_token.span = sig.asyncness.take().unwrap().span; |
170 | |
171 | let (ret_arrow, ret) = match &sig.output { |
172 | ReturnType::Default => (Token![->](Span::call_site()), quote!(())), |
173 | ReturnType::Type(arrow, ret) => (*arrow, quote!(#ret)), |
174 | }; |
175 | |
176 | let mut lifetimes = CollectLifetimes::new(); |
177 | for arg in &mut sig.inputs { |
178 | match arg { |
179 | FnArg::Receiver(arg) => lifetimes.visit_receiver_mut(arg), |
180 | FnArg::Typed(arg) => lifetimes.visit_type_mut(&mut arg.ty), |
181 | } |
182 | } |
183 | |
184 | for param in &mut sig.generics.params { |
185 | match param { |
186 | GenericParam::Type(param) => { |
187 | let param_name = ¶m.ident; |
188 | let span = match param.colon_token.take() { |
189 | Some(colon_token) => colon_token.span, |
190 | None => param_name.span(), |
191 | }; |
192 | let bounds = mem::replace(&mut param.bounds, Punctuated::new()); |
193 | where_clause_or_default(&mut sig.generics.where_clause) |
194 | .predicates |
195 | .push(parse_quote_spanned!(span=> #param_name: 'async_trait + #bounds)); |
196 | } |
197 | GenericParam::Lifetime(param) => { |
198 | let param_name = ¶m.lifetime; |
199 | let span = match param.colon_token.take() { |
200 | Some(colon_token) => colon_token.span, |
201 | None => param_name.span(), |
202 | }; |
203 | let bounds = mem::replace(&mut param.bounds, Punctuated::new()); |
204 | where_clause_or_default(&mut sig.generics.where_clause) |
205 | .predicates |
206 | .push(parse_quote_spanned!(span=> #param: 'async_trait + #bounds)); |
207 | } |
208 | GenericParam::Const(_) => {} |
209 | } |
210 | } |
211 | |
212 | for param in context.lifetimes(&lifetimes.explicit) { |
213 | let param = ¶m.lifetime; |
214 | let span = param.span(); |
215 | where_clause_or_default(&mut sig.generics.where_clause) |
216 | .predicates |
217 | .push(parse_quote_spanned!(span=> #param: 'async_trait)); |
218 | } |
219 | |
220 | if sig.generics.lt_token.is_none() { |
221 | sig.generics.lt_token = Some(Token![<](sig.ident.span())); |
222 | } |
223 | if sig.generics.gt_token.is_none() { |
224 | sig.generics.gt_token = Some(Token![>](sig.paren_token.span.join())); |
225 | } |
226 | |
227 | for elided in lifetimes.elided { |
228 | sig.generics.params.push(parse_quote!(#elided)); |
229 | where_clause_or_default(&mut sig.generics.where_clause) |
230 | .predicates |
231 | .push(parse_quote_spanned!(elided.span()=> #elided: 'async_trait)); |
232 | } |
233 | |
234 | sig.generics.params.push(parse_quote!('async_trait)); |
235 | |
236 | if has_self { |
237 | let bounds: &[InferredBound] = if is_local { |
238 | &[] |
239 | } else if let Some(receiver) = sig.receiver() { |
240 | match receiver.ty.as_ref() { |
241 | // self: &Self |
242 | Type::Reference(ty) if ty.mutability.is_none() => &[InferredBound::Sync], |
243 | // self: Arc<Self> |
244 | Type::Path(ty) |
245 | if { |
246 | let segment = ty.path.segments.last().unwrap(); |
247 | segment.ident == "Arc" |
248 | && match &segment.arguments { |
249 | PathArguments::AngleBracketed(arguments) => { |
250 | arguments.args.len() == 1 |
251 | && match &arguments.args[0] { |
252 | GenericArgument::Type(Type::Path(arg)) => { |
253 | arg.path.is_ident("Self" ) |
254 | } |
255 | _ => false, |
256 | } |
257 | } |
258 | _ => false, |
259 | } |
260 | } => |
261 | { |
262 | &[InferredBound::Sync, InferredBound::Send] |
263 | } |
264 | _ => &[InferredBound::Send], |
265 | } |
266 | } else { |
267 | &[InferredBound::Send] |
268 | }; |
269 | |
270 | let bounds = bounds.iter().filter(|bound| match context { |
271 | Context::Trait { supertraits, .. } => has_default && !has_bound(supertraits, bound), |
272 | Context::Impl { .. } => false, |
273 | }); |
274 | |
275 | where_clause_or_default(&mut sig.generics.where_clause) |
276 | .predicates |
277 | .push(parse_quote! { |
278 | Self: #(#bounds +)* 'async_trait |
279 | }); |
280 | } |
281 | |
282 | for (i, arg) in sig.inputs.iter_mut().enumerate() { |
283 | match arg { |
284 | FnArg::Receiver(receiver) => { |
285 | if receiver.reference.is_none() { |
286 | receiver.mutability = None; |
287 | } |
288 | } |
289 | FnArg::Typed(arg) => { |
290 | if match *arg.ty { |
291 | Type::Reference(_) => false, |
292 | _ => true, |
293 | } { |
294 | if let Pat::Ident(pat) = &mut *arg.pat { |
295 | pat.by_ref = None; |
296 | pat.mutability = None; |
297 | } else { |
298 | let positional = positional_arg(i, &arg.pat); |
299 | let m = mut_pat(&mut arg.pat); |
300 | arg.pat = parse_quote!(#m #positional); |
301 | } |
302 | } |
303 | AddLifetimeToImplTrait.visit_type_mut(&mut arg.ty); |
304 | } |
305 | } |
306 | } |
307 | |
308 | let bounds = if is_local { |
309 | quote!('async_trait) |
310 | } else { |
311 | quote!(::core::marker::Send + 'async_trait) |
312 | }; |
313 | sig.output = parse_quote! { |
314 | #ret_arrow ::core::pin::Pin<Box< |
315 | dyn ::core::future::Future<Output = #ret> + #bounds |
316 | >> |
317 | }; |
318 | } |
319 | |
320 | // Input: |
321 | // async fn f<T>(&self, x: &T, (a, b): (A, B)) -> Ret { |
322 | // self + x + a + b |
323 | // } |
324 | // |
325 | // Output: |
326 | // Box::pin(async move { |
327 | // let ___ret: Ret = { |
328 | // let __self = self; |
329 | // let x = x; |
330 | // let (a, b) = __arg1; |
331 | // |
332 | // __self + x + a + b |
333 | // }; |
334 | // |
335 | // ___ret |
336 | // }) |
337 | fn transform_block(context: Context, sig: &mut Signature, block: &mut Block) { |
338 | let mut self_span = None; |
339 | let decls = sig |
340 | .inputs |
341 | .iter() |
342 | .enumerate() |
343 | .map(|(i, arg)| match arg { |
344 | FnArg::Receiver(Receiver { |
345 | self_token, |
346 | mutability, |
347 | .. |
348 | }) => { |
349 | let ident = Ident::new("__self" , self_token.span); |
350 | self_span = Some(self_token.span); |
351 | quote!(let #mutability #ident = #self_token;) |
352 | } |
353 | FnArg::Typed(arg) => { |
354 | // If there is a #[cfg(...)] attribute that selectively enables |
355 | // the parameter, forward it to the variable. |
356 | // |
357 | // This is currently not applied to the `self` parameter. |
358 | let attrs = arg.attrs.iter().filter(|attr| attr.path().is_ident("cfg" )); |
359 | |
360 | if let Type::Reference(_) = *arg.ty { |
361 | quote!() |
362 | } else if let Pat::Ident(PatIdent { |
363 | ident, mutability, .. |
364 | }) = &*arg.pat |
365 | { |
366 | quote! { |
367 | #(#attrs)* |
368 | let #mutability #ident = #ident; |
369 | } |
370 | } else { |
371 | let pat = &arg.pat; |
372 | let ident = positional_arg(i, pat); |
373 | if let Pat::Wild(_) = **pat { |
374 | quote! { |
375 | #(#attrs)* |
376 | let #ident = #ident; |
377 | } |
378 | } else { |
379 | quote! { |
380 | #(#attrs)* |
381 | let #pat = { |
382 | let #ident = #ident; |
383 | #ident |
384 | }; |
385 | } |
386 | } |
387 | } |
388 | } |
389 | }) |
390 | .collect::<Vec<_>>(); |
391 | |
392 | if let Some(span) = self_span { |
393 | let mut replace_self = ReplaceSelf(span); |
394 | replace_self.visit_block_mut(block); |
395 | } |
396 | |
397 | let stmts = &block.stmts; |
398 | let let_ret = match &mut sig.output { |
399 | ReturnType::Default => quote_spanned! {block.brace_token.span=> |
400 | #(#decls)* |
401 | let () = { #(#stmts)* }; |
402 | }, |
403 | ReturnType::Type(_, ret) => { |
404 | if contains_associated_type_impl_trait(context, ret) { |
405 | if decls.is_empty() { |
406 | quote!(#(#stmts)*) |
407 | } else { |
408 | quote!(#(#decls)* { #(#stmts)* }) |
409 | } |
410 | } else { |
411 | quote! { |
412 | if let ::core::option::Option::Some(__ret) = ::core::option::Option::None::<#ret> { |
413 | return __ret; |
414 | } |
415 | #(#decls)* |
416 | let __ret: #ret = { #(#stmts)* }; |
417 | #[allow(unreachable_code)] |
418 | __ret |
419 | } |
420 | } |
421 | } |
422 | }; |
423 | let box_pin = quote_spanned!(block.brace_token.span=> |
424 | Box::pin(async move { #let_ret }) |
425 | ); |
426 | block.stmts = parse_quote!(#box_pin); |
427 | } |
428 | |
429 | fn positional_arg(i: usize, pat: &Pat) -> Ident { |
430 | let span: Span = syn::spanned::Spanned::span(self:pat); |
431 | #[cfg (not(no_span_mixed_site))] |
432 | let span: Span = span.resolved_at(Span::mixed_site()); |
433 | format_ident!("__arg {}" , i, span = span) |
434 | } |
435 | |
436 | fn contains_associated_type_impl_trait(context: Context, ret: &mut Type) -> bool { |
437 | struct AssociatedTypeImplTraits<'a> { |
438 | set: &'a Set<Ident>, |
439 | contains: bool, |
440 | } |
441 | |
442 | impl<'a> VisitMut for AssociatedTypeImplTraits<'a> { |
443 | fn visit_type_path_mut(&mut self, ty: &mut TypePath) { |
444 | if ty.qself.is_none() |
445 | && ty.path.segments.len() == 2 |
446 | && ty.path.segments[0].ident == "Self" |
447 | && self.set.contains(&ty.path.segments[1].ident) |
448 | { |
449 | self.contains = true; |
450 | } |
451 | visit_mut::visit_type_path_mut(self, ty); |
452 | } |
453 | } |
454 | |
455 | match context { |
456 | Context::Trait { .. } => false, |
457 | Context::Impl { |
458 | associated_type_impl_traits, |
459 | .. |
460 | } => { |
461 | let mut visit = AssociatedTypeImplTraits { |
462 | set: associated_type_impl_traits, |
463 | contains: false, |
464 | }; |
465 | visit.visit_type_mut(ret); |
466 | visit.contains |
467 | } |
468 | } |
469 | } |
470 | |
471 | fn where_clause_or_default(clause: &mut Option<WhereClause>) -> &mut WhereClause { |
472 | clause.get_or_insert_with(|| WhereClause { |
473 | where_token: Default::default(), |
474 | predicates: Punctuated::new(), |
475 | }) |
476 | } |
477 | |