1 | use std::iter; |
2 | |
3 | use proc_macro2::TokenStream; |
4 | use quote::{quote, quote_spanned, ToTokens}; |
5 | use syn::visit_mut::VisitMut; |
6 | use syn::{ |
7 | punctuated::Punctuated, spanned::Spanned, Block, Expr, ExprAsync, ExprCall, FieldPat, FnArg, |
8 | Ident, Item, ItemFn, Pat, PatIdent, PatReference, PatStruct, PatTuple, PatTupleStruct, PatType, |
9 | Path, ReturnType, Signature, Stmt, Token, Type, TypePath, |
10 | }; |
11 | |
12 | use crate::{ |
13 | attr::{Field, Fields, FormatMode, InstrumentArgs, Level}, |
14 | MaybeItemFn, MaybeItemFnRef, |
15 | }; |
16 | |
17 | /// Given an existing function, generate an instrumented version of that function |
18 | pub(crate) fn gen_function<'a, B: ToTokens + 'a>( |
19 | input: MaybeItemFnRef<'a, B>, |
20 | args: InstrumentArgs, |
21 | instrumented_function_name: &str, |
22 | self_type: Option<&TypePath>, |
23 | ) -> proc_macro2::TokenStream { |
24 | // these are needed ahead of time, as ItemFn contains the function body _and_ |
25 | // isn't representable inside a quote!/quote_spanned! macro |
26 | // (Syn's ToTokens isn't implemented for ItemFn) |
27 | let MaybeItemFnRef { |
28 | outer_attrs, |
29 | inner_attrs, |
30 | vis, |
31 | sig, |
32 | block, |
33 | } = input; |
34 | |
35 | let Signature { |
36 | output, |
37 | inputs: params, |
38 | unsafety, |
39 | asyncness, |
40 | constness, |
41 | abi, |
42 | ident, |
43 | generics: |
44 | syn::Generics { |
45 | params: gen_params, |
46 | where_clause, |
47 | .. |
48 | }, |
49 | .. |
50 | } = sig; |
51 | |
52 | let warnings = args.warnings(); |
53 | |
54 | let (return_type, return_span) = if let ReturnType::Type(_, return_type) = &output { |
55 | (erase_impl_trait(return_type), return_type.span()) |
56 | } else { |
57 | // Point at function name if we don't have an explicit return type |
58 | (syn::parse_quote! { () }, ident.span()) |
59 | }; |
60 | // Install a fake return statement as the first thing in the function |
61 | // body, so that we eagerly infer that the return type is what we |
62 | // declared in the async fn signature. |
63 | // The `#[allow(..)]` is given because the return statement is |
64 | // unreachable, but does affect inference, so it needs to be written |
65 | // exactly that way for it to do its magic. |
66 | let fake_return_edge = quote_spanned! {return_span=> |
67 | #[allow( |
68 | unknown_lints, unreachable_code, clippy::diverging_sub_expression, |
69 | clippy::let_unit_value, clippy::unreachable, clippy::let_with_type_underscore, |
70 | clippy::empty_loop |
71 | )] |
72 | if false { |
73 | let __tracing_attr_fake_return: #return_type = loop {}; |
74 | return __tracing_attr_fake_return; |
75 | } |
76 | }; |
77 | let block = quote! { |
78 | { |
79 | #fake_return_edge |
80 | #block |
81 | } |
82 | }; |
83 | |
84 | let body = gen_block( |
85 | &block, |
86 | params, |
87 | asyncness.is_some(), |
88 | args, |
89 | instrumented_function_name, |
90 | self_type, |
91 | ); |
92 | |
93 | quote!( |
94 | #(#outer_attrs) * |
95 | #vis #constness #asyncness #unsafety #abi fn #ident<#gen_params>(#params) #output |
96 | #where_clause |
97 | { |
98 | #(#inner_attrs) * |
99 | #warnings |
100 | #body |
101 | } |
102 | ) |
103 | } |
104 | |
105 | /// Instrument a block |
106 | fn gen_block<B: ToTokens>( |
107 | block: &B, |
108 | params: &Punctuated<FnArg, Token![,]>, |
109 | async_context: bool, |
110 | mut args: InstrumentArgs, |
111 | instrumented_function_name: &str, |
112 | self_type: Option<&TypePath>, |
113 | ) -> proc_macro2::TokenStream { |
114 | // generate the span's name |
115 | let span_name = args |
116 | // did the user override the span's name? |
117 | .name |
118 | .as_ref() |
119 | .map(|name| quote!(#name)) |
120 | .unwrap_or_else(|| quote!(#instrumented_function_name)); |
121 | |
122 | let args_level = args.level(); |
123 | let level = args_level.clone(); |
124 | |
125 | let follows_from = args.follows_from.iter(); |
126 | let follows_from = quote! { |
127 | #(for cause in #follows_from { |
128 | __tracing_attr_span.follows_from(cause); |
129 | })* |
130 | }; |
131 | |
132 | // generate this inside a closure, so we can return early on errors. |
133 | let span = (|| { |
134 | // Pull out the arguments-to-be-skipped first, so we can filter results |
135 | // below. |
136 | let param_names: Vec<(Ident, (Ident, RecordType))> = params |
137 | .clone() |
138 | .into_iter() |
139 | .flat_map(|param| match param { |
140 | FnArg::Typed(PatType { pat, ty, .. }) => { |
141 | param_names(*pat, RecordType::parse_from_ty(&ty)) |
142 | } |
143 | FnArg::Receiver(_) => Box::new(iter::once(( |
144 | Ident::new("self" , param.span()), |
145 | RecordType::Debug, |
146 | ))), |
147 | }) |
148 | // Little dance with new (user-exposed) names and old (internal) |
149 | // names of identifiers. That way, we could do the following |
150 | // even though async_trait (<=0.1.43) rewrites "self" as "_self": |
151 | // ``` |
152 | // #[async_trait] |
153 | // impl Foo for FooImpl { |
154 | // #[instrument(skip(self))] |
155 | // async fn foo(&self, v: usize) {} |
156 | // } |
157 | // ``` |
158 | .map(|(x, record_type)| { |
159 | // if we are inside a function generated by async-trait <=0.1.43, we need to |
160 | // take care to rewrite "_self" as "self" for 'user convenience' |
161 | if self_type.is_some() && x == "_self" { |
162 | (Ident::new("self" , x.span()), (x, record_type)) |
163 | } else { |
164 | (x.clone(), (x, record_type)) |
165 | } |
166 | }) |
167 | .collect(); |
168 | |
169 | for skip in &args.skips { |
170 | if !param_names.iter().map(|(user, _)| user).any(|y| y == skip) { |
171 | return quote_spanned! {skip.span()=> |
172 | compile_error!("attempting to skip non-existent parameter" ) |
173 | }; |
174 | } |
175 | } |
176 | |
177 | let target = args.target(); |
178 | |
179 | let parent = args.parent.iter(); |
180 | |
181 | // filter out skipped fields |
182 | let quoted_fields: Vec<_> = param_names |
183 | .iter() |
184 | .filter(|(param, _)| { |
185 | if args.skip_all || args.skips.contains(param) { |
186 | return false; |
187 | } |
188 | |
189 | // If any parameters have the same name as a custom field, skip |
190 | // and allow them to be formatted by the custom field. |
191 | if let Some(ref fields) = args.fields { |
192 | fields.0.iter().all(|Field { ref name, .. }| { |
193 | let first = name.first(); |
194 | first != name.last() || !first.iter().any(|name| name == ¶m) |
195 | }) |
196 | } else { |
197 | true |
198 | } |
199 | }) |
200 | .map(|(user_name, (real_name, record_type))| match record_type { |
201 | RecordType::Value => quote!(#user_name = #real_name), |
202 | RecordType::Debug => quote!(#user_name = tracing::field::debug(&#real_name)), |
203 | }) |
204 | .collect(); |
205 | |
206 | // replace every use of a variable with its original name |
207 | if let Some(Fields(ref mut fields)) = args.fields { |
208 | let mut replacer = IdentAndTypesRenamer { |
209 | idents: param_names.into_iter().map(|(a, (b, _))| (a, b)).collect(), |
210 | types: Vec::new(), |
211 | }; |
212 | |
213 | // when async-trait <=0.1.43 is in use, replace instances |
214 | // of the "Self" type inside the fields values |
215 | if let Some(self_type) = self_type { |
216 | replacer.types.push(("Self" , self_type.clone())); |
217 | } |
218 | |
219 | for e in fields.iter_mut().filter_map(|f| f.value.as_mut()) { |
220 | syn::visit_mut::visit_expr_mut(&mut replacer, e); |
221 | } |
222 | } |
223 | |
224 | let custom_fields = &args.fields; |
225 | |
226 | quote!(tracing::span!( |
227 | target: #target, |
228 | #(parent: #parent,)* |
229 | #level, |
230 | #span_name, |
231 | #(#quoted_fields,)* |
232 | #custom_fields |
233 | |
234 | )) |
235 | })(); |
236 | |
237 | let target = args.target(); |
238 | |
239 | let err_event = match args.err_args { |
240 | Some(event_args) => { |
241 | let level_tokens = event_args.level(Level::Error); |
242 | match event_args.mode { |
243 | FormatMode::Default | FormatMode::Display => Some(quote!( |
244 | tracing::event!(target: #target, #level_tokens, error = %e) |
245 | )), |
246 | FormatMode::Debug => Some(quote!( |
247 | tracing::event!(target: #target, #level_tokens, error = ?e) |
248 | )), |
249 | } |
250 | } |
251 | _ => None, |
252 | }; |
253 | |
254 | let ret_event = match args.ret_args { |
255 | Some(event_args) => { |
256 | let level_tokens = event_args.level(args_level); |
257 | match event_args.mode { |
258 | FormatMode::Display => Some(quote!( |
259 | tracing::event!(target: #target, #level_tokens, return = %x) |
260 | )), |
261 | FormatMode::Default | FormatMode::Debug => Some(quote!( |
262 | tracing::event!(target: #target, #level_tokens, return = ?x) |
263 | )), |
264 | } |
265 | } |
266 | _ => None, |
267 | }; |
268 | |
269 | // Generate the instrumented function body. |
270 | // If the function is an `async fn`, this will wrap it in an async block, |
271 | // which is `instrument`ed using `tracing-futures`. Otherwise, this will |
272 | // enter the span and then perform the rest of the body. |
273 | // If `err` is in args, instrument any resulting `Err`s. |
274 | // If `ret` is in args, instrument any resulting `Ok`s when the function |
275 | // returns `Result`s, otherwise instrument any resulting values. |
276 | if async_context { |
277 | let mk_fut = match (err_event, ret_event) { |
278 | (Some(err_event), Some(ret_event)) => quote_spanned!(block.span()=> |
279 | async move { |
280 | let __match_scrutinee = async move #block.await; |
281 | match __match_scrutinee { |
282 | #[allow(clippy::unit_arg)] |
283 | Ok(x) => { |
284 | #ret_event; |
285 | Ok(x) |
286 | }, |
287 | Err(e) => { |
288 | #err_event; |
289 | Err(e) |
290 | } |
291 | } |
292 | } |
293 | ), |
294 | (Some(err_event), None) => quote_spanned!(block.span()=> |
295 | async move { |
296 | match async move #block.await { |
297 | #[allow(clippy::unit_arg)] |
298 | Ok(x) => Ok(x), |
299 | Err(e) => { |
300 | #err_event; |
301 | Err(e) |
302 | } |
303 | } |
304 | } |
305 | ), |
306 | (None, Some(ret_event)) => quote_spanned!(block.span()=> |
307 | async move { |
308 | let x = async move #block.await; |
309 | #ret_event; |
310 | x |
311 | } |
312 | ), |
313 | (None, None) => quote_spanned!(block.span()=> |
314 | async move #block |
315 | ), |
316 | }; |
317 | |
318 | return quote!( |
319 | let __tracing_attr_span = #span; |
320 | let __tracing_instrument_future = #mk_fut; |
321 | if !__tracing_attr_span.is_disabled() { |
322 | #follows_from |
323 | tracing::Instrument::instrument( |
324 | __tracing_instrument_future, |
325 | __tracing_attr_span |
326 | ) |
327 | .await |
328 | } else { |
329 | __tracing_instrument_future.await |
330 | } |
331 | ); |
332 | } |
333 | |
334 | let span = quote!( |
335 | // These variables are left uninitialized and initialized only |
336 | // if the tracing level is statically enabled at this point. |
337 | // While the tracing level is also checked at span creation |
338 | // time, that will still create a dummy span, and a dummy guard |
339 | // and drop the dummy guard later. By lazily initializing these |
340 | // variables, Rust will generate a drop flag for them and thus |
341 | // only drop the guard if it was created. This creates code that |
342 | // is very straightforward for LLVM to optimize out if the tracing |
343 | // level is statically disabled, while not causing any performance |
344 | // regression in case the level is enabled. |
345 | let __tracing_attr_span; |
346 | let __tracing_attr_guard; |
347 | if tracing::level_enabled!(#level) || tracing::if_log_enabled!(#level, {true} else {false}) { |
348 | __tracing_attr_span = #span; |
349 | #follows_from |
350 | __tracing_attr_guard = __tracing_attr_span.enter(); |
351 | } |
352 | ); |
353 | |
354 | match (err_event, ret_event) { |
355 | (Some(err_event), Some(ret_event)) => quote_spanned! {block.span()=> |
356 | #span |
357 | #[allow(clippy::redundant_closure_call)] |
358 | match (move || #block)() { |
359 | #[allow(clippy::unit_arg)] |
360 | Ok(x) => { |
361 | #ret_event; |
362 | Ok(x) |
363 | }, |
364 | Err(e) => { |
365 | #err_event; |
366 | Err(e) |
367 | } |
368 | } |
369 | }, |
370 | (Some(err_event), None) => quote_spanned!(block.span()=> |
371 | #span |
372 | #[allow(clippy::redundant_closure_call)] |
373 | match (move || #block)() { |
374 | #[allow(clippy::unit_arg)] |
375 | Ok(x) => Ok(x), |
376 | Err(e) => { |
377 | #err_event; |
378 | Err(e) |
379 | } |
380 | } |
381 | ), |
382 | (None, Some(ret_event)) => quote_spanned!(block.span()=> |
383 | #span |
384 | #[allow(clippy::redundant_closure_call)] |
385 | let x = (move || #block)(); |
386 | #ret_event; |
387 | x |
388 | ), |
389 | (None, None) => quote_spanned!(block.span() => |
390 | // Because `quote` produces a stream of tokens _without_ whitespace, the |
391 | // `if` and the block will appear directly next to each other. This |
392 | // generates a clippy lint about suspicious `if/else` formatting. |
393 | // Therefore, suppress the lint inside the generated code... |
394 | #[allow(clippy::suspicious_else_formatting)] |
395 | { |
396 | #span |
397 | // ...but turn the lint back on inside the function body. |
398 | #[warn(clippy::suspicious_else_formatting)] |
399 | #block |
400 | } |
401 | ), |
402 | } |
403 | } |
404 | |
405 | /// Indicates whether a field should be recorded as `Value` or `Debug`. |
406 | enum RecordType { |
407 | /// The field should be recorded using its `Value` implementation. |
408 | Value, |
409 | /// The field should be recorded using `tracing::field::debug()`. |
410 | Debug, |
411 | } |
412 | |
413 | impl RecordType { |
414 | /// Array of primitive types which should be recorded as [RecordType::Value]. |
415 | const TYPES_FOR_VALUE: &'static [&'static str] = &[ |
416 | "bool" , |
417 | "str" , |
418 | "u8" , |
419 | "i8" , |
420 | "u16" , |
421 | "i16" , |
422 | "u32" , |
423 | "i32" , |
424 | "u64" , |
425 | "i64" , |
426 | "u128" , |
427 | "i128" , |
428 | "f32" , |
429 | "f64" , |
430 | "usize" , |
431 | "isize" , |
432 | "String" , |
433 | "NonZeroU8" , |
434 | "NonZeroI8" , |
435 | "NonZeroU16" , |
436 | "NonZeroI16" , |
437 | "NonZeroU32" , |
438 | "NonZeroI32" , |
439 | "NonZeroU64" , |
440 | "NonZeroI64" , |
441 | "NonZeroU128" , |
442 | "NonZeroI128" , |
443 | "NonZeroUsize" , |
444 | "NonZeroIsize" , |
445 | "Wrapping" , |
446 | ]; |
447 | |
448 | /// Parse `RecordType` from [Type] by looking up |
449 | /// the [RecordType::TYPES_FOR_VALUE] array. |
450 | fn parse_from_ty(ty: &Type) -> Self { |
451 | match ty { |
452 | Type::Path(TypePath { path, .. }) |
453 | if path |
454 | .segments |
455 | .iter() |
456 | .last() |
457 | .map(|path_segment| { |
458 | let ident = path_segment.ident.to_string(); |
459 | Self::TYPES_FOR_VALUE.iter().any(|&t| t == ident) |
460 | }) |
461 | .unwrap_or(false) => |
462 | { |
463 | RecordType::Value |
464 | } |
465 | Type::Reference(syn::TypeReference { elem, .. }) => RecordType::parse_from_ty(elem), |
466 | _ => RecordType::Debug, |
467 | } |
468 | } |
469 | } |
470 | |
471 | fn param_names(pat: Pat, record_type: RecordType) -> Box<dyn Iterator<Item = (Ident, RecordType)>> { |
472 | match pat { |
473 | Pat::Ident(PatIdent { ident, .. }) => Box::new(iter::once((ident, record_type))), |
474 | Pat::Reference(PatReference { pat, .. }) => param_names(*pat, record_type), |
475 | // We can't get the concrete type of fields in the struct/tuple |
476 | // patterns by using `syn`. e.g. `fn foo(Foo { x, y }: Foo) {}`. |
477 | // Therefore, the struct/tuple patterns in the arguments will just |
478 | // always be recorded as `RecordType::Debug`. |
479 | Pat::Struct(PatStruct { fields, .. }) => Box::new( |
480 | fields |
481 | .into_iter() |
482 | .flat_map(|FieldPat { pat, .. }| param_names(*pat, RecordType::Debug)), |
483 | ), |
484 | Pat::Tuple(PatTuple { elems, .. }) => Box::new( |
485 | elems |
486 | .into_iter() |
487 | .flat_map(|p| param_names(p, RecordType::Debug)), |
488 | ), |
489 | Pat::TupleStruct(PatTupleStruct { elems, .. }) => Box::new( |
490 | elems |
491 | .into_iter() |
492 | .flat_map(|p| param_names(p, RecordType::Debug)), |
493 | ), |
494 | |
495 | // The above *should* cover all cases of irrefutable patterns, |
496 | // but we purposefully don't do any funny business here |
497 | // (such as panicking) because that would obscure rustc's |
498 | // much more informative error message. |
499 | _ => Box::new(iter::empty()), |
500 | } |
501 | } |
502 | |
503 | /// The specific async code pattern that was detected |
504 | enum AsyncKind<'a> { |
505 | /// Immediately-invoked async fn, as generated by `async-trait <= 0.1.43`: |
506 | /// `async fn foo<...>(...) {...}; Box::pin(foo<...>(...))` |
507 | Function(&'a ItemFn), |
508 | /// A function returning an async (move) block, optionally `Box::pin`-ed, |
509 | /// as generated by `async-trait >= 0.1.44`: |
510 | /// `Box::pin(async move { ... })` |
511 | Async { |
512 | async_expr: &'a ExprAsync, |
513 | pinned_box: bool, |
514 | }, |
515 | } |
516 | |
517 | pub(crate) struct AsyncInfo<'block> { |
518 | // statement that must be patched |
519 | source_stmt: &'block Stmt, |
520 | kind: AsyncKind<'block>, |
521 | self_type: Option<TypePath>, |
522 | input: &'block ItemFn, |
523 | } |
524 | |
525 | impl<'block> AsyncInfo<'block> { |
526 | /// Get the AST of the inner function we need to hook, if it looks like a |
527 | /// manual future implementation. |
528 | /// |
529 | /// When we are given a function that returns a (pinned) future containing the |
530 | /// user logic, it is that (pinned) future that needs to be instrumented. |
531 | /// Were we to instrument its parent, we would only collect information |
532 | /// regarding the allocation of that future, and not its own span of execution. |
533 | /// |
534 | /// We inspect the block of the function to find if it matches any of the |
535 | /// following patterns: |
536 | /// |
537 | /// - Immediately-invoked async fn, as generated by `async-trait <= 0.1.43`: |
538 | /// `async fn foo<...>(...) {...}; Box::pin(foo<...>(...))` |
539 | /// |
540 | /// - A function returning an async (move) block, optionally `Box::pin`-ed, |
541 | /// as generated by `async-trait >= 0.1.44`: |
542 | /// `Box::pin(async move { ... })` |
543 | /// |
544 | /// We the return the statement that must be instrumented, along with some |
545 | /// other information. |
546 | /// 'gen_body' will then be able to use that information to instrument the |
547 | /// proper function/future. |
548 | /// |
549 | /// (this follows the approach suggested in |
550 | /// https://github.com/dtolnay/async-trait/issues/45#issuecomment-571245673) |
551 | pub(crate) fn from_fn(input: &'block ItemFn) -> Option<Self> { |
552 | // are we in an async context? If yes, this isn't a manual async-like pattern |
553 | if input.sig.asyncness.is_some() { |
554 | return None; |
555 | } |
556 | |
557 | let block = &input.block; |
558 | |
559 | // list of async functions declared inside the block |
560 | let inside_funs = block.stmts.iter().filter_map(|stmt| { |
561 | if let Stmt::Item(Item::Fn(fun)) = &stmt { |
562 | // If the function is async, this is a candidate |
563 | if fun.sig.asyncness.is_some() { |
564 | return Some((stmt, fun)); |
565 | } |
566 | } |
567 | None |
568 | }); |
569 | |
570 | // last expression of the block: it determines the return value of the |
571 | // block, this is quite likely a `Box::pin` statement or an async block |
572 | let (last_expr_stmt, last_expr) = block.stmts.iter().rev().find_map(|stmt| { |
573 | if let Stmt::Expr(expr, _semi) = stmt { |
574 | Some((stmt, expr)) |
575 | } else { |
576 | None |
577 | } |
578 | })?; |
579 | |
580 | // is the last expression an async block? |
581 | if let Expr::Async(async_expr) = last_expr { |
582 | return Some(AsyncInfo { |
583 | source_stmt: last_expr_stmt, |
584 | kind: AsyncKind::Async { |
585 | async_expr, |
586 | pinned_box: false, |
587 | }, |
588 | self_type: None, |
589 | input, |
590 | }); |
591 | } |
592 | |
593 | // is the last expression a function call? |
594 | let (outside_func, outside_args) = match last_expr { |
595 | Expr::Call(ExprCall { func, args, .. }) => (func, args), |
596 | _ => return None, |
597 | }; |
598 | |
599 | // is it a call to `Box::pin()`? |
600 | let path = match outside_func.as_ref() { |
601 | Expr::Path(path) => &path.path, |
602 | _ => return None, |
603 | }; |
604 | if !path_to_string(path).ends_with("Box::pin" ) { |
605 | return None; |
606 | } |
607 | |
608 | // Does the call take an argument? If it doesn't, |
609 | // it's not gonna compile anyway, but that's no reason |
610 | // to (try to) perform an out of bounds access |
611 | if outside_args.is_empty() { |
612 | return None; |
613 | } |
614 | |
615 | // Is the argument to Box::pin an async block that |
616 | // captures its arguments? |
617 | if let Expr::Async(async_expr) = &outside_args[0] { |
618 | return Some(AsyncInfo { |
619 | source_stmt: last_expr_stmt, |
620 | kind: AsyncKind::Async { |
621 | async_expr, |
622 | pinned_box: true, |
623 | }, |
624 | self_type: None, |
625 | input, |
626 | }); |
627 | } |
628 | |
629 | // Is the argument to Box::pin a function call itself? |
630 | let func = match &outside_args[0] { |
631 | Expr::Call(ExprCall { func, .. }) => func, |
632 | _ => return None, |
633 | }; |
634 | |
635 | // "stringify" the path of the function called |
636 | let func_name = match **func { |
637 | Expr::Path(ref func_path) => path_to_string(&func_path.path), |
638 | _ => return None, |
639 | }; |
640 | |
641 | // Was that function defined inside of the current block? |
642 | // If so, retrieve the statement where it was declared and the function itself |
643 | let (stmt_func_declaration, func) = inside_funs |
644 | .into_iter() |
645 | .find(|(_, fun)| fun.sig.ident == func_name)?; |
646 | |
647 | // If "_self" is present as an argument, we store its type to be able to rewrite "Self" (the |
648 | // parameter type) with the type of "_self" |
649 | let mut self_type = None; |
650 | for arg in &func.sig.inputs { |
651 | if let FnArg::Typed(ty) = arg { |
652 | if let Pat::Ident(PatIdent { ref ident, .. }) = *ty.pat { |
653 | if ident == "_self" { |
654 | let mut ty = *ty.ty.clone(); |
655 | // extract the inner type if the argument is "&self" or "&mut self" |
656 | if let Type::Reference(syn::TypeReference { elem, .. }) = ty { |
657 | ty = *elem; |
658 | } |
659 | |
660 | if let Type::Path(tp) = ty { |
661 | self_type = Some(tp); |
662 | break; |
663 | } |
664 | } |
665 | } |
666 | } |
667 | } |
668 | |
669 | Some(AsyncInfo { |
670 | source_stmt: stmt_func_declaration, |
671 | kind: AsyncKind::Function(func), |
672 | self_type, |
673 | input, |
674 | }) |
675 | } |
676 | |
677 | pub(crate) fn gen_async( |
678 | self, |
679 | args: InstrumentArgs, |
680 | instrumented_function_name: &str, |
681 | ) -> Result<proc_macro::TokenStream, syn::Error> { |
682 | // let's rewrite some statements! |
683 | let mut out_stmts: Vec<TokenStream> = self |
684 | .input |
685 | .block |
686 | .stmts |
687 | .iter() |
688 | .map(|stmt| stmt.to_token_stream()) |
689 | .collect(); |
690 | |
691 | if let Some((iter, _stmt)) = self |
692 | .input |
693 | .block |
694 | .stmts |
695 | .iter() |
696 | .enumerate() |
697 | .find(|(_iter, stmt)| *stmt == self.source_stmt) |
698 | { |
699 | // instrument the future by rewriting the corresponding statement |
700 | out_stmts[iter] = match self.kind { |
701 | // `Box::pin(immediately_invoked_async_fn())` |
702 | AsyncKind::Function(fun) => { |
703 | let fun = MaybeItemFn::from(fun.clone()); |
704 | gen_function( |
705 | fun.as_ref(), |
706 | args, |
707 | instrumented_function_name, |
708 | self.self_type.as_ref(), |
709 | ) |
710 | } |
711 | // `async move { ... }`, optionally pinned |
712 | AsyncKind::Async { |
713 | async_expr, |
714 | pinned_box, |
715 | } => { |
716 | let instrumented_block = gen_block( |
717 | &async_expr.block, |
718 | &self.input.sig.inputs, |
719 | true, |
720 | args, |
721 | instrumented_function_name, |
722 | None, |
723 | ); |
724 | let async_attrs = &async_expr.attrs; |
725 | if pinned_box { |
726 | quote! { |
727 | Box::pin(#(#async_attrs) * async move { #instrumented_block }) |
728 | } |
729 | } else { |
730 | quote! { |
731 | #(#async_attrs) * async move { #instrumented_block } |
732 | } |
733 | } |
734 | } |
735 | }; |
736 | } |
737 | |
738 | let vis = &self.input.vis; |
739 | let sig = &self.input.sig; |
740 | let attrs = &self.input.attrs; |
741 | Ok(quote!( |
742 | #(#attrs) * |
743 | #vis #sig { |
744 | #(#out_stmts) * |
745 | } |
746 | ) |
747 | .into()) |
748 | } |
749 | } |
750 | |
751 | // Return a path as a String |
752 | fn path_to_string(path: &Path) -> String { |
753 | use std::fmt::Write; |
754 | // some heuristic to prevent too many allocations |
755 | let mut res: String = String::with_capacity(path.segments.len() * 5); |
756 | for i: usize in 0..path.segments.len() { |
757 | write!(&mut res, " {}" , path.segments[i].ident) |
758 | .expect(msg:"writing to a String should never fail" ); |
759 | if i < path.segments.len() - 1 { |
760 | res.push_str(string:"::" ); |
761 | } |
762 | } |
763 | res |
764 | } |
765 | |
766 | /// A visitor struct to replace idents and types in some piece |
767 | /// of code (e.g. the "self" and "Self" tokens in user-supplied |
768 | /// fields expressions when the function is generated by an old |
769 | /// version of async-trait). |
770 | struct IdentAndTypesRenamer<'a> { |
771 | types: Vec<(&'a str, TypePath)>, |
772 | idents: Vec<(Ident, Ident)>, |
773 | } |
774 | |
775 | impl<'a> VisitMut for IdentAndTypesRenamer<'a> { |
776 | // we deliberately compare strings because we want to ignore the spans |
777 | // If we apply clippy's lint, the behavior changes |
778 | #[allow (clippy::cmp_owned)] |
779 | fn visit_ident_mut(&mut self, id: &mut Ident) { |
780 | for (old_ident: &Ident, new_ident: &Ident) in &self.idents { |
781 | if id.to_string() == old_ident.to_string() { |
782 | *id = new_ident.clone(); |
783 | } |
784 | } |
785 | } |
786 | |
787 | fn visit_type_mut(&mut self, ty: &mut Type) { |
788 | for (type_name: &&str, new_type: &TypePath) in &self.types { |
789 | if let Type::Path(TypePath { path: &mut Path, .. }) = ty { |
790 | if path_to_string(path) == *type_name { |
791 | *ty = Type::Path(new_type.clone()); |
792 | } |
793 | } |
794 | } |
795 | } |
796 | } |
797 | |
798 | // A visitor struct that replace an async block by its patched version |
799 | struct AsyncTraitBlockReplacer<'a> { |
800 | block: &'a Block, |
801 | patched_block: Block, |
802 | } |
803 | |
804 | impl<'a> VisitMut for AsyncTraitBlockReplacer<'a> { |
805 | fn visit_block_mut(&mut self, i: &mut Block) { |
806 | if i == self.block { |
807 | *i = self.patched_block.clone(); |
808 | } |
809 | } |
810 | } |
811 | |
812 | // Replaces any `impl Trait` with `_` so it can be used as the type in |
813 | // a `let` statement's LHS. |
814 | struct ImplTraitEraser; |
815 | |
816 | impl VisitMut for ImplTraitEraser { |
817 | fn visit_type_mut(&mut self, t: &mut Type) { |
818 | if let Type::ImplTrait(..) = t { |
819 | *t = syn::TypeInfer { |
820 | underscore_token: Token), |
821 | } |
822 | .into(); |
823 | } else { |
824 | syn::visit_mut::visit_type_mut(self, node:t); |
825 | } |
826 | } |
827 | } |
828 | |
829 | fn erase_impl_trait(ty: &Type) -> Type { |
830 | let mut ty: Type = ty.clone(); |
831 | ImplTraitEraser.visit_type_mut(&mut ty); |
832 | ty |
833 | } |
834 | |