| 1 | use proc_macro2::{Span, TokenStream, TokenTree}; |
| 2 | use quote::{quote, quote_spanned, ToTokens}; |
| 3 | use syn::parse::{Parse, ParseStream, Parser}; |
| 4 | use syn::{braced, Attribute, Ident, Path, Signature, Visibility}; |
| 5 | |
| 6 | // syn::AttributeArgs does not implement syn::Parse |
| 7 | type AttributeArgs = syn::punctuated::Punctuated<syn::Meta, syn::Token![,]>; |
| 8 | |
| 9 | #[derive (Clone, Copy, PartialEq)] |
| 10 | enum RuntimeFlavor { |
| 11 | CurrentThread, |
| 12 | Threaded, |
| 13 | } |
| 14 | |
| 15 | impl RuntimeFlavor { |
| 16 | fn from_str(s: &str) -> Result<RuntimeFlavor, String> { |
| 17 | match s { |
| 18 | "current_thread" => Ok(RuntimeFlavor::CurrentThread), |
| 19 | "multi_thread" => Ok(RuntimeFlavor::Threaded), |
| 20 | "single_thread" => Err("The single threaded runtime flavor is called `current_thread`." .to_string()), |
| 21 | "basic_scheduler" => Err("The `basic_scheduler` runtime flavor has been renamed to `current_thread`." .to_string()), |
| 22 | "threaded_scheduler" => Err("The `threaded_scheduler` runtime flavor has been renamed to `multi_thread`." .to_string()), |
| 23 | _ => Err(format!("No such runtime flavor ` {s}`. The runtime flavors are `current_thread` and `multi_thread`." )), |
| 24 | } |
| 25 | } |
| 26 | } |
| 27 | |
| 28 | #[derive (Clone, Copy, PartialEq)] |
| 29 | enum UnhandledPanic { |
| 30 | Ignore, |
| 31 | ShutdownRuntime, |
| 32 | } |
| 33 | |
| 34 | impl UnhandledPanic { |
| 35 | fn from_str(s: &str) -> Result<UnhandledPanic, String> { |
| 36 | match s { |
| 37 | "ignore" => Ok(UnhandledPanic::Ignore), |
| 38 | "shutdown_runtime" => Ok(UnhandledPanic::ShutdownRuntime), |
| 39 | _ => Err(format!("No such unhandled panic behavior ` {s}`. The unhandled panic behaviors are `ignore` and `shutdown_runtime`." )), |
| 40 | } |
| 41 | } |
| 42 | |
| 43 | fn into_tokens(self, crate_path: &TokenStream) -> TokenStream { |
| 44 | match self { |
| 45 | UnhandledPanic::Ignore => quote! { #crate_path::runtime::UnhandledPanic::Ignore }, |
| 46 | UnhandledPanic::ShutdownRuntime => { |
| 47 | quote! { #crate_path::runtime::UnhandledPanic::ShutdownRuntime } |
| 48 | } |
| 49 | } |
| 50 | } |
| 51 | } |
| 52 | |
| 53 | struct FinalConfig { |
| 54 | flavor: RuntimeFlavor, |
| 55 | worker_threads: Option<usize>, |
| 56 | start_paused: Option<bool>, |
| 57 | crate_name: Option<Path>, |
| 58 | unhandled_panic: Option<UnhandledPanic>, |
| 59 | } |
| 60 | |
| 61 | /// Config used in case of the attribute not being able to build a valid config |
| 62 | const DEFAULT_ERROR_CONFIG: FinalConfig = FinalConfig { |
| 63 | flavor: RuntimeFlavor::CurrentThread, |
| 64 | worker_threads: None, |
| 65 | start_paused: None, |
| 66 | crate_name: None, |
| 67 | unhandled_panic: None, |
| 68 | }; |
| 69 | |
| 70 | struct Configuration { |
| 71 | rt_multi_thread_available: bool, |
| 72 | default_flavor: RuntimeFlavor, |
| 73 | flavor: Option<RuntimeFlavor>, |
| 74 | worker_threads: Option<(usize, Span)>, |
| 75 | start_paused: Option<(bool, Span)>, |
| 76 | is_test: bool, |
| 77 | crate_name: Option<Path>, |
| 78 | unhandled_panic: Option<(UnhandledPanic, Span)>, |
| 79 | } |
| 80 | |
| 81 | impl Configuration { |
| 82 | fn new(is_test: bool, rt_multi_thread: bool) -> Self { |
| 83 | Configuration { |
| 84 | rt_multi_thread_available: rt_multi_thread, |
| 85 | default_flavor: match is_test { |
| 86 | true => RuntimeFlavor::CurrentThread, |
| 87 | false => RuntimeFlavor::Threaded, |
| 88 | }, |
| 89 | flavor: None, |
| 90 | worker_threads: None, |
| 91 | start_paused: None, |
| 92 | is_test, |
| 93 | crate_name: None, |
| 94 | unhandled_panic: None, |
| 95 | } |
| 96 | } |
| 97 | |
| 98 | fn set_flavor(&mut self, runtime: syn::Lit, span: Span) -> Result<(), syn::Error> { |
| 99 | if self.flavor.is_some() { |
| 100 | return Err(syn::Error::new(span, "`flavor` set multiple times." )); |
| 101 | } |
| 102 | |
| 103 | let runtime_str = parse_string(runtime, span, "flavor" )?; |
| 104 | let runtime = |
| 105 | RuntimeFlavor::from_str(&runtime_str).map_err(|err| syn::Error::new(span, err))?; |
| 106 | self.flavor = Some(runtime); |
| 107 | Ok(()) |
| 108 | } |
| 109 | |
| 110 | fn set_worker_threads( |
| 111 | &mut self, |
| 112 | worker_threads: syn::Lit, |
| 113 | span: Span, |
| 114 | ) -> Result<(), syn::Error> { |
| 115 | if self.worker_threads.is_some() { |
| 116 | return Err(syn::Error::new( |
| 117 | span, |
| 118 | "`worker_threads` set multiple times." , |
| 119 | )); |
| 120 | } |
| 121 | |
| 122 | let worker_threads = parse_int(worker_threads, span, "worker_threads" )?; |
| 123 | if worker_threads == 0 { |
| 124 | return Err(syn::Error::new(span, "`worker_threads` may not be 0." )); |
| 125 | } |
| 126 | self.worker_threads = Some((worker_threads, span)); |
| 127 | Ok(()) |
| 128 | } |
| 129 | |
| 130 | fn set_start_paused(&mut self, start_paused: syn::Lit, span: Span) -> Result<(), syn::Error> { |
| 131 | if self.start_paused.is_some() { |
| 132 | return Err(syn::Error::new(span, "`start_paused` set multiple times." )); |
| 133 | } |
| 134 | |
| 135 | let start_paused = parse_bool(start_paused, span, "start_paused" )?; |
| 136 | self.start_paused = Some((start_paused, span)); |
| 137 | Ok(()) |
| 138 | } |
| 139 | |
| 140 | fn set_crate_name(&mut self, name: syn::Lit, span: Span) -> Result<(), syn::Error> { |
| 141 | if self.crate_name.is_some() { |
| 142 | return Err(syn::Error::new(span, "`crate` set multiple times." )); |
| 143 | } |
| 144 | let name_path = parse_path(name, span, "crate" )?; |
| 145 | self.crate_name = Some(name_path); |
| 146 | Ok(()) |
| 147 | } |
| 148 | |
| 149 | fn set_unhandled_panic( |
| 150 | &mut self, |
| 151 | unhandled_panic: syn::Lit, |
| 152 | span: Span, |
| 153 | ) -> Result<(), syn::Error> { |
| 154 | if self.unhandled_panic.is_some() { |
| 155 | return Err(syn::Error::new( |
| 156 | span, |
| 157 | "`unhandled_panic` set multiple times." , |
| 158 | )); |
| 159 | } |
| 160 | |
| 161 | let unhandled_panic = parse_string(unhandled_panic, span, "unhandled_panic" )?; |
| 162 | let unhandled_panic = |
| 163 | UnhandledPanic::from_str(&unhandled_panic).map_err(|err| syn::Error::new(span, err))?; |
| 164 | self.unhandled_panic = Some((unhandled_panic, span)); |
| 165 | Ok(()) |
| 166 | } |
| 167 | |
| 168 | fn macro_name(&self) -> &'static str { |
| 169 | if self.is_test { |
| 170 | "tokio::test" |
| 171 | } else { |
| 172 | "tokio::main" |
| 173 | } |
| 174 | } |
| 175 | |
| 176 | fn build(&self) -> Result<FinalConfig, syn::Error> { |
| 177 | use RuntimeFlavor as F; |
| 178 | |
| 179 | let flavor = self.flavor.unwrap_or(self.default_flavor); |
| 180 | let worker_threads = match (flavor, self.worker_threads) { |
| 181 | (F::CurrentThread, Some((_, worker_threads_span))) => { |
| 182 | let msg = format!( |
| 183 | "The `worker_threads` option requires the `multi_thread` runtime flavor. Use `#[ {}(flavor = \"multi_thread \")]`" , |
| 184 | self.macro_name(), |
| 185 | ); |
| 186 | return Err(syn::Error::new(worker_threads_span, msg)); |
| 187 | } |
| 188 | (F::CurrentThread, None) => None, |
| 189 | (F::Threaded, worker_threads) if self.rt_multi_thread_available => { |
| 190 | worker_threads.map(|(val, _span)| val) |
| 191 | } |
| 192 | (F::Threaded, _) => { |
| 193 | let msg = if self.flavor.is_none() { |
| 194 | "The default runtime flavor is `multi_thread`, but the `rt-multi-thread` feature is disabled." |
| 195 | } else { |
| 196 | "The runtime flavor `multi_thread` requires the `rt-multi-thread` feature." |
| 197 | }; |
| 198 | return Err(syn::Error::new(Span::call_site(), msg)); |
| 199 | } |
| 200 | }; |
| 201 | |
| 202 | let start_paused = match (flavor, self.start_paused) { |
| 203 | (F::Threaded, Some((_, start_paused_span))) => { |
| 204 | let msg = format!( |
| 205 | "The `start_paused` option requires the `current_thread` runtime flavor. Use `#[ {}(flavor = \"current_thread \")]`" , |
| 206 | self.macro_name(), |
| 207 | ); |
| 208 | return Err(syn::Error::new(start_paused_span, msg)); |
| 209 | } |
| 210 | (F::CurrentThread, Some((start_paused, _))) => Some(start_paused), |
| 211 | (_, None) => None, |
| 212 | }; |
| 213 | |
| 214 | let unhandled_panic = match (flavor, self.unhandled_panic) { |
| 215 | (F::Threaded, Some((_, unhandled_panic_span))) => { |
| 216 | let msg = format!( |
| 217 | "The `unhandled_panic` option requires the `current_thread` runtime flavor. Use `#[ {}(flavor = \"current_thread \")]`" , |
| 218 | self.macro_name(), |
| 219 | ); |
| 220 | return Err(syn::Error::new(unhandled_panic_span, msg)); |
| 221 | } |
| 222 | (F::CurrentThread, Some((unhandled_panic, _))) => Some(unhandled_panic), |
| 223 | (_, None) => None, |
| 224 | }; |
| 225 | |
| 226 | Ok(FinalConfig { |
| 227 | crate_name: self.crate_name.clone(), |
| 228 | flavor, |
| 229 | worker_threads, |
| 230 | start_paused, |
| 231 | unhandled_panic, |
| 232 | }) |
| 233 | } |
| 234 | } |
| 235 | |
| 236 | fn parse_int(int: syn::Lit, span: Span, field: &str) -> Result<usize, syn::Error> { |
| 237 | match int { |
| 238 | syn::Lit::Int(lit: LitInt) => match lit.base10_parse::<usize>() { |
| 239 | Ok(value: usize) => Ok(value), |
| 240 | Err(e: Error) => Err(syn::Error::new( |
| 241 | span, |
| 242 | message:format!("Failed to parse value of ` {field}` as integer: {e}" ), |
| 243 | )), |
| 244 | }, |
| 245 | _ => Err(syn::Error::new( |
| 246 | span, |
| 247 | message:format!("Failed to parse value of ` {field}` as integer." ), |
| 248 | )), |
| 249 | } |
| 250 | } |
| 251 | |
| 252 | fn parse_string(int: syn::Lit, span: Span, field: &str) -> Result<String, syn::Error> { |
| 253 | match int { |
| 254 | syn::Lit::Str(s: LitStr) => Ok(s.value()), |
| 255 | syn::Lit::Verbatim(s: Literal) => Ok(s.to_string()), |
| 256 | _ => Err(syn::Error::new( |
| 257 | span, |
| 258 | message:format!("Failed to parse value of ` {field}` as string." ), |
| 259 | )), |
| 260 | } |
| 261 | } |
| 262 | |
| 263 | fn parse_path(lit: syn::Lit, span: Span, field: &str) -> Result<Path, syn::Error> { |
| 264 | match lit { |
| 265 | syn::Lit::Str(s: LitStr) => { |
| 266 | let err: Error = syn::Error::new( |
| 267 | span, |
| 268 | message:format!( |
| 269 | "Failed to parse value of ` {}` as path: \"{}\"" , |
| 270 | field, |
| 271 | s.value() |
| 272 | ), |
| 273 | ); |
| 274 | s.parse::<syn::Path>().map_err(|_| err.clone()) |
| 275 | } |
| 276 | _ => Err(syn::Error::new( |
| 277 | span, |
| 278 | message:format!("Failed to parse value of ` {field}` as path." ), |
| 279 | )), |
| 280 | } |
| 281 | } |
| 282 | |
| 283 | fn parse_bool(bool: syn::Lit, span: Span, field: &str) -> Result<bool, syn::Error> { |
| 284 | match bool { |
| 285 | syn::Lit::Bool(b: LitBool) => Ok(b.value), |
| 286 | _ => Err(syn::Error::new( |
| 287 | span, |
| 288 | message:format!("Failed to parse value of ` {field}` as bool." ), |
| 289 | )), |
| 290 | } |
| 291 | } |
| 292 | |
| 293 | fn build_config( |
| 294 | input: &ItemFn, |
| 295 | args: AttributeArgs, |
| 296 | is_test: bool, |
| 297 | rt_multi_thread: bool, |
| 298 | ) -> Result<FinalConfig, syn::Error> { |
| 299 | if input.sig.asyncness.is_none() { |
| 300 | let msg = "the `async` keyword is missing from the function declaration" ; |
| 301 | return Err(syn::Error::new_spanned(input.sig.fn_token, msg)); |
| 302 | } |
| 303 | |
| 304 | let mut config = Configuration::new(is_test, rt_multi_thread); |
| 305 | let macro_name = config.macro_name(); |
| 306 | |
| 307 | for arg in args { |
| 308 | match arg { |
| 309 | syn::Meta::NameValue(namevalue) => { |
| 310 | let ident = namevalue |
| 311 | .path |
| 312 | .get_ident() |
| 313 | .ok_or_else(|| { |
| 314 | syn::Error::new_spanned(&namevalue, "Must have specified ident" ) |
| 315 | })? |
| 316 | .to_string() |
| 317 | .to_lowercase(); |
| 318 | let lit = match &namevalue.value { |
| 319 | syn::Expr::Lit(syn::ExprLit { lit, .. }) => lit, |
| 320 | expr => return Err(syn::Error::new_spanned(expr, "Must be a literal" )), |
| 321 | }; |
| 322 | match ident.as_str() { |
| 323 | "worker_threads" => { |
| 324 | config.set_worker_threads(lit.clone(), syn::spanned::Spanned::span(lit))?; |
| 325 | } |
| 326 | "flavor" => { |
| 327 | config.set_flavor(lit.clone(), syn::spanned::Spanned::span(lit))?; |
| 328 | } |
| 329 | "start_paused" => { |
| 330 | config.set_start_paused(lit.clone(), syn::spanned::Spanned::span(lit))?; |
| 331 | } |
| 332 | "core_threads" => { |
| 333 | let msg = "Attribute `core_threads` is renamed to `worker_threads`" ; |
| 334 | return Err(syn::Error::new_spanned(namevalue, msg)); |
| 335 | } |
| 336 | "crate" => { |
| 337 | config.set_crate_name(lit.clone(), syn::spanned::Spanned::span(lit))?; |
| 338 | } |
| 339 | "unhandled_panic" => { |
| 340 | config |
| 341 | .set_unhandled_panic(lit.clone(), syn::spanned::Spanned::span(lit))?; |
| 342 | } |
| 343 | name => { |
| 344 | let msg = format!( |
| 345 | "Unknown attribute {name} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `crate`, `unhandled_panic`" , |
| 346 | ); |
| 347 | return Err(syn::Error::new_spanned(namevalue, msg)); |
| 348 | } |
| 349 | } |
| 350 | } |
| 351 | syn::Meta::Path(path) => { |
| 352 | let name = path |
| 353 | .get_ident() |
| 354 | .ok_or_else(|| syn::Error::new_spanned(&path, "Must have specified ident" ))? |
| 355 | .to_string() |
| 356 | .to_lowercase(); |
| 357 | let msg = match name.as_str() { |
| 358 | "threaded_scheduler" | "multi_thread" => { |
| 359 | format!( |
| 360 | "Set the runtime flavor with #[ {macro_name}(flavor = \"multi_thread \")]." |
| 361 | ) |
| 362 | } |
| 363 | "basic_scheduler" | "current_thread" | "single_threaded" => { |
| 364 | format!( |
| 365 | "Set the runtime flavor with #[ {macro_name}(flavor = \"current_thread \")]." |
| 366 | ) |
| 367 | } |
| 368 | "flavor" | "worker_threads" | "start_paused" | "crate" | "unhandled_panic" => { |
| 369 | format!("The ` {name}` attribute requires an argument." ) |
| 370 | } |
| 371 | name => { |
| 372 | format!("Unknown attribute {name} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `crate`, `unhandled_panic`." ) |
| 373 | } |
| 374 | }; |
| 375 | return Err(syn::Error::new_spanned(path, msg)); |
| 376 | } |
| 377 | other => { |
| 378 | return Err(syn::Error::new_spanned( |
| 379 | other, |
| 380 | "Unknown attribute inside the macro" , |
| 381 | )); |
| 382 | } |
| 383 | } |
| 384 | } |
| 385 | |
| 386 | config.build() |
| 387 | } |
| 388 | |
| 389 | fn parse_knobs(mut input: ItemFn, is_test: bool, config: FinalConfig) -> TokenStream { |
| 390 | input.sig.asyncness = None; |
| 391 | |
| 392 | // If type mismatch occurs, the current rustc points to the last statement. |
| 393 | let (last_stmt_start_span, last_stmt_end_span) = { |
| 394 | let mut last_stmt = input.stmts.last().cloned().unwrap_or_default().into_iter(); |
| 395 | |
| 396 | // `Span` on stable Rust has a limitation that only points to the first |
| 397 | // token, not the whole tokens. We can work around this limitation by |
| 398 | // using the first/last span of the tokens like |
| 399 | // `syn::Error::new_spanned` does. |
| 400 | let start = last_stmt.next().map_or_else(Span::call_site, |t| t.span()); |
| 401 | let end = last_stmt.last().map_or(start, |t| t.span()); |
| 402 | (start, end) |
| 403 | }; |
| 404 | |
| 405 | let crate_path = config |
| 406 | .crate_name |
| 407 | .map(ToTokens::into_token_stream) |
| 408 | .unwrap_or_else(|| Ident::new("tokio" , last_stmt_start_span).into_token_stream()); |
| 409 | |
| 410 | let mut rt = match config.flavor { |
| 411 | RuntimeFlavor::CurrentThread => quote_spanned! {last_stmt_start_span=> |
| 412 | #crate_path::runtime::Builder::new_current_thread() |
| 413 | }, |
| 414 | RuntimeFlavor::Threaded => quote_spanned! {last_stmt_start_span=> |
| 415 | #crate_path::runtime::Builder::new_multi_thread() |
| 416 | }, |
| 417 | }; |
| 418 | if let Some(v) = config.worker_threads { |
| 419 | rt = quote_spanned! {last_stmt_start_span=> #rt.worker_threads(#v) }; |
| 420 | } |
| 421 | if let Some(v) = config.start_paused { |
| 422 | rt = quote_spanned! {last_stmt_start_span=> #rt.start_paused(#v) }; |
| 423 | } |
| 424 | if let Some(v) = config.unhandled_panic { |
| 425 | let unhandled_panic = v.into_tokens(&crate_path); |
| 426 | rt = quote_spanned! {last_stmt_start_span=> #rt.unhandled_panic(#unhandled_panic) }; |
| 427 | } |
| 428 | |
| 429 | let generated_attrs = if is_test { |
| 430 | quote! { |
| 431 | #[::core::prelude::v1::test] |
| 432 | } |
| 433 | } else { |
| 434 | quote! {} |
| 435 | }; |
| 436 | |
| 437 | let body_ident = quote! { body }; |
| 438 | // This explicit `return` is intentional. See tokio-rs/tokio#4636 |
| 439 | let last_block = quote_spanned! {last_stmt_end_span=> |
| 440 | #[allow(clippy::expect_used, clippy::diverging_sub_expression, clippy::needless_return)] |
| 441 | { |
| 442 | return #rt |
| 443 | .enable_all() |
| 444 | .build() |
| 445 | .expect("Failed building the Runtime" ) |
| 446 | .block_on(#body_ident); |
| 447 | } |
| 448 | }; |
| 449 | |
| 450 | let body = input.body(); |
| 451 | |
| 452 | // For test functions pin the body to the stack and use `Pin<&mut dyn |
| 453 | // Future>` to reduce the amount of `Runtime::block_on` (and related |
| 454 | // functions) copies we generate during compilation due to the generic |
| 455 | // parameter `F` (the future to block on). This could have an impact on |
| 456 | // performance, but because it's only for testing it's unlikely to be very |
| 457 | // large. |
| 458 | // |
| 459 | // We don't do this for the main function as it should only be used once so |
| 460 | // there will be no benefit. |
| 461 | let body = if is_test { |
| 462 | let output_type = match &input.sig.output { |
| 463 | // For functions with no return value syn doesn't print anything, |
| 464 | // but that doesn't work as `Output` for our boxed `Future`, so |
| 465 | // default to `()` (the same type as the function output). |
| 466 | syn::ReturnType::Default => quote! { () }, |
| 467 | syn::ReturnType::Type(_, ret_type) => quote! { #ret_type }, |
| 468 | }; |
| 469 | quote! { |
| 470 | let body = async #body; |
| 471 | #crate_path::pin!(body); |
| 472 | let body: ::core::pin::Pin<&mut dyn ::core::future::Future<Output = #output_type>> = body; |
| 473 | } |
| 474 | } else { |
| 475 | quote! { |
| 476 | let body = async #body; |
| 477 | } |
| 478 | }; |
| 479 | |
| 480 | input.into_tokens(generated_attrs, body, last_block) |
| 481 | } |
| 482 | |
| 483 | fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream { |
| 484 | tokens.extend(iter:error.into_compile_error()); |
| 485 | tokens |
| 486 | } |
| 487 | |
| 488 | pub(crate) fn main(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream { |
| 489 | // If any of the steps for this macro fail, we still want to expand to an item that is as close |
| 490 | // to the expected output as possible. This helps out IDEs such that completions and other |
| 491 | // related features keep working. |
| 492 | let input: ItemFn = match syn::parse2(tokens:item.clone()) { |
| 493 | Ok(it: ItemFn) => it, |
| 494 | Err(e: Error) => return token_stream_with_error(tokens:item, error:e), |
| 495 | }; |
| 496 | |
| 497 | let config: Result = if input.sig.ident == "main" && !input.sig.inputs.is_empty() { |
| 498 | let msg: &'static str = "the main function cannot accept arguments" ; |
| 499 | Err(syn::Error::new_spanned(&input.sig.ident, message:msg)) |
| 500 | } else { |
| 501 | AttributeArgs::parse_terminated |
| 502 | .parse2(args) |
| 503 | .and_then(|args: Punctuated| build_config(&input, args, is_test:false, rt_multi_thread)) |
| 504 | }; |
| 505 | |
| 506 | match config { |
| 507 | Ok(config: FinalConfig) => parse_knobs(input, is_test:false, config), |
| 508 | Err(e: Error) => token_stream_with_error(tokens:parse_knobs(input, false, DEFAULT_ERROR_CONFIG), error:e), |
| 509 | } |
| 510 | } |
| 511 | |
| 512 | // Check whether given attribute is a test attribute of forms: |
| 513 | // * `#[test]` |
| 514 | // * `#[core::prelude::*::test]` or `#[::core::prelude::*::test]` |
| 515 | // * `#[std::prelude::*::test]` or `#[::std::prelude::*::test]` |
| 516 | fn is_test_attribute(attr: &Attribute) -> bool { |
| 517 | let path: &Path = match &attr.meta { |
| 518 | syn::Meta::Path(path: &Path) => path, |
| 519 | _ => return false, |
| 520 | }; |
| 521 | let candidates: [[&'static str; 4]; 2] = [ |
| 522 | ["core" , "prelude" , "*" , "test" ], |
| 523 | ["std" , "prelude" , "*" , "test" ], |
| 524 | ]; |
| 525 | if path.leading_colon.is_none() |
| 526 | && path.segments.len() == 1 |
| 527 | && path.segments[0].arguments.is_none() |
| 528 | && path.segments[0].ident == "test" |
| 529 | { |
| 530 | return true; |
| 531 | } else if path.segments.len() != candidates[0].len() { |
| 532 | return false; |
| 533 | } |
| 534 | candidates.into_iter().any(|segments: [&'static str; 4]| { |
| 535 | path.segments.iter().zip(segments).all(|(segment: &PathSegment, path: &'static str)| { |
| 536 | segment.arguments.is_none() && (path == "*" || segment.ident == path) |
| 537 | }) |
| 538 | }) |
| 539 | } |
| 540 | |
| 541 | pub(crate) fn test(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream { |
| 542 | // If any of the steps for this macro fail, we still want to expand to an item that is as close |
| 543 | // to the expected output as possible. This helps out IDEs such that completions and other |
| 544 | // related features keep working. |
| 545 | let input: ItemFn = match syn::parse2(tokens:item.clone()) { |
| 546 | Ok(it: ItemFn) => it, |
| 547 | Err(e: Error) => return token_stream_with_error(tokens:item, error:e), |
| 548 | }; |
| 549 | let config: Result = if let Some(attr: &Attribute) = input.attrs().find(|attr: &&Attribute| is_test_attribute(attr)) { |
| 550 | let msg: &'static str = "second test attribute is supplied, consider removing or changing the order of your test attributes" ; |
| 551 | Err(syn::Error::new_spanned(tokens:attr, message:msg)) |
| 552 | } else { |
| 553 | AttributeArgs::parse_terminated |
| 554 | .parse2(args) |
| 555 | .and_then(|args: Punctuated| build_config(&input, args, is_test:true, rt_multi_thread)) |
| 556 | }; |
| 557 | |
| 558 | match config { |
| 559 | Ok(config: FinalConfig) => parse_knobs(input, is_test:true, config), |
| 560 | Err(e: Error) => token_stream_with_error(tokens:parse_knobs(input, true, DEFAULT_ERROR_CONFIG), error:e), |
| 561 | } |
| 562 | } |
| 563 | |
| 564 | struct ItemFn { |
| 565 | outer_attrs: Vec<Attribute>, |
| 566 | vis: Visibility, |
| 567 | sig: Signature, |
| 568 | brace_token: syn::token::Brace, |
| 569 | inner_attrs: Vec<Attribute>, |
| 570 | stmts: Vec<proc_macro2::TokenStream>, |
| 571 | } |
| 572 | |
| 573 | impl ItemFn { |
| 574 | /// Access all attributes of the function item. |
| 575 | fn attrs(&self) -> impl Iterator<Item = &Attribute> { |
| 576 | self.outer_attrs.iter().chain(self.inner_attrs.iter()) |
| 577 | } |
| 578 | |
| 579 | /// Get the body of the function item in a manner so that it can be |
| 580 | /// conveniently used with the `quote!` macro. |
| 581 | fn body(&self) -> Body<'_> { |
| 582 | Body { |
| 583 | brace_token: self.brace_token, |
| 584 | stmts: &self.stmts, |
| 585 | } |
| 586 | } |
| 587 | |
| 588 | /// Convert our local function item into a token stream. |
| 589 | fn into_tokens( |
| 590 | self, |
| 591 | generated_attrs: proc_macro2::TokenStream, |
| 592 | body: proc_macro2::TokenStream, |
| 593 | last_block: proc_macro2::TokenStream, |
| 594 | ) -> TokenStream { |
| 595 | let mut tokens = proc_macro2::TokenStream::new(); |
| 596 | // Outer attributes are simply streamed as-is. |
| 597 | for attr in self.outer_attrs { |
| 598 | attr.to_tokens(&mut tokens); |
| 599 | } |
| 600 | |
| 601 | // Inner attributes require extra care, since they're not supported on |
| 602 | // blocks (which is what we're expanded into) we instead lift them |
| 603 | // outside of the function. This matches the behavior of `syn`. |
| 604 | for mut attr in self.inner_attrs { |
| 605 | attr.style = syn::AttrStyle::Outer; |
| 606 | attr.to_tokens(&mut tokens); |
| 607 | } |
| 608 | |
| 609 | // Add generated macros at the end, so macros processed later are aware of them. |
| 610 | generated_attrs.to_tokens(&mut tokens); |
| 611 | |
| 612 | self.vis.to_tokens(&mut tokens); |
| 613 | self.sig.to_tokens(&mut tokens); |
| 614 | |
| 615 | self.brace_token.surround(&mut tokens, |tokens| { |
| 616 | body.to_tokens(tokens); |
| 617 | last_block.to_tokens(tokens); |
| 618 | }); |
| 619 | |
| 620 | tokens |
| 621 | } |
| 622 | } |
| 623 | |
| 624 | impl Parse for ItemFn { |
| 625 | #[inline ] |
| 626 | fn parse(input: ParseStream<'_>) -> syn::Result<Self> { |
| 627 | // This parse implementation has been largely lifted from `syn`, with |
| 628 | // the exception of: |
| 629 | // * We don't have access to the plumbing necessary to parse inner |
| 630 | // attributes in-place. |
| 631 | // * We do our own statements parsing to avoid recursively parsing |
| 632 | // entire statements and only look for the parts we're interested in. |
| 633 | |
| 634 | let outer_attrs = input.call(Attribute::parse_outer)?; |
| 635 | let vis: Visibility = input.parse()?; |
| 636 | let sig: Signature = input.parse()?; |
| 637 | |
| 638 | let content; |
| 639 | let brace_token = braced!(content in input); |
| 640 | let inner_attrs = Attribute::parse_inner(&content)?; |
| 641 | |
| 642 | let mut buf = proc_macro2::TokenStream::new(); |
| 643 | let mut stmts = Vec::new(); |
| 644 | |
| 645 | while !content.is_empty() { |
| 646 | if let Some(semi) = content.parse::<Option<syn::Token![;]>>()? { |
| 647 | semi.to_tokens(&mut buf); |
| 648 | stmts.push(buf); |
| 649 | buf = proc_macro2::TokenStream::new(); |
| 650 | continue; |
| 651 | } |
| 652 | |
| 653 | // Parse a single token tree and extend our current buffer with it. |
| 654 | // This avoids parsing the entire content of the sub-tree. |
| 655 | buf.extend([content.parse::<TokenTree>()?]); |
| 656 | } |
| 657 | |
| 658 | if !buf.is_empty() { |
| 659 | stmts.push(buf); |
| 660 | } |
| 661 | |
| 662 | Ok(Self { |
| 663 | outer_attrs, |
| 664 | vis, |
| 665 | sig, |
| 666 | brace_token, |
| 667 | inner_attrs, |
| 668 | stmts, |
| 669 | }) |
| 670 | } |
| 671 | } |
| 672 | |
| 673 | struct Body<'a> { |
| 674 | brace_token: syn::token::Brace, |
| 675 | // Statements, with terminating `;`. |
| 676 | stmts: &'a [TokenStream], |
| 677 | } |
| 678 | |
| 679 | impl ToTokens for Body<'_> { |
| 680 | fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { |
| 681 | self.brace_token.surround(tokens, |tokens: &mut TokenStream| { |
| 682 | for stmt: &TokenStream in self.stmts { |
| 683 | stmt.to_tokens(tokens); |
| 684 | } |
| 685 | }); |
| 686 | } |
| 687 | } |
| 688 | |