1use proc_macro2::{Span, TokenStream, TokenTree};
2use quote::{quote, quote_spanned, ToTokens};
3use syn::parse::{Parse, ParseStream, Parser};
4use syn::{braced, Attribute, Ident, Path, Signature, Visibility};
5
6// syn::AttributeArgs does not implement syn::Parse
7type AttributeArgs = syn::punctuated::Punctuated<syn::Meta, syn::Token![,]>;
8
9#[derive(Clone, Copy, PartialEq)]
10enum RuntimeFlavor {
11 CurrentThread,
12 Threaded,
13}
14
15impl RuntimeFlavor {
16 fn from_str(s: &str) -> Result<RuntimeFlavor, String> {
17 match s {
18 "current_thread" => Ok(RuntimeFlavor::CurrentThread),
19 "multi_thread" => Ok(RuntimeFlavor::Threaded),
20 "single_thread" => Err("The single threaded runtime flavor is called `current_thread`.".to_string()),
21 "basic_scheduler" => Err("The `basic_scheduler` runtime flavor has been renamed to `current_thread`.".to_string()),
22 "threaded_scheduler" => Err("The `threaded_scheduler` runtime flavor has been renamed to `multi_thread`.".to_string()),
23 _ => Err(format!("No such runtime flavor `{}`. The runtime flavors are `current_thread` and `multi_thread`.", s)),
24 }
25 }
26}
27
28struct FinalConfig {
29 flavor: RuntimeFlavor,
30 worker_threads: Option<usize>,
31 start_paused: Option<bool>,
32 crate_name: Option<Path>,
33}
34
35/// Config used in case of the attribute not being able to build a valid config
36const DEFAULT_ERROR_CONFIG: FinalConfig = FinalConfig {
37 flavor: RuntimeFlavor::CurrentThread,
38 worker_threads: None,
39 start_paused: None,
40 crate_name: None,
41};
42
43struct Configuration {
44 rt_multi_thread_available: bool,
45 default_flavor: RuntimeFlavor,
46 flavor: Option<RuntimeFlavor>,
47 worker_threads: Option<(usize, Span)>,
48 start_paused: Option<(bool, Span)>,
49 is_test: bool,
50 crate_name: Option<Path>,
51}
52
53impl Configuration {
54 fn new(is_test: bool, rt_multi_thread: bool) -> Self {
55 Configuration {
56 rt_multi_thread_available: rt_multi_thread,
57 default_flavor: match is_test {
58 true => RuntimeFlavor::CurrentThread,
59 false => RuntimeFlavor::Threaded,
60 },
61 flavor: None,
62 worker_threads: None,
63 start_paused: None,
64 is_test,
65 crate_name: None,
66 }
67 }
68
69 fn set_flavor(&mut self, runtime: syn::Lit, span: Span) -> Result<(), syn::Error> {
70 if self.flavor.is_some() {
71 return Err(syn::Error::new(span, "`flavor` set multiple times."));
72 }
73
74 let runtime_str = parse_string(runtime, span, "flavor")?;
75 let runtime =
76 RuntimeFlavor::from_str(&runtime_str).map_err(|err| syn::Error::new(span, err))?;
77 self.flavor = Some(runtime);
78 Ok(())
79 }
80
81 fn set_worker_threads(
82 &mut self,
83 worker_threads: syn::Lit,
84 span: Span,
85 ) -> Result<(), syn::Error> {
86 if self.worker_threads.is_some() {
87 return Err(syn::Error::new(
88 span,
89 "`worker_threads` set multiple times.",
90 ));
91 }
92
93 let worker_threads = parse_int(worker_threads, span, "worker_threads")?;
94 if worker_threads == 0 {
95 return Err(syn::Error::new(span, "`worker_threads` may not be 0."));
96 }
97 self.worker_threads = Some((worker_threads, span));
98 Ok(())
99 }
100
101 fn set_start_paused(&mut self, start_paused: syn::Lit, span: Span) -> Result<(), syn::Error> {
102 if self.start_paused.is_some() {
103 return Err(syn::Error::new(span, "`start_paused` set multiple times."));
104 }
105
106 let start_paused = parse_bool(start_paused, span, "start_paused")?;
107 self.start_paused = Some((start_paused, span));
108 Ok(())
109 }
110
111 fn set_crate_name(&mut self, name: syn::Lit, span: Span) -> Result<(), syn::Error> {
112 if self.crate_name.is_some() {
113 return Err(syn::Error::new(span, "`crate` set multiple times."));
114 }
115 let name_path = parse_path(name, span, "crate")?;
116 self.crate_name = Some(name_path);
117 Ok(())
118 }
119
120 fn macro_name(&self) -> &'static str {
121 if self.is_test {
122 "tokio::test"
123 } else {
124 "tokio::main"
125 }
126 }
127
128 fn build(&self) -> Result<FinalConfig, syn::Error> {
129 use RuntimeFlavor as F;
130
131 let flavor = self.flavor.unwrap_or(self.default_flavor);
132 let worker_threads = match (flavor, self.worker_threads) {
133 (F::CurrentThread, Some((_, worker_threads_span))) => {
134 let msg = format!(
135 "The `worker_threads` option requires the `multi_thread` runtime flavor. Use `#[{}(flavor = \"multi_thread\")]`",
136 self.macro_name(),
137 );
138 return Err(syn::Error::new(worker_threads_span, msg));
139 }
140 (F::CurrentThread, None) => None,
141 (F::Threaded, worker_threads) if self.rt_multi_thread_available => {
142 worker_threads.map(|(val, _span)| val)
143 }
144 (F::Threaded, _) => {
145 let msg = if self.flavor.is_none() {
146 "The default runtime flavor is `multi_thread`, but the `rt-multi-thread` feature is disabled."
147 } else {
148 "The runtime flavor `multi_thread` requires the `rt-multi-thread` feature."
149 };
150 return Err(syn::Error::new(Span::call_site(), msg));
151 }
152 };
153
154 let start_paused = match (flavor, self.start_paused) {
155 (F::Threaded, Some((_, start_paused_span))) => {
156 let msg = format!(
157 "The `start_paused` option requires the `current_thread` runtime flavor. Use `#[{}(flavor = \"current_thread\")]`",
158 self.macro_name(),
159 );
160 return Err(syn::Error::new(start_paused_span, msg));
161 }
162 (F::CurrentThread, Some((start_paused, _))) => Some(start_paused),
163 (_, None) => None,
164 };
165
166 Ok(FinalConfig {
167 crate_name: self.crate_name.clone(),
168 flavor,
169 worker_threads,
170 start_paused,
171 })
172 }
173}
174
175fn parse_int(int: syn::Lit, span: Span, field: &str) -> Result<usize, syn::Error> {
176 match int {
177 syn::Lit::Int(lit) => match lit.base10_parse::<usize>() {
178 Ok(value) => Ok(value),
179 Err(e) => Err(syn::Error::new(
180 span,
181 format!("Failed to parse value of `{}` as integer: {}", field, e),
182 )),
183 },
184 _ => Err(syn::Error::new(
185 span,
186 format!("Failed to parse value of `{}` as integer.", field),
187 )),
188 }
189}
190
191fn parse_string(int: syn::Lit, span: Span, field: &str) -> Result<String, syn::Error> {
192 match int {
193 syn::Lit::Str(s) => Ok(s.value()),
194 syn::Lit::Verbatim(s) => Ok(s.to_string()),
195 _ => Err(syn::Error::new(
196 span,
197 format!("Failed to parse value of `{}` as string.", field),
198 )),
199 }
200}
201
202fn parse_path(lit: syn::Lit, span: Span, field: &str) -> Result<Path, syn::Error> {
203 match lit {
204 syn::Lit::Str(s) => {
205 let err = syn::Error::new(
206 span,
207 format!(
208 "Failed to parse value of `{}` as path: \"{}\"",
209 field,
210 s.value()
211 ),
212 );
213 s.parse::<syn::Path>().map_err(|_| err.clone())
214 }
215 _ => Err(syn::Error::new(
216 span,
217 format!("Failed to parse value of `{}` as path.", field),
218 )),
219 }
220}
221
222fn parse_bool(bool: syn::Lit, span: Span, field: &str) -> Result<bool, syn::Error> {
223 match bool {
224 syn::Lit::Bool(b) => Ok(b.value),
225 _ => Err(syn::Error::new(
226 span,
227 format!("Failed to parse value of `{}` as bool.", field),
228 )),
229 }
230}
231
232fn build_config(
233 input: &ItemFn,
234 args: AttributeArgs,
235 is_test: bool,
236 rt_multi_thread: bool,
237) -> Result<FinalConfig, syn::Error> {
238 if input.sig.asyncness.is_none() {
239 let msg = "the `async` keyword is missing from the function declaration";
240 return Err(syn::Error::new_spanned(input.sig.fn_token, msg));
241 }
242
243 let mut config = Configuration::new(is_test, rt_multi_thread);
244 let macro_name = config.macro_name();
245
246 for arg in args {
247 match arg {
248 syn::Meta::NameValue(namevalue) => {
249 let ident = namevalue
250 .path
251 .get_ident()
252 .ok_or_else(|| {
253 syn::Error::new_spanned(&namevalue, "Must have specified ident")
254 })?
255 .to_string()
256 .to_lowercase();
257 let lit = match &namevalue.value {
258 syn::Expr::Lit(syn::ExprLit { lit, .. }) => lit,
259 expr => return Err(syn::Error::new_spanned(expr, "Must be a literal")),
260 };
261 match ident.as_str() {
262 "worker_threads" => {
263 config.set_worker_threads(lit.clone(), syn::spanned::Spanned::span(lit))?;
264 }
265 "flavor" => {
266 config.set_flavor(lit.clone(), syn::spanned::Spanned::span(lit))?;
267 }
268 "start_paused" => {
269 config.set_start_paused(lit.clone(), syn::spanned::Spanned::span(lit))?;
270 }
271 "core_threads" => {
272 let msg = "Attribute `core_threads` is renamed to `worker_threads`";
273 return Err(syn::Error::new_spanned(namevalue, msg));
274 }
275 "crate" => {
276 config.set_crate_name(lit.clone(), syn::spanned::Spanned::span(lit))?;
277 }
278 name => {
279 let msg = format!(
280 "Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `crate`",
281 name,
282 );
283 return Err(syn::Error::new_spanned(namevalue, msg));
284 }
285 }
286 }
287 syn::Meta::Path(path) => {
288 let name = path
289 .get_ident()
290 .ok_or_else(|| syn::Error::new_spanned(&path, "Must have specified ident"))?
291 .to_string()
292 .to_lowercase();
293 let msg = match name.as_str() {
294 "threaded_scheduler" | "multi_thread" => {
295 format!(
296 "Set the runtime flavor with #[{}(flavor = \"multi_thread\")].",
297 macro_name
298 )
299 }
300 "basic_scheduler" | "current_thread" | "single_threaded" => {
301 format!(
302 "Set the runtime flavor with #[{}(flavor = \"current_thread\")].",
303 macro_name
304 )
305 }
306 "flavor" | "worker_threads" | "start_paused" => {
307 format!("The `{}` attribute requires an argument.", name)
308 }
309 name => {
310 format!("Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `crate`", name)
311 }
312 };
313 return Err(syn::Error::new_spanned(path, msg));
314 }
315 other => {
316 return Err(syn::Error::new_spanned(
317 other,
318 "Unknown attribute inside the macro",
319 ));
320 }
321 }
322 }
323
324 config.build()
325}
326
327fn parse_knobs(mut input: ItemFn, is_test: bool, config: FinalConfig) -> TokenStream {
328 input.sig.asyncness = None;
329
330 // If type mismatch occurs, the current rustc points to the last statement.
331 let (last_stmt_start_span, last_stmt_end_span) = {
332 let mut last_stmt = input.stmts.last().cloned().unwrap_or_default().into_iter();
333
334 // `Span` on stable Rust has a limitation that only points to the first
335 // token, not the whole tokens. We can work around this limitation by
336 // using the first/last span of the tokens like
337 // `syn::Error::new_spanned` does.
338 let start = last_stmt.next().map_or_else(Span::call_site, |t| t.span());
339 let end = last_stmt.last().map_or(start, |t| t.span());
340 (start, end)
341 };
342
343 let crate_path = config
344 .crate_name
345 .map(ToTokens::into_token_stream)
346 .unwrap_or_else(|| Ident::new("tokio", last_stmt_start_span).into_token_stream());
347
348 let mut rt = match config.flavor {
349 RuntimeFlavor::CurrentThread => quote_spanned! {last_stmt_start_span=>
350 #crate_path::runtime::Builder::new_current_thread()
351 },
352 RuntimeFlavor::Threaded => quote_spanned! {last_stmt_start_span=>
353 #crate_path::runtime::Builder::new_multi_thread()
354 },
355 };
356 if let Some(v) = config.worker_threads {
357 rt = quote_spanned! {last_stmt_start_span=> #rt.worker_threads(#v) };
358 }
359 if let Some(v) = config.start_paused {
360 rt = quote_spanned! {last_stmt_start_span=> #rt.start_paused(#v) };
361 }
362
363 let header = if is_test {
364 quote! {
365 #[::core::prelude::v1::test]
366 }
367 } else {
368 quote! {}
369 };
370
371 let body_ident = quote! { body };
372 let last_block = quote_spanned! {last_stmt_end_span=>
373 #[allow(clippy::expect_used, clippy::diverging_sub_expression)]
374 {
375 return #rt
376 .enable_all()
377 .build()
378 .expect("Failed building the Runtime")
379 .block_on(#body_ident);
380 }
381 };
382
383 let body = input.body();
384
385 // For test functions pin the body to the stack and use `Pin<&mut dyn
386 // Future>` to reduce the amount of `Runtime::block_on` (and related
387 // functions) copies we generate during compilation due to the generic
388 // parameter `F` (the future to block on). This could have an impact on
389 // performance, but because it's only for testing it's unlikely to be very
390 // large.
391 //
392 // We don't do this for the main function as it should only be used once so
393 // there will be no benefit.
394 let body = if is_test {
395 let output_type = match &input.sig.output {
396 // For functions with no return value syn doesn't print anything,
397 // but that doesn't work as `Output` for our boxed `Future`, so
398 // default to `()` (the same type as the function output).
399 syn::ReturnType::Default => quote! { () },
400 syn::ReturnType::Type(_, ret_type) => quote! { #ret_type },
401 };
402 quote! {
403 let body = async #body;
404 #crate_path::pin!(body);
405 let body: ::core::pin::Pin<&mut dyn ::core::future::Future<Output = #output_type>> = body;
406 }
407 } else {
408 quote! {
409 let body = async #body;
410 }
411 };
412
413 input.into_tokens(header, body, last_block)
414}
415
416fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream {
417 tokens.extend(error.into_compile_error());
418 tokens
419}
420
421#[cfg(not(test))] // Work around for rust-lang/rust#62127
422pub(crate) fn main(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream {
423 // If any of the steps for this macro fail, we still want to expand to an item that is as close
424 // to the expected output as possible. This helps out IDEs such that completions and other
425 // related features keep working.
426 let input: ItemFn = match syn::parse2(item.clone()) {
427 Ok(it) => it,
428 Err(e) => return token_stream_with_error(item, e),
429 };
430
431 let config = if input.sig.ident == "main" && !input.sig.inputs.is_empty() {
432 let msg = "the main function cannot accept arguments";
433 Err(syn::Error::new_spanned(&input.sig.ident, msg))
434 } else {
435 AttributeArgs::parse_terminated
436 .parse2(args)
437 .and_then(|args| build_config(&input, args, false, rt_multi_thread))
438 };
439
440 match config {
441 Ok(config) => parse_knobs(input, false, config),
442 Err(e) => token_stream_with_error(parse_knobs(input, false, DEFAULT_ERROR_CONFIG), e),
443 }
444}
445
446pub(crate) fn test(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream {
447 // If any of the steps for this macro fail, we still want to expand to an item that is as close
448 // to the expected output as possible. This helps out IDEs such that completions and other
449 // related features keep working.
450 let input: ItemFn = match syn::parse2(item.clone()) {
451 Ok(it) => it,
452 Err(e) => return token_stream_with_error(item, e),
453 };
454 let config = if let Some(attr) = input.attrs().find(|attr| attr.meta.path().is_ident("test")) {
455 let msg = "second test attribute is supplied";
456 Err(syn::Error::new_spanned(attr, msg))
457 } else {
458 AttributeArgs::parse_terminated
459 .parse2(args)
460 .and_then(|args| build_config(&input, args, true, rt_multi_thread))
461 };
462
463 match config {
464 Ok(config) => parse_knobs(input, true, config),
465 Err(e) => token_stream_with_error(parse_knobs(input, true, DEFAULT_ERROR_CONFIG), e),
466 }
467}
468
469struct ItemFn {
470 outer_attrs: Vec<Attribute>,
471 vis: Visibility,
472 sig: Signature,
473 brace_token: syn::token::Brace,
474 inner_attrs: Vec<Attribute>,
475 stmts: Vec<proc_macro2::TokenStream>,
476}
477
478impl ItemFn {
479 /// Access all attributes of the function item.
480 fn attrs(&self) -> impl Iterator<Item = &Attribute> {
481 self.outer_attrs.iter().chain(self.inner_attrs.iter())
482 }
483
484 /// Get the body of the function item in a manner so that it can be
485 /// conveniently used with the `quote!` macro.
486 fn body(&self) -> Body<'_> {
487 Body {
488 brace_token: self.brace_token,
489 stmts: &self.stmts,
490 }
491 }
492
493 /// Convert our local function item into a token stream.
494 fn into_tokens(
495 self,
496 header: proc_macro2::TokenStream,
497 body: proc_macro2::TokenStream,
498 last_block: proc_macro2::TokenStream,
499 ) -> TokenStream {
500 let mut tokens = proc_macro2::TokenStream::new();
501 header.to_tokens(&mut tokens);
502
503 // Outer attributes are simply streamed as-is.
504 for attr in self.outer_attrs {
505 attr.to_tokens(&mut tokens);
506 }
507
508 // Inner attributes require extra care, since they're not supported on
509 // blocks (which is what we're expanded into) we instead lift them
510 // outside of the function. This matches the behaviour of `syn`.
511 for mut attr in self.inner_attrs {
512 attr.style = syn::AttrStyle::Outer;
513 attr.to_tokens(&mut tokens);
514 }
515
516 self.vis.to_tokens(&mut tokens);
517 self.sig.to_tokens(&mut tokens);
518
519 self.brace_token.surround(&mut tokens, |tokens| {
520 body.to_tokens(tokens);
521 last_block.to_tokens(tokens);
522 });
523
524 tokens
525 }
526}
527
528impl Parse for ItemFn {
529 #[inline]
530 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
531 // This parse implementation has been largely lifted from `syn`, with
532 // the exception of:
533 // * We don't have access to the plumbing necessary to parse inner
534 // attributes in-place.
535 // * We do our own statements parsing to avoid recursively parsing
536 // entire statements and only look for the parts we're interested in.
537
538 let outer_attrs = input.call(Attribute::parse_outer)?;
539 let vis: Visibility = input.parse()?;
540 let sig: Signature = input.parse()?;
541
542 let content;
543 let brace_token = braced!(content in input);
544 let inner_attrs = Attribute::parse_inner(&content)?;
545
546 let mut buf = proc_macro2::TokenStream::new();
547 let mut stmts = Vec::new();
548
549 while !content.is_empty() {
550 if let Some(semi) = content.parse::<Option<syn::Token![;]>>()? {
551 semi.to_tokens(&mut buf);
552 stmts.push(buf);
553 buf = proc_macro2::TokenStream::new();
554 continue;
555 }
556
557 // Parse a single token tree and extend our current buffer with it.
558 // This avoids parsing the entire content of the sub-tree.
559 buf.extend([content.parse::<TokenTree>()?]);
560 }
561
562 if !buf.is_empty() {
563 stmts.push(buf);
564 }
565
566 Ok(Self {
567 outer_attrs,
568 vis,
569 sig,
570 brace_token,
571 inner_attrs,
572 stmts,
573 })
574 }
575}
576
577struct Body<'a> {
578 brace_token: syn::token::Brace,
579 // Statements, with terminating `;`.
580 stmts: &'a [TokenStream],
581}
582
583impl ToTokens for Body<'_> {
584 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
585 self.brace_token.surround(tokens, |tokens| {
586 for stmt in self.stmts {
587 stmt.to_tokens(tokens);
588 }
589 });
590 }
591}
592