1 | use std::iter; |
2 | |
3 | use proc_macro2::TokenStream; |
4 | use quote::{quote, quote_spanned, ToTokens}; |
5 | use syn::visit_mut::VisitMut; |
6 | use syn::{ |
7 | punctuated::Punctuated, spanned::Spanned, Block, Expr, ExprAsync, ExprCall, FieldPat, FnArg, |
8 | Ident, Item, ItemFn, Pat, PatIdent, PatReference, PatStruct, PatTuple, PatTupleStruct, PatType, |
9 | Path, ReturnType, Signature, Stmt, Token, Type, TypePath, |
10 | }; |
11 | |
12 | use crate::{ |
13 | attr::{Field, Fields, FormatMode, InstrumentArgs, Level}, |
14 | MaybeItemFn, MaybeItemFnRef, |
15 | }; |
16 | |
17 | /// Given an existing function, generate an instrumented version of that function |
18 | pub(crate) fn gen_function<'a, B: ToTokens + 'a>( |
19 | input: MaybeItemFnRef<'a, B>, |
20 | args: InstrumentArgs, |
21 | instrumented_function_name: &str, |
22 | self_type: Option<&TypePath>, |
23 | ) -> proc_macro2::TokenStream { |
24 | // these are needed ahead of time, as ItemFn contains the function body _and_ |
25 | // isn't representable inside a quote!/quote_spanned! macro |
26 | // (Syn's ToTokens isn't implemented for ItemFn) |
27 | let MaybeItemFnRef { |
28 | outer_attrs, |
29 | inner_attrs, |
30 | vis, |
31 | sig, |
32 | block, |
33 | } = input; |
34 | |
35 | let Signature { |
36 | output, |
37 | inputs: params, |
38 | unsafety, |
39 | asyncness, |
40 | constness, |
41 | abi, |
42 | ident, |
43 | generics: |
44 | syn::Generics { |
45 | params: gen_params, |
46 | where_clause, |
47 | .. |
48 | }, |
49 | .. |
50 | } = sig; |
51 | |
52 | let warnings = args.warnings(); |
53 | |
54 | let (return_type, return_span) = if let ReturnType::Type(_, return_type) = &output { |
55 | (erase_impl_trait(return_type), return_type.span()) |
56 | } else { |
57 | // Point at function name if we don't have an explicit return type |
58 | (syn::parse_quote! { () }, ident.span()) |
59 | }; |
60 | // Install a fake return statement as the first thing in the function |
61 | // body, so that we eagerly infer that the return type is what we |
62 | // declared in the async fn signature. |
63 | // The `#[allow(..)]` is given because the return statement is |
64 | // unreachable, but does affect inference, so it needs to be written |
65 | // exactly that way for it to do its magic. |
66 | let fake_return_edge = quote_spanned! {return_span=> |
67 | #[allow( |
68 | unknown_lints, unreachable_code, clippy::diverging_sub_expression, |
69 | clippy::let_unit_value, clippy::unreachable, clippy::let_with_type_underscore, |
70 | clippy::empty_loop |
71 | )] |
72 | if false { |
73 | let __tracing_attr_fake_return: #return_type = loop {}; |
74 | return __tracing_attr_fake_return; |
75 | } |
76 | }; |
77 | let block = quote! { |
78 | { |
79 | #fake_return_edge |
80 | #block |
81 | } |
82 | }; |
83 | |
84 | let body = gen_block( |
85 | &block, |
86 | params, |
87 | asyncness.is_some(), |
88 | args, |
89 | instrumented_function_name, |
90 | self_type, |
91 | ); |
92 | |
93 | quote!( |
94 | #(#outer_attrs) * |
95 | #vis #constness #unsafety #asyncness #abi fn #ident<#gen_params>(#params) #output |
96 | #where_clause |
97 | { |
98 | #(#inner_attrs) * |
99 | #warnings |
100 | #body |
101 | } |
102 | ) |
103 | } |
104 | |
105 | /// Instrument a block |
106 | fn gen_block<B: ToTokens>( |
107 | block: &B, |
108 | params: &Punctuated<FnArg, Token![,]>, |
109 | async_context: bool, |
110 | mut args: InstrumentArgs, |
111 | instrumented_function_name: &str, |
112 | self_type: Option<&TypePath>, |
113 | ) -> proc_macro2::TokenStream { |
114 | // generate the span's name |
115 | let span_name = args |
116 | // did the user override the span's name? |
117 | .name |
118 | .as_ref() |
119 | .map(|name| quote!(#name)) |
120 | .unwrap_or_else(|| quote!(#instrumented_function_name)); |
121 | |
122 | let args_level = args.level(); |
123 | let level = args_level.clone(); |
124 | |
125 | let follows_from = args.follows_from.iter(); |
126 | let follows_from = quote! { |
127 | #(for cause in #follows_from { |
128 | __tracing_attr_span.follows_from(cause); |
129 | })* |
130 | }; |
131 | |
132 | // generate this inside a closure, so we can return early on errors. |
133 | let span = (|| { |
134 | // Pull out the arguments-to-be-skipped first, so we can filter results |
135 | // below. |
136 | let param_names: Vec<(Ident, (Ident, RecordType))> = params |
137 | .clone() |
138 | .into_iter() |
139 | .flat_map(|param| match param { |
140 | FnArg::Typed(PatType { pat, ty, .. }) => { |
141 | param_names(*pat, RecordType::parse_from_ty(&ty)) |
142 | } |
143 | FnArg::Receiver(_) => Box::new(iter::once(( |
144 | Ident::new("self" , param.span()), |
145 | RecordType::Debug, |
146 | ))), |
147 | }) |
148 | // Little dance with new (user-exposed) names and old (internal) |
149 | // names of identifiers. That way, we could do the following |
150 | // even though async_trait (<=0.1.43) rewrites "self" as "_self": |
151 | // ``` |
152 | // #[async_trait] |
153 | // impl Foo for FooImpl { |
154 | // #[instrument(skip(self))] |
155 | // async fn foo(&self, v: usize) {} |
156 | // } |
157 | // ``` |
158 | .map(|(x, record_type)| { |
159 | // if we are inside a function generated by async-trait <=0.1.43, we need to |
160 | // take care to rewrite "_self" as "self" for 'user convenience' |
161 | if self_type.is_some() && x == "_self" { |
162 | (Ident::new("self" , x.span()), (x, record_type)) |
163 | } else { |
164 | (x.clone(), (x, record_type)) |
165 | } |
166 | }) |
167 | .collect(); |
168 | |
169 | for skip in &args.skips { |
170 | if !param_names.iter().map(|(user, _)| user).any(|y| y == skip) { |
171 | return quote_spanned! {skip.span()=> |
172 | compile_error!("attempting to skip non-existent parameter" ) |
173 | }; |
174 | } |
175 | } |
176 | |
177 | let target = args.target(); |
178 | |
179 | let parent = args.parent.iter(); |
180 | |
181 | // filter out skipped fields |
182 | let quoted_fields: Vec<_> = param_names |
183 | .iter() |
184 | .filter(|(param, _)| { |
185 | if args.skip_all || args.skips.contains(param) { |
186 | return false; |
187 | } |
188 | |
189 | // If any parameters have the same name as a custom field, skip |
190 | // and allow them to be formatted by the custom field. |
191 | if let Some(ref fields) = args.fields { |
192 | fields.0.iter().all(|Field { ref name, .. }| { |
193 | let first = name.first(); |
194 | first != name.last() || !first.iter().any(|name| name == ¶m) |
195 | }) |
196 | } else { |
197 | true |
198 | } |
199 | }) |
200 | .map(|(user_name, (real_name, record_type))| match record_type { |
201 | RecordType::Value => quote!(#user_name = #real_name), |
202 | RecordType::Debug => quote!(#user_name = tracing::field::debug(real_name)), |
203 | }) |
204 | .collect(); |
205 | |
206 | // replace every use of a variable with its original name |
207 | if let Some(Fields(ref mut fields)) = args.fields { |
208 | let mut replacer = IdentAndTypesRenamer { |
209 | idents: param_names.into_iter().map(|(a, (b, _))| (a, b)).collect(), |
210 | types: Vec::new(), |
211 | }; |
212 | |
213 | // when async-trait <=0.1.43 is in use, replace instances |
214 | // of the "Self" type inside the fields values |
215 | if let Some(self_type) = self_type { |
216 | replacer.types.push(("Self" , self_type.clone())); |
217 | } |
218 | |
219 | for e in fields.iter_mut().filter_map(|f| f.value.as_mut()) { |
220 | syn::visit_mut::visit_expr_mut(&mut replacer, e); |
221 | } |
222 | } |
223 | |
224 | let custom_fields = &args.fields; |
225 | |
226 | quote!(tracing::span!( |
227 | target: #target, |
228 | #(parent: #parent,)* |
229 | #level, |
230 | #span_name, |
231 | #(#quoted_fields,)* |
232 | #custom_fields |
233 | |
234 | )) |
235 | })(); |
236 | |
237 | let target = args.target(); |
238 | |
239 | let err_event = match args.err_args { |
240 | Some(event_args) => { |
241 | let level_tokens = event_args.level(Level::Error); |
242 | match event_args.mode { |
243 | FormatMode::Default | FormatMode::Display => Some(quote!( |
244 | tracing::event!(target: #target, #level_tokens, error = %e) |
245 | )), |
246 | FormatMode::Debug => Some(quote!( |
247 | tracing::event!(target: #target, #level_tokens, error = ?e) |
248 | )), |
249 | } |
250 | } |
251 | _ => None, |
252 | }; |
253 | |
254 | let ret_event = match args.ret_args { |
255 | Some(event_args) => { |
256 | let level_tokens = event_args.level(args_level); |
257 | match event_args.mode { |
258 | FormatMode::Display => Some(quote!( |
259 | tracing::event!(target: #target, #level_tokens, return = %x) |
260 | )), |
261 | FormatMode::Default | FormatMode::Debug => Some(quote!( |
262 | tracing::event!(target: #target, #level_tokens, return = ?x) |
263 | )), |
264 | } |
265 | } |
266 | _ => None, |
267 | }; |
268 | |
269 | // Generate the instrumented function body. |
270 | // If the function is an `async fn`, this will wrap it in an async block, |
271 | // which is `instrument`ed using `tracing-futures`. Otherwise, this will |
272 | // enter the span and then perform the rest of the body. |
273 | // If `err` is in args, instrument any resulting `Err`s. |
274 | // If `ret` is in args, instrument any resulting `Ok`s when the function |
275 | // returns `Result`s, otherwise instrument any resulting values. |
276 | if async_context { |
277 | let mk_fut = match (err_event, ret_event) { |
278 | (Some(err_event), Some(ret_event)) => quote_spanned!(block.span()=> |
279 | async move { |
280 | match async move #block.await { |
281 | #[allow(clippy::unit_arg)] |
282 | Ok(x) => { |
283 | #ret_event; |
284 | Ok(x) |
285 | }, |
286 | Err(e) => { |
287 | #err_event; |
288 | Err(e) |
289 | } |
290 | } |
291 | } |
292 | ), |
293 | (Some(err_event), None) => quote_spanned!(block.span()=> |
294 | async move { |
295 | match async move #block.await { |
296 | #[allow(clippy::unit_arg)] |
297 | Ok(x) => Ok(x), |
298 | Err(e) => { |
299 | #err_event; |
300 | Err(e) |
301 | } |
302 | } |
303 | } |
304 | ), |
305 | (None, Some(ret_event)) => quote_spanned!(block.span()=> |
306 | async move { |
307 | let x = async move #block.await; |
308 | #ret_event; |
309 | x |
310 | } |
311 | ), |
312 | (None, None) => quote_spanned!(block.span()=> |
313 | async move #block |
314 | ), |
315 | }; |
316 | |
317 | return quote!( |
318 | let __tracing_attr_span = #span; |
319 | let __tracing_instrument_future = #mk_fut; |
320 | if !__tracing_attr_span.is_disabled() { |
321 | #follows_from |
322 | tracing::Instrument::instrument( |
323 | __tracing_instrument_future, |
324 | __tracing_attr_span |
325 | ) |
326 | .await |
327 | } else { |
328 | __tracing_instrument_future.await |
329 | } |
330 | ); |
331 | } |
332 | |
333 | let span = quote!( |
334 | // These variables are left uninitialized and initialized only |
335 | // if the tracing level is statically enabled at this point. |
336 | // While the tracing level is also checked at span creation |
337 | // time, that will still create a dummy span, and a dummy guard |
338 | // and drop the dummy guard later. By lazily initializing these |
339 | // variables, Rust will generate a drop flag for them and thus |
340 | // only drop the guard if it was created. This creates code that |
341 | // is very straightforward for LLVM to optimize out if the tracing |
342 | // level is statically disabled, while not causing any performance |
343 | // regression in case the level is enabled. |
344 | let __tracing_attr_span; |
345 | let __tracing_attr_guard; |
346 | if tracing::level_enabled!(#level) || tracing::if_log_enabled!(#level, {true} else {false}) { |
347 | __tracing_attr_span = #span; |
348 | #follows_from |
349 | __tracing_attr_guard = __tracing_attr_span.enter(); |
350 | } |
351 | ); |
352 | |
353 | match (err_event, ret_event) { |
354 | (Some(err_event), Some(ret_event)) => quote_spanned! {block.span()=> |
355 | #span |
356 | #[allow(clippy::redundant_closure_call)] |
357 | match (move || #block)() { |
358 | #[allow(clippy::unit_arg)] |
359 | Ok(x) => { |
360 | #ret_event; |
361 | Ok(x) |
362 | }, |
363 | Err(e) => { |
364 | #err_event; |
365 | Err(e) |
366 | } |
367 | } |
368 | }, |
369 | (Some(err_event), None) => quote_spanned!(block.span()=> |
370 | #span |
371 | #[allow(clippy::redundant_closure_call)] |
372 | match (move || #block)() { |
373 | #[allow(clippy::unit_arg)] |
374 | Ok(x) => Ok(x), |
375 | Err(e) => { |
376 | #err_event; |
377 | Err(e) |
378 | } |
379 | } |
380 | ), |
381 | (None, Some(ret_event)) => quote_spanned!(block.span()=> |
382 | #span |
383 | #[allow(clippy::redundant_closure_call)] |
384 | let x = (move || #block)(); |
385 | #ret_event; |
386 | x |
387 | ), |
388 | (None, None) => quote_spanned!(block.span() => |
389 | // Because `quote` produces a stream of tokens _without_ whitespace, the |
390 | // `if` and the block will appear directly next to each other. This |
391 | // generates a clippy lint about suspicious `if/else` formatting. |
392 | // Therefore, suppress the lint inside the generated code... |
393 | #[allow(clippy::suspicious_else_formatting)] |
394 | { |
395 | #span |
396 | // ...but turn the lint back on inside the function body. |
397 | #[warn(clippy::suspicious_else_formatting)] |
398 | #block |
399 | } |
400 | ), |
401 | } |
402 | } |
403 | |
404 | /// Indicates whether a field should be recorded as `Value` or `Debug`. |
405 | enum RecordType { |
406 | /// The field should be recorded using its `Value` implementation. |
407 | Value, |
408 | /// The field should be recorded using `tracing::field::debug()`. |
409 | Debug, |
410 | } |
411 | |
412 | impl RecordType { |
413 | /// Array of primitive types which should be recorded as [RecordType::Value]. |
414 | const TYPES_FOR_VALUE: &'static [&'static str] = &[ |
415 | "bool" , |
416 | "str" , |
417 | "u8" , |
418 | "i8" , |
419 | "u16" , |
420 | "i16" , |
421 | "u32" , |
422 | "i32" , |
423 | "u64" , |
424 | "i64" , |
425 | "f32" , |
426 | "f64" , |
427 | "usize" , |
428 | "isize" , |
429 | "NonZeroU8" , |
430 | "NonZeroI8" , |
431 | "NonZeroU16" , |
432 | "NonZeroI16" , |
433 | "NonZeroU32" , |
434 | "NonZeroI32" , |
435 | "NonZeroU64" , |
436 | "NonZeroI64" , |
437 | "NonZeroUsize" , |
438 | "NonZeroIsize" , |
439 | "Wrapping" , |
440 | ]; |
441 | |
442 | /// Parse `RecordType` from [Type] by looking up |
443 | /// the [RecordType::TYPES_FOR_VALUE] array. |
444 | fn parse_from_ty(ty: &Type) -> Self { |
445 | match ty { |
446 | Type::Path(TypePath { path, .. }) |
447 | if path |
448 | .segments |
449 | .iter() |
450 | .last() |
451 | .map(|path_segment| { |
452 | let ident = path_segment.ident.to_string(); |
453 | Self::TYPES_FOR_VALUE.iter().any(|&t| t == ident) |
454 | }) |
455 | .unwrap_or(false) => |
456 | { |
457 | RecordType::Value |
458 | } |
459 | Type::Reference(syn::TypeReference { elem, .. }) => RecordType::parse_from_ty(elem), |
460 | _ => RecordType::Debug, |
461 | } |
462 | } |
463 | } |
464 | |
465 | fn param_names(pat: Pat, record_type: RecordType) -> Box<dyn Iterator<Item = (Ident, RecordType)>> { |
466 | match pat { |
467 | Pat::Ident(PatIdent { ident, .. }) => Box::new(iter::once((ident, record_type))), |
468 | Pat::Reference(PatReference { pat, .. }) => param_names(*pat, record_type), |
469 | // We can't get the concrete type of fields in the struct/tuple |
470 | // patterns by using `syn`. e.g. `fn foo(Foo { x, y }: Foo) {}`. |
471 | // Therefore, the struct/tuple patterns in the arguments will just |
472 | // always be recorded as `RecordType::Debug`. |
473 | Pat::Struct(PatStruct { fields, .. }) => Box::new( |
474 | fields |
475 | .into_iter() |
476 | .flat_map(|FieldPat { pat, .. }| param_names(*pat, RecordType::Debug)), |
477 | ), |
478 | Pat::Tuple(PatTuple { elems, .. }) => Box::new( |
479 | elems |
480 | .into_iter() |
481 | .flat_map(|p| param_names(p, RecordType::Debug)), |
482 | ), |
483 | Pat::TupleStruct(PatTupleStruct { elems, .. }) => Box::new( |
484 | elems |
485 | .into_iter() |
486 | .flat_map(|p| param_names(p, RecordType::Debug)), |
487 | ), |
488 | |
489 | // The above *should* cover all cases of irrefutable patterns, |
490 | // but we purposefully don't do any funny business here |
491 | // (such as panicking) because that would obscure rustc's |
492 | // much more informative error message. |
493 | _ => Box::new(iter::empty()), |
494 | } |
495 | } |
496 | |
497 | /// The specific async code pattern that was detected |
498 | enum AsyncKind<'a> { |
499 | /// Immediately-invoked async fn, as generated by `async-trait <= 0.1.43`: |
500 | /// `async fn foo<...>(...) {...}; Box::pin(foo<...>(...))` |
501 | Function(&'a ItemFn), |
502 | /// A function returning an async (move) block, optionally `Box::pin`-ed, |
503 | /// as generated by `async-trait >= 0.1.44`: |
504 | /// `Box::pin(async move { ... })` |
505 | Async { |
506 | async_expr: &'a ExprAsync, |
507 | pinned_box: bool, |
508 | }, |
509 | } |
510 | |
511 | pub(crate) struct AsyncInfo<'block> { |
512 | // statement that must be patched |
513 | source_stmt: &'block Stmt, |
514 | kind: AsyncKind<'block>, |
515 | self_type: Option<TypePath>, |
516 | input: &'block ItemFn, |
517 | } |
518 | |
519 | impl<'block> AsyncInfo<'block> { |
520 | /// Get the AST of the inner function we need to hook, if it looks like a |
521 | /// manual future implementation. |
522 | /// |
523 | /// When we are given a function that returns a (pinned) future containing the |
524 | /// user logic, it is that (pinned) future that needs to be instrumented. |
525 | /// Were we to instrument its parent, we would only collect information |
526 | /// regarding the allocation of that future, and not its own span of execution. |
527 | /// |
528 | /// We inspect the block of the function to find if it matches any of the |
529 | /// following patterns: |
530 | /// |
531 | /// - Immediately-invoked async fn, as generated by `async-trait <= 0.1.43`: |
532 | /// `async fn foo<...>(...) {...}; Box::pin(foo<...>(...))` |
533 | /// |
534 | /// - A function returning an async (move) block, optionally `Box::pin`-ed, |
535 | /// as generated by `async-trait >= 0.1.44`: |
536 | /// `Box::pin(async move { ... })` |
537 | /// |
538 | /// We the return the statement that must be instrumented, along with some |
539 | /// other information. |
540 | /// 'gen_body' will then be able to use that information to instrument the |
541 | /// proper function/future. |
542 | /// |
543 | /// (this follows the approach suggested in |
544 | /// https://github.com/dtolnay/async-trait/issues/45#issuecomment-571245673) |
545 | pub(crate) fn from_fn(input: &'block ItemFn) -> Option<Self> { |
546 | // are we in an async context? If yes, this isn't a manual async-like pattern |
547 | if input.sig.asyncness.is_some() { |
548 | return None; |
549 | } |
550 | |
551 | let block = &input.block; |
552 | |
553 | // list of async functions declared inside the block |
554 | let inside_funs = block.stmts.iter().filter_map(|stmt| { |
555 | if let Stmt::Item(Item::Fn(fun)) = &stmt { |
556 | // If the function is async, this is a candidate |
557 | if fun.sig.asyncness.is_some() { |
558 | return Some((stmt, fun)); |
559 | } |
560 | } |
561 | None |
562 | }); |
563 | |
564 | // last expression of the block: it determines the return value of the |
565 | // block, this is quite likely a `Box::pin` statement or an async block |
566 | let (last_expr_stmt, last_expr) = block.stmts.iter().rev().find_map(|stmt| { |
567 | if let Stmt::Expr(expr, _semi) = stmt { |
568 | Some((stmt, expr)) |
569 | } else { |
570 | None |
571 | } |
572 | })?; |
573 | |
574 | // is the last expression an async block? |
575 | if let Expr::Async(async_expr) = last_expr { |
576 | return Some(AsyncInfo { |
577 | source_stmt: last_expr_stmt, |
578 | kind: AsyncKind::Async { |
579 | async_expr, |
580 | pinned_box: false, |
581 | }, |
582 | self_type: None, |
583 | input, |
584 | }); |
585 | } |
586 | |
587 | // is the last expression a function call? |
588 | let (outside_func, outside_args) = match last_expr { |
589 | Expr::Call(ExprCall { func, args, .. }) => (func, args), |
590 | _ => return None, |
591 | }; |
592 | |
593 | // is it a call to `Box::pin()`? |
594 | let path = match outside_func.as_ref() { |
595 | Expr::Path(path) => &path.path, |
596 | _ => return None, |
597 | }; |
598 | if !path_to_string(path).ends_with("Box::pin" ) { |
599 | return None; |
600 | } |
601 | |
602 | // Does the call take an argument? If it doesn't, |
603 | // it's not gonna compile anyway, but that's no reason |
604 | // to (try to) perform an out of bounds access |
605 | if outside_args.is_empty() { |
606 | return None; |
607 | } |
608 | |
609 | // Is the argument to Box::pin an async block that |
610 | // captures its arguments? |
611 | if let Expr::Async(async_expr) = &outside_args[0] { |
612 | return Some(AsyncInfo { |
613 | source_stmt: last_expr_stmt, |
614 | kind: AsyncKind::Async { |
615 | async_expr, |
616 | pinned_box: true, |
617 | }, |
618 | self_type: None, |
619 | input, |
620 | }); |
621 | } |
622 | |
623 | // Is the argument to Box::pin a function call itself? |
624 | let func = match &outside_args[0] { |
625 | Expr::Call(ExprCall { func, .. }) => func, |
626 | _ => return None, |
627 | }; |
628 | |
629 | // "stringify" the path of the function called |
630 | let func_name = match **func { |
631 | Expr::Path(ref func_path) => path_to_string(&func_path.path), |
632 | _ => return None, |
633 | }; |
634 | |
635 | // Was that function defined inside of the current block? |
636 | // If so, retrieve the statement where it was declared and the function itself |
637 | let (stmt_func_declaration, func) = inside_funs |
638 | .into_iter() |
639 | .find(|(_, fun)| fun.sig.ident == func_name)?; |
640 | |
641 | // If "_self" is present as an argument, we store its type to be able to rewrite "Self" (the |
642 | // parameter type) with the type of "_self" |
643 | let mut self_type = None; |
644 | for arg in &func.sig.inputs { |
645 | if let FnArg::Typed(ty) = arg { |
646 | if let Pat::Ident(PatIdent { ref ident, .. }) = *ty.pat { |
647 | if ident == "_self" { |
648 | let mut ty = *ty.ty.clone(); |
649 | // extract the inner type if the argument is "&self" or "&mut self" |
650 | if let Type::Reference(syn::TypeReference { elem, .. }) = ty { |
651 | ty = *elem; |
652 | } |
653 | |
654 | if let Type::Path(tp) = ty { |
655 | self_type = Some(tp); |
656 | break; |
657 | } |
658 | } |
659 | } |
660 | } |
661 | } |
662 | |
663 | Some(AsyncInfo { |
664 | source_stmt: stmt_func_declaration, |
665 | kind: AsyncKind::Function(func), |
666 | self_type, |
667 | input, |
668 | }) |
669 | } |
670 | |
671 | pub(crate) fn gen_async( |
672 | self, |
673 | args: InstrumentArgs, |
674 | instrumented_function_name: &str, |
675 | ) -> Result<proc_macro::TokenStream, syn::Error> { |
676 | // let's rewrite some statements! |
677 | let mut out_stmts: Vec<TokenStream> = self |
678 | .input |
679 | .block |
680 | .stmts |
681 | .iter() |
682 | .map(|stmt| stmt.to_token_stream()) |
683 | .collect(); |
684 | |
685 | if let Some((iter, _stmt)) = self |
686 | .input |
687 | .block |
688 | .stmts |
689 | .iter() |
690 | .enumerate() |
691 | .find(|(_iter, stmt)| *stmt == self.source_stmt) |
692 | { |
693 | // instrument the future by rewriting the corresponding statement |
694 | out_stmts[iter] = match self.kind { |
695 | // `Box::pin(immediately_invoked_async_fn())` |
696 | AsyncKind::Function(fun) => { |
697 | let fun = MaybeItemFn::from(fun.clone()); |
698 | gen_function( |
699 | fun.as_ref(), |
700 | args, |
701 | instrumented_function_name, |
702 | self.self_type.as_ref(), |
703 | ) |
704 | } |
705 | // `async move { ... }`, optionally pinned |
706 | AsyncKind::Async { |
707 | async_expr, |
708 | pinned_box, |
709 | } => { |
710 | let instrumented_block = gen_block( |
711 | &async_expr.block, |
712 | &self.input.sig.inputs, |
713 | true, |
714 | args, |
715 | instrumented_function_name, |
716 | None, |
717 | ); |
718 | let async_attrs = &async_expr.attrs; |
719 | if pinned_box { |
720 | quote! { |
721 | Box::pin(#(#async_attrs) * async move { #instrumented_block }) |
722 | } |
723 | } else { |
724 | quote! { |
725 | #(#async_attrs) * async move { #instrumented_block } |
726 | } |
727 | } |
728 | } |
729 | }; |
730 | } |
731 | |
732 | let vis = &self.input.vis; |
733 | let sig = &self.input.sig; |
734 | let attrs = &self.input.attrs; |
735 | Ok(quote!( |
736 | #(#attrs) * |
737 | #vis #sig { |
738 | #(#out_stmts) * |
739 | } |
740 | ) |
741 | .into()) |
742 | } |
743 | } |
744 | |
745 | // Return a path as a String |
746 | fn path_to_string(path: &Path) -> String { |
747 | use std::fmt::Write; |
748 | // some heuristic to prevent too many allocations |
749 | let mut res = String::with_capacity(path.segments.len() * 5); |
750 | for i in 0..path.segments.len() { |
751 | write!(&mut res, "{}" , path.segments[i].ident) |
752 | .expect("writing to a String should never fail" ); |
753 | if i < path.segments.len() - 1 { |
754 | res.push_str("::" ); |
755 | } |
756 | } |
757 | res |
758 | } |
759 | |
760 | /// A visitor struct to replace idents and types in some piece |
761 | /// of code (e.g. the "self" and "Self" tokens in user-supplied |
762 | /// fields expressions when the function is generated by an old |
763 | /// version of async-trait). |
764 | struct IdentAndTypesRenamer<'a> { |
765 | types: Vec<(&'a str, TypePath)>, |
766 | idents: Vec<(Ident, Ident)>, |
767 | } |
768 | |
769 | impl<'a> VisitMut for IdentAndTypesRenamer<'a> { |
770 | // we deliberately compare strings because we want to ignore the spans |
771 | // If we apply clippy's lint, the behavior changes |
772 | #[allow (clippy::cmp_owned)] |
773 | fn visit_ident_mut(&mut self, id: &mut Ident) { |
774 | for (old_ident, new_ident) in &self.idents { |
775 | if id.to_string() == old_ident.to_string() { |
776 | *id = new_ident.clone(); |
777 | } |
778 | } |
779 | } |
780 | |
781 | fn visit_type_mut(&mut self, ty: &mut Type) { |
782 | for (type_name, new_type) in &self.types { |
783 | if let Type::Path(TypePath { path, .. }) = ty { |
784 | if path_to_string(path) == *type_name { |
785 | *ty = Type::Path(new_type.clone()); |
786 | } |
787 | } |
788 | } |
789 | } |
790 | } |
791 | |
792 | // A visitor struct that replace an async block by its patched version |
793 | struct AsyncTraitBlockReplacer<'a> { |
794 | block: &'a Block, |
795 | patched_block: Block, |
796 | } |
797 | |
798 | impl<'a> VisitMut for AsyncTraitBlockReplacer<'a> { |
799 | fn visit_block_mut(&mut self, i: &mut Block) { |
800 | if i == self.block { |
801 | *i = self.patched_block.clone(); |
802 | } |
803 | } |
804 | } |
805 | |
806 | // Replaces any `impl Trait` with `_` so it can be used as the type in |
807 | // a `let` statement's LHS. |
808 | struct ImplTraitEraser; |
809 | |
810 | impl VisitMut for ImplTraitEraser { |
811 | fn visit_type_mut(&mut self, t: &mut Type) { |
812 | if let Type::ImplTrait(..) = t { |
813 | *t = syn::TypeInfer { |
814 | underscore_token: Token![_](t.span()), |
815 | } |
816 | .into(); |
817 | } else { |
818 | syn::visit_mut::visit_type_mut(self, t); |
819 | } |
820 | } |
821 | } |
822 | |
823 | fn erase_impl_trait(ty: &Type) -> Type { |
824 | let mut ty = ty.clone(); |
825 | ImplTraitEraser.visit_type_mut(&mut ty); |
826 | ty |
827 | } |
828 | |