1 | use proc_macro2::{Span, TokenStream, TokenTree}; |
2 | use quote::{quote, quote_spanned, ToTokens}; |
3 | use syn::parse::{Parse, ParseStream, Parser}; |
4 | use syn::{braced, Attribute, Ident, Path, Signature, Visibility}; |
5 | |
6 | // syn::AttributeArgs does not implement syn::Parse |
7 | type AttributeArgs = syn::punctuated::Punctuated<syn::Meta, syn::Token![,]>; |
8 | |
9 | #[derive(Clone, Copy, PartialEq)] |
10 | enum RuntimeFlavor { |
11 | CurrentThread, |
12 | Threaded, |
13 | } |
14 | |
15 | impl RuntimeFlavor { |
16 | fn from_str(s: &str) -> Result<RuntimeFlavor, String> { |
17 | match s { |
18 | "current_thread" => Ok(RuntimeFlavor::CurrentThread), |
19 | "multi_thread" => Ok(RuntimeFlavor::Threaded), |
20 | "single_thread" => Err("The single threaded runtime flavor is called `current_thread`." .to_string()), |
21 | "basic_scheduler" => Err("The `basic_scheduler` runtime flavor has been renamed to `current_thread`." .to_string()), |
22 | "threaded_scheduler" => Err("The `threaded_scheduler` runtime flavor has been renamed to `multi_thread`." .to_string()), |
23 | _ => Err(format!("No such runtime flavor `{}`. The runtime flavors are `current_thread` and `multi_thread`." , s)), |
24 | } |
25 | } |
26 | } |
27 | |
28 | struct FinalConfig { |
29 | flavor: RuntimeFlavor, |
30 | worker_threads: Option<usize>, |
31 | start_paused: Option<bool>, |
32 | crate_name: Option<Path>, |
33 | } |
34 | |
35 | /// Config used in case of the attribute not being able to build a valid config |
36 | const DEFAULT_ERROR_CONFIG: FinalConfig = FinalConfig { |
37 | flavor: RuntimeFlavor::CurrentThread, |
38 | worker_threads: None, |
39 | start_paused: None, |
40 | crate_name: None, |
41 | }; |
42 | |
43 | struct Configuration { |
44 | rt_multi_thread_available: bool, |
45 | default_flavor: RuntimeFlavor, |
46 | flavor: Option<RuntimeFlavor>, |
47 | worker_threads: Option<(usize, Span)>, |
48 | start_paused: Option<(bool, Span)>, |
49 | is_test: bool, |
50 | crate_name: Option<Path>, |
51 | } |
52 | |
53 | impl Configuration { |
54 | fn new(is_test: bool, rt_multi_thread: bool) -> Self { |
55 | Configuration { |
56 | rt_multi_thread_available: rt_multi_thread, |
57 | default_flavor: match is_test { |
58 | true => RuntimeFlavor::CurrentThread, |
59 | false => RuntimeFlavor::Threaded, |
60 | }, |
61 | flavor: None, |
62 | worker_threads: None, |
63 | start_paused: None, |
64 | is_test, |
65 | crate_name: None, |
66 | } |
67 | } |
68 | |
69 | fn set_flavor(&mut self, runtime: syn::Lit, span: Span) -> Result<(), syn::Error> { |
70 | if self.flavor.is_some() { |
71 | return Err(syn::Error::new(span, "`flavor` set multiple times." )); |
72 | } |
73 | |
74 | let runtime_str = parse_string(runtime, span, "flavor" )?; |
75 | let runtime = |
76 | RuntimeFlavor::from_str(&runtime_str).map_err(|err| syn::Error::new(span, err))?; |
77 | self.flavor = Some(runtime); |
78 | Ok(()) |
79 | } |
80 | |
81 | fn set_worker_threads( |
82 | &mut self, |
83 | worker_threads: syn::Lit, |
84 | span: Span, |
85 | ) -> Result<(), syn::Error> { |
86 | if self.worker_threads.is_some() { |
87 | return Err(syn::Error::new( |
88 | span, |
89 | "`worker_threads` set multiple times." , |
90 | )); |
91 | } |
92 | |
93 | let worker_threads = parse_int(worker_threads, span, "worker_threads" )?; |
94 | if worker_threads == 0 { |
95 | return Err(syn::Error::new(span, "`worker_threads` may not be 0." )); |
96 | } |
97 | self.worker_threads = Some((worker_threads, span)); |
98 | Ok(()) |
99 | } |
100 | |
101 | fn set_start_paused(&mut self, start_paused: syn::Lit, span: Span) -> Result<(), syn::Error> { |
102 | if self.start_paused.is_some() { |
103 | return Err(syn::Error::new(span, "`start_paused` set multiple times." )); |
104 | } |
105 | |
106 | let start_paused = parse_bool(start_paused, span, "start_paused" )?; |
107 | self.start_paused = Some((start_paused, span)); |
108 | Ok(()) |
109 | } |
110 | |
111 | fn set_crate_name(&mut self, name: syn::Lit, span: Span) -> Result<(), syn::Error> { |
112 | if self.crate_name.is_some() { |
113 | return Err(syn::Error::new(span, "`crate` set multiple times." )); |
114 | } |
115 | let name_path = parse_path(name, span, "crate" )?; |
116 | self.crate_name = Some(name_path); |
117 | Ok(()) |
118 | } |
119 | |
120 | fn macro_name(&self) -> &'static str { |
121 | if self.is_test { |
122 | "tokio::test" |
123 | } else { |
124 | "tokio::main" |
125 | } |
126 | } |
127 | |
128 | fn build(&self) -> Result<FinalConfig, syn::Error> { |
129 | use RuntimeFlavor as F; |
130 | |
131 | let flavor = self.flavor.unwrap_or(self.default_flavor); |
132 | let worker_threads = match (flavor, self.worker_threads) { |
133 | (F::CurrentThread, Some((_, worker_threads_span))) => { |
134 | let msg = format!( |
135 | "The `worker_threads` option requires the `multi_thread` runtime flavor. Use `#[{}(flavor = \"multi_thread \")]`" , |
136 | self.macro_name(), |
137 | ); |
138 | return Err(syn::Error::new(worker_threads_span, msg)); |
139 | } |
140 | (F::CurrentThread, None) => None, |
141 | (F::Threaded, worker_threads) if self.rt_multi_thread_available => { |
142 | worker_threads.map(|(val, _span)| val) |
143 | } |
144 | (F::Threaded, _) => { |
145 | let msg = if self.flavor.is_none() { |
146 | "The default runtime flavor is `multi_thread`, but the `rt-multi-thread` feature is disabled." |
147 | } else { |
148 | "The runtime flavor `multi_thread` requires the `rt-multi-thread` feature." |
149 | }; |
150 | return Err(syn::Error::new(Span::call_site(), msg)); |
151 | } |
152 | }; |
153 | |
154 | let start_paused = match (flavor, self.start_paused) { |
155 | (F::Threaded, Some((_, start_paused_span))) => { |
156 | let msg = format!( |
157 | "The `start_paused` option requires the `current_thread` runtime flavor. Use `#[{}(flavor = \"current_thread \")]`" , |
158 | self.macro_name(), |
159 | ); |
160 | return Err(syn::Error::new(start_paused_span, msg)); |
161 | } |
162 | (F::CurrentThread, Some((start_paused, _))) => Some(start_paused), |
163 | (_, None) => None, |
164 | }; |
165 | |
166 | Ok(FinalConfig { |
167 | crate_name: self.crate_name.clone(), |
168 | flavor, |
169 | worker_threads, |
170 | start_paused, |
171 | }) |
172 | } |
173 | } |
174 | |
175 | fn parse_int(int: syn::Lit, span: Span, field: &str) -> Result<usize, syn::Error> { |
176 | match int { |
177 | syn::Lit::Int(lit) => match lit.base10_parse::<usize>() { |
178 | Ok(value) => Ok(value), |
179 | Err(e) => Err(syn::Error::new( |
180 | span, |
181 | format!("Failed to parse value of `{}` as integer: {}" , field, e), |
182 | )), |
183 | }, |
184 | _ => Err(syn::Error::new( |
185 | span, |
186 | format!("Failed to parse value of `{}` as integer." , field), |
187 | )), |
188 | } |
189 | } |
190 | |
191 | fn parse_string(int: syn::Lit, span: Span, field: &str) -> Result<String, syn::Error> { |
192 | match int { |
193 | syn::Lit::Str(s) => Ok(s.value()), |
194 | syn::Lit::Verbatim(s) => Ok(s.to_string()), |
195 | _ => Err(syn::Error::new( |
196 | span, |
197 | format!("Failed to parse value of `{}` as string." , field), |
198 | )), |
199 | } |
200 | } |
201 | |
202 | fn parse_path(lit: syn::Lit, span: Span, field: &str) -> Result<Path, syn::Error> { |
203 | match lit { |
204 | syn::Lit::Str(s) => { |
205 | let err = syn::Error::new( |
206 | span, |
207 | format!( |
208 | "Failed to parse value of `{}` as path: \"{} \"" , |
209 | field, |
210 | s.value() |
211 | ), |
212 | ); |
213 | s.parse::<syn::Path>().map_err(|_| err.clone()) |
214 | } |
215 | _ => Err(syn::Error::new( |
216 | span, |
217 | format!("Failed to parse value of `{}` as path." , field), |
218 | )), |
219 | } |
220 | } |
221 | |
222 | fn parse_bool(bool: syn::Lit, span: Span, field: &str) -> Result<bool, syn::Error> { |
223 | match bool { |
224 | syn::Lit::Bool(b) => Ok(b.value), |
225 | _ => Err(syn::Error::new( |
226 | span, |
227 | format!("Failed to parse value of `{}` as bool." , field), |
228 | )), |
229 | } |
230 | } |
231 | |
232 | fn build_config( |
233 | input: &ItemFn, |
234 | args: AttributeArgs, |
235 | is_test: bool, |
236 | rt_multi_thread: bool, |
237 | ) -> Result<FinalConfig, syn::Error> { |
238 | if input.sig.asyncness.is_none() { |
239 | let msg = "the `async` keyword is missing from the function declaration" ; |
240 | return Err(syn::Error::new_spanned(input.sig.fn_token, msg)); |
241 | } |
242 | |
243 | let mut config = Configuration::new(is_test, rt_multi_thread); |
244 | let macro_name = config.macro_name(); |
245 | |
246 | for arg in args { |
247 | match arg { |
248 | syn::Meta::NameValue(namevalue) => { |
249 | let ident = namevalue |
250 | .path |
251 | .get_ident() |
252 | .ok_or_else(|| { |
253 | syn::Error::new_spanned(&namevalue, "Must have specified ident" ) |
254 | })? |
255 | .to_string() |
256 | .to_lowercase(); |
257 | let lit = match &namevalue.value { |
258 | syn::Expr::Lit(syn::ExprLit { lit, .. }) => lit, |
259 | expr => return Err(syn::Error::new_spanned(expr, "Must be a literal" )), |
260 | }; |
261 | match ident.as_str() { |
262 | "worker_threads" => { |
263 | config.set_worker_threads(lit.clone(), syn::spanned::Spanned::span(lit))?; |
264 | } |
265 | "flavor" => { |
266 | config.set_flavor(lit.clone(), syn::spanned::Spanned::span(lit))?; |
267 | } |
268 | "start_paused" => { |
269 | config.set_start_paused(lit.clone(), syn::spanned::Spanned::span(lit))?; |
270 | } |
271 | "core_threads" => { |
272 | let msg = "Attribute `core_threads` is renamed to `worker_threads`" ; |
273 | return Err(syn::Error::new_spanned(namevalue, msg)); |
274 | } |
275 | "crate" => { |
276 | config.set_crate_name(lit.clone(), syn::spanned::Spanned::span(lit))?; |
277 | } |
278 | name => { |
279 | let msg = format!( |
280 | "Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `crate`" , |
281 | name, |
282 | ); |
283 | return Err(syn::Error::new_spanned(namevalue, msg)); |
284 | } |
285 | } |
286 | } |
287 | syn::Meta::Path(path) => { |
288 | let name = path |
289 | .get_ident() |
290 | .ok_or_else(|| syn::Error::new_spanned(&path, "Must have specified ident" ))? |
291 | .to_string() |
292 | .to_lowercase(); |
293 | let msg = match name.as_str() { |
294 | "threaded_scheduler" | "multi_thread" => { |
295 | format!( |
296 | "Set the runtime flavor with #[{}(flavor = \"multi_thread \")]." , |
297 | macro_name |
298 | ) |
299 | } |
300 | "basic_scheduler" | "current_thread" | "single_threaded" => { |
301 | format!( |
302 | "Set the runtime flavor with #[{}(flavor = \"current_thread \")]." , |
303 | macro_name |
304 | ) |
305 | } |
306 | "flavor" | "worker_threads" | "start_paused" => { |
307 | format!("The `{}` attribute requires an argument." , name) |
308 | } |
309 | name => { |
310 | format!("Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `crate`" , name) |
311 | } |
312 | }; |
313 | return Err(syn::Error::new_spanned(path, msg)); |
314 | } |
315 | other => { |
316 | return Err(syn::Error::new_spanned( |
317 | other, |
318 | "Unknown attribute inside the macro" , |
319 | )); |
320 | } |
321 | } |
322 | } |
323 | |
324 | config.build() |
325 | } |
326 | |
327 | fn parse_knobs(mut input: ItemFn, is_test: bool, config: FinalConfig) -> TokenStream { |
328 | input.sig.asyncness = None; |
329 | |
330 | // If type mismatch occurs, the current rustc points to the last statement. |
331 | let (last_stmt_start_span, last_stmt_end_span) = { |
332 | let mut last_stmt = input.stmts.last().cloned().unwrap_or_default().into_iter(); |
333 | |
334 | // `Span` on stable Rust has a limitation that only points to the first |
335 | // token, not the whole tokens. We can work around this limitation by |
336 | // using the first/last span of the tokens like |
337 | // `syn::Error::new_spanned` does. |
338 | let start = last_stmt.next().map_or_else(Span::call_site, |t| t.span()); |
339 | let end = last_stmt.last().map_or(start, |t| t.span()); |
340 | (start, end) |
341 | }; |
342 | |
343 | let crate_path = config |
344 | .crate_name |
345 | .map(ToTokens::into_token_stream) |
346 | .unwrap_or_else(|| Ident::new("tokio" , last_stmt_start_span).into_token_stream()); |
347 | |
348 | let mut rt = match config.flavor { |
349 | RuntimeFlavor::CurrentThread => quote_spanned! {last_stmt_start_span=> |
350 | #crate_path::runtime::Builder::new_current_thread() |
351 | }, |
352 | RuntimeFlavor::Threaded => quote_spanned! {last_stmt_start_span=> |
353 | #crate_path::runtime::Builder::new_multi_thread() |
354 | }, |
355 | }; |
356 | if let Some(v) = config.worker_threads { |
357 | rt = quote_spanned! {last_stmt_start_span=> #rt.worker_threads(#v) }; |
358 | } |
359 | if let Some(v) = config.start_paused { |
360 | rt = quote_spanned! {last_stmt_start_span=> #rt.start_paused(#v) }; |
361 | } |
362 | |
363 | let header = if is_test { |
364 | quote! { |
365 | #[::core::prelude::v1::test] |
366 | } |
367 | } else { |
368 | quote! {} |
369 | }; |
370 | |
371 | let body_ident = quote! { body }; |
372 | let last_block = quote_spanned! {last_stmt_end_span=> |
373 | #[allow(clippy::expect_used, clippy::diverging_sub_expression)] |
374 | { |
375 | return #rt |
376 | .enable_all() |
377 | .build() |
378 | .expect("Failed building the Runtime" ) |
379 | .block_on(#body_ident); |
380 | } |
381 | }; |
382 | |
383 | let body = input.body(); |
384 | |
385 | // For test functions pin the body to the stack and use `Pin<&mut dyn |
386 | // Future>` to reduce the amount of `Runtime::block_on` (and related |
387 | // functions) copies we generate during compilation due to the generic |
388 | // parameter `F` (the future to block on). This could have an impact on |
389 | // performance, but because it's only for testing it's unlikely to be very |
390 | // large. |
391 | // |
392 | // We don't do this for the main function as it should only be used once so |
393 | // there will be no benefit. |
394 | let body = if is_test { |
395 | let output_type = match &input.sig.output { |
396 | // For functions with no return value syn doesn't print anything, |
397 | // but that doesn't work as `Output` for our boxed `Future`, so |
398 | // default to `()` (the same type as the function output). |
399 | syn::ReturnType::Default => quote! { () }, |
400 | syn::ReturnType::Type(_, ret_type) => quote! { #ret_type }, |
401 | }; |
402 | quote! { |
403 | let body = async #body; |
404 | #crate_path::pin!(body); |
405 | let body: ::core::pin::Pin<&mut dyn ::core::future::Future<Output = #output_type>> = body; |
406 | } |
407 | } else { |
408 | quote! { |
409 | let body = async #body; |
410 | } |
411 | }; |
412 | |
413 | input.into_tokens(header, body, last_block) |
414 | } |
415 | |
416 | fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream { |
417 | tokens.extend(error.into_compile_error()); |
418 | tokens |
419 | } |
420 | |
421 | #[cfg (not(test))] // Work around for rust-lang/rust#62127 |
422 | pub(crate) fn main(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream { |
423 | // If any of the steps for this macro fail, we still want to expand to an item that is as close |
424 | // to the expected output as possible. This helps out IDEs such that completions and other |
425 | // related features keep working. |
426 | let input: ItemFn = match syn::parse2(item.clone()) { |
427 | Ok(it) => it, |
428 | Err(e) => return token_stream_with_error(item, e), |
429 | }; |
430 | |
431 | let config = if input.sig.ident == "main" && !input.sig.inputs.is_empty() { |
432 | let msg = "the main function cannot accept arguments" ; |
433 | Err(syn::Error::new_spanned(&input.sig.ident, msg)) |
434 | } else { |
435 | AttributeArgs::parse_terminated |
436 | .parse2(args) |
437 | .and_then(|args| build_config(&input, args, false, rt_multi_thread)) |
438 | }; |
439 | |
440 | match config { |
441 | Ok(config) => parse_knobs(input, false, config), |
442 | Err(e) => token_stream_with_error(parse_knobs(input, false, DEFAULT_ERROR_CONFIG), e), |
443 | } |
444 | } |
445 | |
446 | pub(crate) fn test(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream { |
447 | // If any of the steps for this macro fail, we still want to expand to an item that is as close |
448 | // to the expected output as possible. This helps out IDEs such that completions and other |
449 | // related features keep working. |
450 | let input: ItemFn = match syn::parse2(item.clone()) { |
451 | Ok(it) => it, |
452 | Err(e) => return token_stream_with_error(item, e), |
453 | }; |
454 | let config = if let Some(attr) = input.attrs().find(|attr| attr.meta.path().is_ident("test" )) { |
455 | let msg = "second test attribute is supplied" ; |
456 | Err(syn::Error::new_spanned(attr, msg)) |
457 | } else { |
458 | AttributeArgs::parse_terminated |
459 | .parse2(args) |
460 | .and_then(|args| build_config(&input, args, true, rt_multi_thread)) |
461 | }; |
462 | |
463 | match config { |
464 | Ok(config) => parse_knobs(input, true, config), |
465 | Err(e) => token_stream_with_error(parse_knobs(input, true, DEFAULT_ERROR_CONFIG), e), |
466 | } |
467 | } |
468 | |
469 | struct ItemFn { |
470 | outer_attrs: Vec<Attribute>, |
471 | vis: Visibility, |
472 | sig: Signature, |
473 | brace_token: syn::token::Brace, |
474 | inner_attrs: Vec<Attribute>, |
475 | stmts: Vec<proc_macro2::TokenStream>, |
476 | } |
477 | |
478 | impl ItemFn { |
479 | /// Access all attributes of the function item. |
480 | fn attrs(&self) -> impl Iterator<Item = &Attribute> { |
481 | self.outer_attrs.iter().chain(self.inner_attrs.iter()) |
482 | } |
483 | |
484 | /// Get the body of the function item in a manner so that it can be |
485 | /// conveniently used with the `quote!` macro. |
486 | fn body(&self) -> Body<'_> { |
487 | Body { |
488 | brace_token: self.brace_token, |
489 | stmts: &self.stmts, |
490 | } |
491 | } |
492 | |
493 | /// Convert our local function item into a token stream. |
494 | fn into_tokens( |
495 | self, |
496 | header: proc_macro2::TokenStream, |
497 | body: proc_macro2::TokenStream, |
498 | last_block: proc_macro2::TokenStream, |
499 | ) -> TokenStream { |
500 | let mut tokens = proc_macro2::TokenStream::new(); |
501 | header.to_tokens(&mut tokens); |
502 | |
503 | // Outer attributes are simply streamed as-is. |
504 | for attr in self.outer_attrs { |
505 | attr.to_tokens(&mut tokens); |
506 | } |
507 | |
508 | // Inner attributes require extra care, since they're not supported on |
509 | // blocks (which is what we're expanded into) we instead lift them |
510 | // outside of the function. This matches the behaviour of `syn`. |
511 | for mut attr in self.inner_attrs { |
512 | attr.style = syn::AttrStyle::Outer; |
513 | attr.to_tokens(&mut tokens); |
514 | } |
515 | |
516 | self.vis.to_tokens(&mut tokens); |
517 | self.sig.to_tokens(&mut tokens); |
518 | |
519 | self.brace_token.surround(&mut tokens, |tokens| { |
520 | body.to_tokens(tokens); |
521 | last_block.to_tokens(tokens); |
522 | }); |
523 | |
524 | tokens |
525 | } |
526 | } |
527 | |
528 | impl Parse for ItemFn { |
529 | #[inline ] |
530 | fn parse(input: ParseStream<'_>) -> syn::Result<Self> { |
531 | // This parse implementation has been largely lifted from `syn`, with |
532 | // the exception of: |
533 | // * We don't have access to the plumbing necessary to parse inner |
534 | // attributes in-place. |
535 | // * We do our own statements parsing to avoid recursively parsing |
536 | // entire statements and only look for the parts we're interested in. |
537 | |
538 | let outer_attrs = input.call(Attribute::parse_outer)?; |
539 | let vis: Visibility = input.parse()?; |
540 | let sig: Signature = input.parse()?; |
541 | |
542 | let content; |
543 | let brace_token = braced!(content in input); |
544 | let inner_attrs = Attribute::parse_inner(&content)?; |
545 | |
546 | let mut buf = proc_macro2::TokenStream::new(); |
547 | let mut stmts = Vec::new(); |
548 | |
549 | while !content.is_empty() { |
550 | if let Some(semi) = content.parse::<Option<syn::Token![;]>>()? { |
551 | semi.to_tokens(&mut buf); |
552 | stmts.push(buf); |
553 | buf = proc_macro2::TokenStream::new(); |
554 | continue; |
555 | } |
556 | |
557 | // Parse a single token tree and extend our current buffer with it. |
558 | // This avoids parsing the entire content of the sub-tree. |
559 | buf.extend([content.parse::<TokenTree>()?]); |
560 | } |
561 | |
562 | if !buf.is_empty() { |
563 | stmts.push(buf); |
564 | } |
565 | |
566 | Ok(Self { |
567 | outer_attrs, |
568 | vis, |
569 | sig, |
570 | brace_token, |
571 | inner_attrs, |
572 | stmts, |
573 | }) |
574 | } |
575 | } |
576 | |
577 | struct Body<'a> { |
578 | brace_token: syn::token::Brace, |
579 | // Statements, with terminating `;`. |
580 | stmts: &'a [TokenStream], |
581 | } |
582 | |
583 | impl ToTokens for Body<'_> { |
584 | fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { |
585 | self.brace_token.surround(tokens, |tokens| { |
586 | for stmt in self.stmts { |
587 | stmt.to_tokens(tokens); |
588 | } |
589 | }); |
590 | } |
591 | } |
592 | |