| 1 | use std::iter; |
| 2 | |
| 3 | use proc_macro2::TokenStream; |
| 4 | use quote::{quote, quote_spanned, ToTokens}; |
| 5 | use syn::visit_mut::VisitMut; |
| 6 | use syn::{ |
| 7 | punctuated::Punctuated, spanned::Spanned, Block, Expr, ExprAsync, ExprCall, FieldPat, FnArg, |
| 8 | Ident, Item, ItemFn, Pat, PatIdent, PatReference, PatStruct, PatTuple, PatTupleStruct, PatType, |
| 9 | Path, ReturnType, Signature, Stmt, Token, Type, TypePath, |
| 10 | }; |
| 11 | |
| 12 | use crate::{ |
| 13 | attr::{Field, Fields, FormatMode, InstrumentArgs, Level}, |
| 14 | MaybeItemFn, MaybeItemFnRef, |
| 15 | }; |
| 16 | |
| 17 | /// Given an existing function, generate an instrumented version of that function |
| 18 | pub(crate) fn gen_function<'a, B: ToTokens + 'a>( |
| 19 | input: MaybeItemFnRef<'a, B>, |
| 20 | args: InstrumentArgs, |
| 21 | instrumented_function_name: &str, |
| 22 | self_type: Option<&TypePath>, |
| 23 | ) -> proc_macro2::TokenStream { |
| 24 | // these are needed ahead of time, as ItemFn contains the function body _and_ |
| 25 | // isn't representable inside a quote!/quote_spanned! macro |
| 26 | // (Syn's ToTokens isn't implemented for ItemFn) |
| 27 | let MaybeItemFnRef { |
| 28 | outer_attrs, |
| 29 | inner_attrs, |
| 30 | vis, |
| 31 | sig, |
| 32 | block, |
| 33 | } = input; |
| 34 | |
| 35 | let Signature { |
| 36 | output, |
| 37 | inputs: params, |
| 38 | unsafety, |
| 39 | asyncness, |
| 40 | constness, |
| 41 | abi, |
| 42 | ident, |
| 43 | generics: |
| 44 | syn::Generics { |
| 45 | params: gen_params, |
| 46 | where_clause, |
| 47 | .. |
| 48 | }, |
| 49 | .. |
| 50 | } = sig; |
| 51 | |
| 52 | let warnings = args.warnings(); |
| 53 | |
| 54 | let (return_type, return_span) = if let ReturnType::Type(_, return_type) = &output { |
| 55 | (erase_impl_trait(return_type), return_type.span()) |
| 56 | } else { |
| 57 | // Point at function name if we don't have an explicit return type |
| 58 | (syn::parse_quote! { () }, ident.span()) |
| 59 | }; |
| 60 | // Install a fake return statement as the first thing in the function |
| 61 | // body, so that we eagerly infer that the return type is what we |
| 62 | // declared in the async fn signature. |
| 63 | // The `#[allow(..)]` is given because the return statement is |
| 64 | // unreachable, but does affect inference, so it needs to be written |
| 65 | // exactly that way for it to do its magic. |
| 66 | let fake_return_edge = quote_spanned! {return_span=> |
| 67 | #[allow( |
| 68 | unknown_lints, unreachable_code, clippy::diverging_sub_expression, |
| 69 | clippy::let_unit_value, clippy::unreachable, clippy::let_with_type_underscore, |
| 70 | clippy::empty_loop |
| 71 | )] |
| 72 | if false { |
| 73 | let __tracing_attr_fake_return: #return_type = loop {}; |
| 74 | return __tracing_attr_fake_return; |
| 75 | } |
| 76 | }; |
| 77 | let block = quote! { |
| 78 | { |
| 79 | #fake_return_edge |
| 80 | #block |
| 81 | } |
| 82 | }; |
| 83 | |
| 84 | let body = gen_block( |
| 85 | &block, |
| 86 | params, |
| 87 | asyncness.is_some(), |
| 88 | args, |
| 89 | instrumented_function_name, |
| 90 | self_type, |
| 91 | ); |
| 92 | |
| 93 | quote!( |
| 94 | #(#outer_attrs) * |
| 95 | #vis #constness #asyncness #unsafety #abi fn #ident<#gen_params>(#params) #output |
| 96 | #where_clause |
| 97 | { |
| 98 | #(#inner_attrs) * |
| 99 | #warnings |
| 100 | #body |
| 101 | } |
| 102 | ) |
| 103 | } |
| 104 | |
| 105 | /// Instrument a block |
| 106 | fn gen_block<B: ToTokens>( |
| 107 | block: &B, |
| 108 | params: &Punctuated<FnArg, Token![,]>, |
| 109 | async_context: bool, |
| 110 | mut args: InstrumentArgs, |
| 111 | instrumented_function_name: &str, |
| 112 | self_type: Option<&TypePath>, |
| 113 | ) -> proc_macro2::TokenStream { |
| 114 | // generate the span's name |
| 115 | let span_name = args |
| 116 | // did the user override the span's name? |
| 117 | .name |
| 118 | .as_ref() |
| 119 | .map(|name| quote!(#name)) |
| 120 | .unwrap_or_else(|| quote!(#instrumented_function_name)); |
| 121 | |
| 122 | let args_level = args.level(); |
| 123 | let level = args_level.clone(); |
| 124 | |
| 125 | let follows_from = args.follows_from.iter(); |
| 126 | let follows_from = quote! { |
| 127 | #(for cause in #follows_from { |
| 128 | __tracing_attr_span.follows_from(cause); |
| 129 | })* |
| 130 | }; |
| 131 | |
| 132 | // generate this inside a closure, so we can return early on errors. |
| 133 | let span = (|| { |
| 134 | // Pull out the arguments-to-be-skipped first, so we can filter results |
| 135 | // below. |
| 136 | let param_names: Vec<(Ident, (Ident, RecordType))> = params |
| 137 | .clone() |
| 138 | .into_iter() |
| 139 | .flat_map(|param| match param { |
| 140 | FnArg::Typed(PatType { pat, ty, .. }) => { |
| 141 | param_names(*pat, RecordType::parse_from_ty(&ty)) |
| 142 | } |
| 143 | FnArg::Receiver(_) => Box::new(iter::once(( |
| 144 | Ident::new("self" , param.span()), |
| 145 | RecordType::Debug, |
| 146 | ))), |
| 147 | }) |
| 148 | // Little dance with new (user-exposed) names and old (internal) |
| 149 | // names of identifiers. That way, we could do the following |
| 150 | // even though async_trait (<=0.1.43) rewrites "self" as "_self": |
| 151 | // ``` |
| 152 | // #[async_trait] |
| 153 | // impl Foo for FooImpl { |
| 154 | // #[instrument(skip(self))] |
| 155 | // async fn foo(&self, v: usize) {} |
| 156 | // } |
| 157 | // ``` |
| 158 | .map(|(x, record_type)| { |
| 159 | // if we are inside a function generated by async-trait <=0.1.43, we need to |
| 160 | // take care to rewrite "_self" as "self" for 'user convenience' |
| 161 | if self_type.is_some() && x == "_self" { |
| 162 | (Ident::new("self" , x.span()), (x, record_type)) |
| 163 | } else { |
| 164 | (x.clone(), (x, record_type)) |
| 165 | } |
| 166 | }) |
| 167 | .collect(); |
| 168 | |
| 169 | for skip in &args.skips { |
| 170 | if !param_names.iter().map(|(user, _)| user).any(|y| y == skip) { |
| 171 | return quote_spanned! {skip.span()=> |
| 172 | compile_error!("attempting to skip non-existent parameter" ) |
| 173 | }; |
| 174 | } |
| 175 | } |
| 176 | |
| 177 | let target = args.target(); |
| 178 | |
| 179 | let parent = args.parent.iter(); |
| 180 | |
| 181 | // filter out skipped fields |
| 182 | let quoted_fields: Vec<_> = param_names |
| 183 | .iter() |
| 184 | .filter(|(param, _)| { |
| 185 | if args.skip_all || args.skips.contains(param) { |
| 186 | return false; |
| 187 | } |
| 188 | |
| 189 | // If any parameters have the same name as a custom field, skip |
| 190 | // and allow them to be formatted by the custom field. |
| 191 | if let Some(ref fields) = args.fields { |
| 192 | fields.0.iter().all(|Field { ref name, .. }| { |
| 193 | let first = name.first(); |
| 194 | first != name.last() || !first.iter().any(|name| name == ¶m) |
| 195 | }) |
| 196 | } else { |
| 197 | true |
| 198 | } |
| 199 | }) |
| 200 | .map(|(user_name, (real_name, record_type))| match record_type { |
| 201 | RecordType::Value => quote!(#user_name = #real_name), |
| 202 | RecordType::Debug => quote!(#user_name = tracing::field::debug(&#real_name)), |
| 203 | }) |
| 204 | .collect(); |
| 205 | |
| 206 | // replace every use of a variable with its original name |
| 207 | if let Some(Fields(ref mut fields)) = args.fields { |
| 208 | let mut replacer = IdentAndTypesRenamer { |
| 209 | idents: param_names.into_iter().map(|(a, (b, _))| (a, b)).collect(), |
| 210 | types: Vec::new(), |
| 211 | }; |
| 212 | |
| 213 | // when async-trait <=0.1.43 is in use, replace instances |
| 214 | // of the "Self" type inside the fields values |
| 215 | if let Some(self_type) = self_type { |
| 216 | replacer.types.push(("Self" , self_type.clone())); |
| 217 | } |
| 218 | |
| 219 | for e in fields.iter_mut().filter_map(|f| f.value.as_mut()) { |
| 220 | syn::visit_mut::visit_expr_mut(&mut replacer, e); |
| 221 | } |
| 222 | } |
| 223 | |
| 224 | let custom_fields = &args.fields; |
| 225 | |
| 226 | quote!(tracing::span!( |
| 227 | target: #target, |
| 228 | #(parent: #parent,)* |
| 229 | #level, |
| 230 | #span_name, |
| 231 | #(#quoted_fields,)* |
| 232 | #custom_fields |
| 233 | |
| 234 | )) |
| 235 | })(); |
| 236 | |
| 237 | let target = args.target(); |
| 238 | |
| 239 | let err_event = match args.err_args { |
| 240 | Some(event_args) => { |
| 241 | let level_tokens = event_args.level(Level::Error); |
| 242 | match event_args.mode { |
| 243 | FormatMode::Default | FormatMode::Display => Some(quote!( |
| 244 | tracing::event!(target: #target, #level_tokens, error = %e) |
| 245 | )), |
| 246 | FormatMode::Debug => Some(quote!( |
| 247 | tracing::event!(target: #target, #level_tokens, error = ?e) |
| 248 | )), |
| 249 | } |
| 250 | } |
| 251 | _ => None, |
| 252 | }; |
| 253 | |
| 254 | let ret_event = match args.ret_args { |
| 255 | Some(event_args) => { |
| 256 | let level_tokens = event_args.level(args_level); |
| 257 | match event_args.mode { |
| 258 | FormatMode::Display => Some(quote!( |
| 259 | tracing::event!(target: #target, #level_tokens, return = %x) |
| 260 | )), |
| 261 | FormatMode::Default | FormatMode::Debug => Some(quote!( |
| 262 | tracing::event!(target: #target, #level_tokens, return = ?x) |
| 263 | )), |
| 264 | } |
| 265 | } |
| 266 | _ => None, |
| 267 | }; |
| 268 | |
| 269 | // Generate the instrumented function body. |
| 270 | // If the function is an `async fn`, this will wrap it in an async block, |
| 271 | // which is `instrument`ed using `tracing-futures`. Otherwise, this will |
| 272 | // enter the span and then perform the rest of the body. |
| 273 | // If `err` is in args, instrument any resulting `Err`s. |
| 274 | // If `ret` is in args, instrument any resulting `Ok`s when the function |
| 275 | // returns `Result`s, otherwise instrument any resulting values. |
| 276 | if async_context { |
| 277 | let mk_fut = match (err_event, ret_event) { |
| 278 | (Some(err_event), Some(ret_event)) => quote_spanned!(block.span()=> |
| 279 | async move { |
| 280 | let __match_scrutinee = async move #block.await; |
| 281 | match __match_scrutinee { |
| 282 | #[allow(clippy::unit_arg)] |
| 283 | Ok(x) => { |
| 284 | #ret_event; |
| 285 | Ok(x) |
| 286 | }, |
| 287 | Err(e) => { |
| 288 | #err_event; |
| 289 | Err(e) |
| 290 | } |
| 291 | } |
| 292 | } |
| 293 | ), |
| 294 | (Some(err_event), None) => quote_spanned!(block.span()=> |
| 295 | async move { |
| 296 | match async move #block.await { |
| 297 | #[allow(clippy::unit_arg)] |
| 298 | Ok(x) => Ok(x), |
| 299 | Err(e) => { |
| 300 | #err_event; |
| 301 | Err(e) |
| 302 | } |
| 303 | } |
| 304 | } |
| 305 | ), |
| 306 | (None, Some(ret_event)) => quote_spanned!(block.span()=> |
| 307 | async move { |
| 308 | let x = async move #block.await; |
| 309 | #ret_event; |
| 310 | x |
| 311 | } |
| 312 | ), |
| 313 | (None, None) => quote_spanned!(block.span()=> |
| 314 | async move #block |
| 315 | ), |
| 316 | }; |
| 317 | |
| 318 | return quote!( |
| 319 | let __tracing_attr_span = #span; |
| 320 | let __tracing_instrument_future = #mk_fut; |
| 321 | if !__tracing_attr_span.is_disabled() { |
| 322 | #follows_from |
| 323 | tracing::Instrument::instrument( |
| 324 | __tracing_instrument_future, |
| 325 | __tracing_attr_span |
| 326 | ) |
| 327 | .await |
| 328 | } else { |
| 329 | __tracing_instrument_future.await |
| 330 | } |
| 331 | ); |
| 332 | } |
| 333 | |
| 334 | let span = quote!( |
| 335 | // These variables are left uninitialized and initialized only |
| 336 | // if the tracing level is statically enabled at this point. |
| 337 | // While the tracing level is also checked at span creation |
| 338 | // time, that will still create a dummy span, and a dummy guard |
| 339 | // and drop the dummy guard later. By lazily initializing these |
| 340 | // variables, Rust will generate a drop flag for them and thus |
| 341 | // only drop the guard if it was created. This creates code that |
| 342 | // is very straightforward for LLVM to optimize out if the tracing |
| 343 | // level is statically disabled, while not causing any performance |
| 344 | // regression in case the level is enabled. |
| 345 | let __tracing_attr_span; |
| 346 | let __tracing_attr_guard; |
| 347 | if tracing::level_enabled!(#level) || tracing::if_log_enabled!(#level, {true} else {false}) { |
| 348 | __tracing_attr_span = #span; |
| 349 | #follows_from |
| 350 | __tracing_attr_guard = __tracing_attr_span.enter(); |
| 351 | } |
| 352 | ); |
| 353 | |
| 354 | match (err_event, ret_event) { |
| 355 | (Some(err_event), Some(ret_event)) => quote_spanned! {block.span()=> |
| 356 | #span |
| 357 | #[allow(clippy::redundant_closure_call)] |
| 358 | match (move || #block)() { |
| 359 | #[allow(clippy::unit_arg)] |
| 360 | Ok(x) => { |
| 361 | #ret_event; |
| 362 | Ok(x) |
| 363 | }, |
| 364 | Err(e) => { |
| 365 | #err_event; |
| 366 | Err(e) |
| 367 | } |
| 368 | } |
| 369 | }, |
| 370 | (Some(err_event), None) => quote_spanned!(block.span()=> |
| 371 | #span |
| 372 | #[allow(clippy::redundant_closure_call)] |
| 373 | match (move || #block)() { |
| 374 | #[allow(clippy::unit_arg)] |
| 375 | Ok(x) => Ok(x), |
| 376 | Err(e) => { |
| 377 | #err_event; |
| 378 | Err(e) |
| 379 | } |
| 380 | } |
| 381 | ), |
| 382 | (None, Some(ret_event)) => quote_spanned!(block.span()=> |
| 383 | #span |
| 384 | #[allow(clippy::redundant_closure_call)] |
| 385 | let x = (move || #block)(); |
| 386 | #ret_event; |
| 387 | x |
| 388 | ), |
| 389 | (None, None) => quote_spanned!(block.span() => |
| 390 | // Because `quote` produces a stream of tokens _without_ whitespace, the |
| 391 | // `if` and the block will appear directly next to each other. This |
| 392 | // generates a clippy lint about suspicious `if/else` formatting. |
| 393 | // Therefore, suppress the lint inside the generated code... |
| 394 | #[allow(clippy::suspicious_else_formatting)] |
| 395 | { |
| 396 | #span |
| 397 | // ...but turn the lint back on inside the function body. |
| 398 | #[warn(clippy::suspicious_else_formatting)] |
| 399 | #block |
| 400 | } |
| 401 | ), |
| 402 | } |
| 403 | } |
| 404 | |
| 405 | /// Indicates whether a field should be recorded as `Value` or `Debug`. |
| 406 | enum RecordType { |
| 407 | /// The field should be recorded using its `Value` implementation. |
| 408 | Value, |
| 409 | /// The field should be recorded using `tracing::field::debug()`. |
| 410 | Debug, |
| 411 | } |
| 412 | |
| 413 | impl RecordType { |
| 414 | /// Array of primitive types which should be recorded as [RecordType::Value]. |
| 415 | const TYPES_FOR_VALUE: &'static [&'static str] = &[ |
| 416 | "bool" , |
| 417 | "str" , |
| 418 | "u8" , |
| 419 | "i8" , |
| 420 | "u16" , |
| 421 | "i16" , |
| 422 | "u32" , |
| 423 | "i32" , |
| 424 | "u64" , |
| 425 | "i64" , |
| 426 | "u128" , |
| 427 | "i128" , |
| 428 | "f32" , |
| 429 | "f64" , |
| 430 | "usize" , |
| 431 | "isize" , |
| 432 | "String" , |
| 433 | "NonZeroU8" , |
| 434 | "NonZeroI8" , |
| 435 | "NonZeroU16" , |
| 436 | "NonZeroI16" , |
| 437 | "NonZeroU32" , |
| 438 | "NonZeroI32" , |
| 439 | "NonZeroU64" , |
| 440 | "NonZeroI64" , |
| 441 | "NonZeroU128" , |
| 442 | "NonZeroI128" , |
| 443 | "NonZeroUsize" , |
| 444 | "NonZeroIsize" , |
| 445 | "Wrapping" , |
| 446 | ]; |
| 447 | |
| 448 | /// Parse `RecordType` from [Type] by looking up |
| 449 | /// the [RecordType::TYPES_FOR_VALUE] array. |
| 450 | fn parse_from_ty(ty: &Type) -> Self { |
| 451 | match ty { |
| 452 | Type::Path(TypePath { path, .. }) |
| 453 | if path |
| 454 | .segments |
| 455 | .iter() |
| 456 | .last() |
| 457 | .map(|path_segment| { |
| 458 | let ident = path_segment.ident.to_string(); |
| 459 | Self::TYPES_FOR_VALUE.iter().any(|&t| t == ident) |
| 460 | }) |
| 461 | .unwrap_or(false) => |
| 462 | { |
| 463 | RecordType::Value |
| 464 | } |
| 465 | Type::Reference(syn::TypeReference { elem, .. }) => RecordType::parse_from_ty(elem), |
| 466 | _ => RecordType::Debug, |
| 467 | } |
| 468 | } |
| 469 | } |
| 470 | |
| 471 | fn param_names(pat: Pat, record_type: RecordType) -> Box<dyn Iterator<Item = (Ident, RecordType)>> { |
| 472 | match pat { |
| 473 | Pat::Ident(PatIdent { ident, .. }) => Box::new(iter::once((ident, record_type))), |
| 474 | Pat::Reference(PatReference { pat, .. }) => param_names(*pat, record_type), |
| 475 | // We can't get the concrete type of fields in the struct/tuple |
| 476 | // patterns by using `syn`. e.g. `fn foo(Foo { x, y }: Foo) {}`. |
| 477 | // Therefore, the struct/tuple patterns in the arguments will just |
| 478 | // always be recorded as `RecordType::Debug`. |
| 479 | Pat::Struct(PatStruct { fields, .. }) => Box::new( |
| 480 | fields |
| 481 | .into_iter() |
| 482 | .flat_map(|FieldPat { pat, .. }| param_names(*pat, RecordType::Debug)), |
| 483 | ), |
| 484 | Pat::Tuple(PatTuple { elems, .. }) => Box::new( |
| 485 | elems |
| 486 | .into_iter() |
| 487 | .flat_map(|p| param_names(p, RecordType::Debug)), |
| 488 | ), |
| 489 | Pat::TupleStruct(PatTupleStruct { elems, .. }) => Box::new( |
| 490 | elems |
| 491 | .into_iter() |
| 492 | .flat_map(|p| param_names(p, RecordType::Debug)), |
| 493 | ), |
| 494 | |
| 495 | // The above *should* cover all cases of irrefutable patterns, |
| 496 | // but we purposefully don't do any funny business here |
| 497 | // (such as panicking) because that would obscure rustc's |
| 498 | // much more informative error message. |
| 499 | _ => Box::new(iter::empty()), |
| 500 | } |
| 501 | } |
| 502 | |
| 503 | /// The specific async code pattern that was detected |
| 504 | enum AsyncKind<'a> { |
| 505 | /// Immediately-invoked async fn, as generated by `async-trait <= 0.1.43`: |
| 506 | /// `async fn foo<...>(...) {...}; Box::pin(foo<...>(...))` |
| 507 | Function(&'a ItemFn), |
| 508 | /// A function returning an async (move) block, optionally `Box::pin`-ed, |
| 509 | /// as generated by `async-trait >= 0.1.44`: |
| 510 | /// `Box::pin(async move { ... })` |
| 511 | Async { |
| 512 | async_expr: &'a ExprAsync, |
| 513 | pinned_box: bool, |
| 514 | }, |
| 515 | } |
| 516 | |
| 517 | pub(crate) struct AsyncInfo<'block> { |
| 518 | // statement that must be patched |
| 519 | source_stmt: &'block Stmt, |
| 520 | kind: AsyncKind<'block>, |
| 521 | self_type: Option<TypePath>, |
| 522 | input: &'block ItemFn, |
| 523 | } |
| 524 | |
| 525 | impl<'block> AsyncInfo<'block> { |
| 526 | /// Get the AST of the inner function we need to hook, if it looks like a |
| 527 | /// manual future implementation. |
| 528 | /// |
| 529 | /// When we are given a function that returns a (pinned) future containing the |
| 530 | /// user logic, it is that (pinned) future that needs to be instrumented. |
| 531 | /// Were we to instrument its parent, we would only collect information |
| 532 | /// regarding the allocation of that future, and not its own span of execution. |
| 533 | /// |
| 534 | /// We inspect the block of the function to find if it matches any of the |
| 535 | /// following patterns: |
| 536 | /// |
| 537 | /// - Immediately-invoked async fn, as generated by `async-trait <= 0.1.43`: |
| 538 | /// `async fn foo<...>(...) {...}; Box::pin(foo<...>(...))` |
| 539 | /// |
| 540 | /// - A function returning an async (move) block, optionally `Box::pin`-ed, |
| 541 | /// as generated by `async-trait >= 0.1.44`: |
| 542 | /// `Box::pin(async move { ... })` |
| 543 | /// |
| 544 | /// We the return the statement that must be instrumented, along with some |
| 545 | /// other information. |
| 546 | /// 'gen_body' will then be able to use that information to instrument the |
| 547 | /// proper function/future. |
| 548 | /// |
| 549 | /// (this follows the approach suggested in |
| 550 | /// https://github.com/dtolnay/async-trait/issues/45#issuecomment-571245673) |
| 551 | pub(crate) fn from_fn(input: &'block ItemFn) -> Option<Self> { |
| 552 | // are we in an async context? If yes, this isn't a manual async-like pattern |
| 553 | if input.sig.asyncness.is_some() { |
| 554 | return None; |
| 555 | } |
| 556 | |
| 557 | let block = &input.block; |
| 558 | |
| 559 | // list of async functions declared inside the block |
| 560 | let inside_funs = block.stmts.iter().filter_map(|stmt| { |
| 561 | if let Stmt::Item(Item::Fn(fun)) = &stmt { |
| 562 | // If the function is async, this is a candidate |
| 563 | if fun.sig.asyncness.is_some() { |
| 564 | return Some((stmt, fun)); |
| 565 | } |
| 566 | } |
| 567 | None |
| 568 | }); |
| 569 | |
| 570 | // last expression of the block: it determines the return value of the |
| 571 | // block, this is quite likely a `Box::pin` statement or an async block |
| 572 | let (last_expr_stmt, last_expr) = block.stmts.iter().rev().find_map(|stmt| { |
| 573 | if let Stmt::Expr(expr, _semi) = stmt { |
| 574 | Some((stmt, expr)) |
| 575 | } else { |
| 576 | None |
| 577 | } |
| 578 | })?; |
| 579 | |
| 580 | // is the last expression an async block? |
| 581 | if let Expr::Async(async_expr) = last_expr { |
| 582 | return Some(AsyncInfo { |
| 583 | source_stmt: last_expr_stmt, |
| 584 | kind: AsyncKind::Async { |
| 585 | async_expr, |
| 586 | pinned_box: false, |
| 587 | }, |
| 588 | self_type: None, |
| 589 | input, |
| 590 | }); |
| 591 | } |
| 592 | |
| 593 | // is the last expression a function call? |
| 594 | let (outside_func, outside_args) = match last_expr { |
| 595 | Expr::Call(ExprCall { func, args, .. }) => (func, args), |
| 596 | _ => return None, |
| 597 | }; |
| 598 | |
| 599 | // is it a call to `Box::pin()`? |
| 600 | let path = match outside_func.as_ref() { |
| 601 | Expr::Path(path) => &path.path, |
| 602 | _ => return None, |
| 603 | }; |
| 604 | if !path_to_string(path).ends_with("Box::pin" ) { |
| 605 | return None; |
| 606 | } |
| 607 | |
| 608 | // Does the call take an argument? If it doesn't, |
| 609 | // it's not gonna compile anyway, but that's no reason |
| 610 | // to (try to) perform an out of bounds access |
| 611 | if outside_args.is_empty() { |
| 612 | return None; |
| 613 | } |
| 614 | |
| 615 | // Is the argument to Box::pin an async block that |
| 616 | // captures its arguments? |
| 617 | if let Expr::Async(async_expr) = &outside_args[0] { |
| 618 | return Some(AsyncInfo { |
| 619 | source_stmt: last_expr_stmt, |
| 620 | kind: AsyncKind::Async { |
| 621 | async_expr, |
| 622 | pinned_box: true, |
| 623 | }, |
| 624 | self_type: None, |
| 625 | input, |
| 626 | }); |
| 627 | } |
| 628 | |
| 629 | // Is the argument to Box::pin a function call itself? |
| 630 | let func = match &outside_args[0] { |
| 631 | Expr::Call(ExprCall { func, .. }) => func, |
| 632 | _ => return None, |
| 633 | }; |
| 634 | |
| 635 | // "stringify" the path of the function called |
| 636 | let func_name = match **func { |
| 637 | Expr::Path(ref func_path) => path_to_string(&func_path.path), |
| 638 | _ => return None, |
| 639 | }; |
| 640 | |
| 641 | // Was that function defined inside of the current block? |
| 642 | // If so, retrieve the statement where it was declared and the function itself |
| 643 | let (stmt_func_declaration, func) = inside_funs |
| 644 | .into_iter() |
| 645 | .find(|(_, fun)| fun.sig.ident == func_name)?; |
| 646 | |
| 647 | // If "_self" is present as an argument, we store its type to be able to rewrite "Self" (the |
| 648 | // parameter type) with the type of "_self" |
| 649 | let mut self_type = None; |
| 650 | for arg in &func.sig.inputs { |
| 651 | if let FnArg::Typed(ty) = arg { |
| 652 | if let Pat::Ident(PatIdent { ref ident, .. }) = *ty.pat { |
| 653 | if ident == "_self" { |
| 654 | let mut ty = *ty.ty.clone(); |
| 655 | // extract the inner type if the argument is "&self" or "&mut self" |
| 656 | if let Type::Reference(syn::TypeReference { elem, .. }) = ty { |
| 657 | ty = *elem; |
| 658 | } |
| 659 | |
| 660 | if let Type::Path(tp) = ty { |
| 661 | self_type = Some(tp); |
| 662 | break; |
| 663 | } |
| 664 | } |
| 665 | } |
| 666 | } |
| 667 | } |
| 668 | |
| 669 | Some(AsyncInfo { |
| 670 | source_stmt: stmt_func_declaration, |
| 671 | kind: AsyncKind::Function(func), |
| 672 | self_type, |
| 673 | input, |
| 674 | }) |
| 675 | } |
| 676 | |
| 677 | pub(crate) fn gen_async( |
| 678 | self, |
| 679 | args: InstrumentArgs, |
| 680 | instrumented_function_name: &str, |
| 681 | ) -> Result<proc_macro::TokenStream, syn::Error> { |
| 682 | // let's rewrite some statements! |
| 683 | let mut out_stmts: Vec<TokenStream> = self |
| 684 | .input |
| 685 | .block |
| 686 | .stmts |
| 687 | .iter() |
| 688 | .map(|stmt| stmt.to_token_stream()) |
| 689 | .collect(); |
| 690 | |
| 691 | if let Some((iter, _stmt)) = self |
| 692 | .input |
| 693 | .block |
| 694 | .stmts |
| 695 | .iter() |
| 696 | .enumerate() |
| 697 | .find(|(_iter, stmt)| *stmt == self.source_stmt) |
| 698 | { |
| 699 | // instrument the future by rewriting the corresponding statement |
| 700 | out_stmts[iter] = match self.kind { |
| 701 | // `Box::pin(immediately_invoked_async_fn())` |
| 702 | AsyncKind::Function(fun) => { |
| 703 | let fun = MaybeItemFn::from(fun.clone()); |
| 704 | gen_function( |
| 705 | fun.as_ref(), |
| 706 | args, |
| 707 | instrumented_function_name, |
| 708 | self.self_type.as_ref(), |
| 709 | ) |
| 710 | } |
| 711 | // `async move { ... }`, optionally pinned |
| 712 | AsyncKind::Async { |
| 713 | async_expr, |
| 714 | pinned_box, |
| 715 | } => { |
| 716 | let instrumented_block = gen_block( |
| 717 | &async_expr.block, |
| 718 | &self.input.sig.inputs, |
| 719 | true, |
| 720 | args, |
| 721 | instrumented_function_name, |
| 722 | None, |
| 723 | ); |
| 724 | let async_attrs = &async_expr.attrs; |
| 725 | if pinned_box { |
| 726 | quote! { |
| 727 | Box::pin(#(#async_attrs) * async move { #instrumented_block }) |
| 728 | } |
| 729 | } else { |
| 730 | quote! { |
| 731 | #(#async_attrs) * async move { #instrumented_block } |
| 732 | } |
| 733 | } |
| 734 | } |
| 735 | }; |
| 736 | } |
| 737 | |
| 738 | let vis = &self.input.vis; |
| 739 | let sig = &self.input.sig; |
| 740 | let attrs = &self.input.attrs; |
| 741 | Ok(quote!( |
| 742 | #(#attrs) * |
| 743 | #vis #sig { |
| 744 | #(#out_stmts) * |
| 745 | } |
| 746 | ) |
| 747 | .into()) |
| 748 | } |
| 749 | } |
| 750 | |
| 751 | // Return a path as a String |
| 752 | fn path_to_string(path: &Path) -> String { |
| 753 | use std::fmt::Write; |
| 754 | // some heuristic to prevent too many allocations |
| 755 | let mut res: String = String::with_capacity(path.segments.len() * 5); |
| 756 | for i: usize in 0..path.segments.len() { |
| 757 | write!(&mut res, " {}" , path.segments[i].ident) |
| 758 | .expect(msg:"writing to a String should never fail" ); |
| 759 | if i < path.segments.len() - 1 { |
| 760 | res.push_str(string:"::" ); |
| 761 | } |
| 762 | } |
| 763 | res |
| 764 | } |
| 765 | |
| 766 | /// A visitor struct to replace idents and types in some piece |
| 767 | /// of code (e.g. the "self" and "Self" tokens in user-supplied |
| 768 | /// fields expressions when the function is generated by an old |
| 769 | /// version of async-trait). |
| 770 | struct IdentAndTypesRenamer<'a> { |
| 771 | types: Vec<(&'a str, TypePath)>, |
| 772 | idents: Vec<(Ident, Ident)>, |
| 773 | } |
| 774 | |
| 775 | impl<'a> VisitMut for IdentAndTypesRenamer<'a> { |
| 776 | // we deliberately compare strings because we want to ignore the spans |
| 777 | // If we apply clippy's lint, the behavior changes |
| 778 | #[allow (clippy::cmp_owned)] |
| 779 | fn visit_ident_mut(&mut self, id: &mut Ident) { |
| 780 | for (old_ident: &Ident, new_ident: &Ident) in &self.idents { |
| 781 | if id.to_string() == old_ident.to_string() { |
| 782 | *id = new_ident.clone(); |
| 783 | } |
| 784 | } |
| 785 | } |
| 786 | |
| 787 | fn visit_type_mut(&mut self, ty: &mut Type) { |
| 788 | for (type_name: &&str, new_type: &TypePath) in &self.types { |
| 789 | if let Type::Path(TypePath { path: &mut Path, .. }) = ty { |
| 790 | if path_to_string(path) == *type_name { |
| 791 | *ty = Type::Path(new_type.clone()); |
| 792 | } |
| 793 | } |
| 794 | } |
| 795 | } |
| 796 | } |
| 797 | |
| 798 | // A visitor struct that replace an async block by its patched version |
| 799 | struct AsyncTraitBlockReplacer<'a> { |
| 800 | block: &'a Block, |
| 801 | patched_block: Block, |
| 802 | } |
| 803 | |
| 804 | impl<'a> VisitMut for AsyncTraitBlockReplacer<'a> { |
| 805 | fn visit_block_mut(&mut self, i: &mut Block) { |
| 806 | if i == self.block { |
| 807 | *i = self.patched_block.clone(); |
| 808 | } |
| 809 | } |
| 810 | } |
| 811 | |
| 812 | // Replaces any `impl Trait` with `_` so it can be used as the type in |
| 813 | // a `let` statement's LHS. |
| 814 | struct ImplTraitEraser; |
| 815 | |
| 816 | impl VisitMut for ImplTraitEraser { |
| 817 | fn visit_type_mut(&mut self, t: &mut Type) { |
| 818 | if let Type::ImplTrait(..) = t { |
| 819 | *t = syn::TypeInfer { |
| 820 | underscore_token: Token), |
| 821 | } |
| 822 | .into(); |
| 823 | } else { |
| 824 | syn::visit_mut::visit_type_mut(self, node:t); |
| 825 | } |
| 826 | } |
| 827 | } |
| 828 | |
| 829 | fn erase_impl_trait(ty: &Type) -> Type { |
| 830 | let mut ty: Type = ty.clone(); |
| 831 | ImplTraitEraser.visit_type_mut(&mut ty); |
| 832 | ty |
| 833 | } |
| 834 | |