1 | use proc_macro2::{Span, TokenStream, TokenTree}; |
2 | use quote::{quote, quote_spanned, ToTokens}; |
3 | use syn::parse::{Parse, ParseStream, Parser}; |
4 | use syn::{braced, Attribute, Ident, Path, Signature, Visibility}; |
5 | |
6 | // syn::AttributeArgs does not implement syn::Parse |
7 | type AttributeArgs = syn::punctuated::Punctuated<syn::Meta, syn::Token![,]>; |
8 | |
9 | #[derive (Clone, Copy, PartialEq)] |
10 | enum RuntimeFlavor { |
11 | CurrentThread, |
12 | Threaded, |
13 | } |
14 | |
15 | impl RuntimeFlavor { |
16 | fn from_str(s: &str) -> Result<RuntimeFlavor, String> { |
17 | match s { |
18 | "current_thread" => Ok(RuntimeFlavor::CurrentThread), |
19 | "multi_thread" => Ok(RuntimeFlavor::Threaded), |
20 | "single_thread" => Err("The single threaded runtime flavor is called `current_thread`." .to_string()), |
21 | "basic_scheduler" => Err("The `basic_scheduler` runtime flavor has been renamed to `current_thread`." .to_string()), |
22 | "threaded_scheduler" => Err("The `threaded_scheduler` runtime flavor has been renamed to `multi_thread`." .to_string()), |
23 | _ => Err(format!("No such runtime flavor ` {s}`. The runtime flavors are `current_thread` and `multi_thread`." )), |
24 | } |
25 | } |
26 | } |
27 | |
28 | #[derive (Clone, Copy, PartialEq)] |
29 | enum UnhandledPanic { |
30 | Ignore, |
31 | ShutdownRuntime, |
32 | } |
33 | |
34 | impl UnhandledPanic { |
35 | fn from_str(s: &str) -> Result<UnhandledPanic, String> { |
36 | match s { |
37 | "ignore" => Ok(UnhandledPanic::Ignore), |
38 | "shutdown_runtime" => Ok(UnhandledPanic::ShutdownRuntime), |
39 | _ => Err(format!("No such unhandled panic behavior ` {s}`. The unhandled panic behaviors are `ignore` and `shutdown_runtime`." )), |
40 | } |
41 | } |
42 | |
43 | fn into_tokens(self, crate_path: &TokenStream) -> TokenStream { |
44 | match self { |
45 | UnhandledPanic::Ignore => quote! { #crate_path::runtime::UnhandledPanic::Ignore }, |
46 | UnhandledPanic::ShutdownRuntime => { |
47 | quote! { #crate_path::runtime::UnhandledPanic::ShutdownRuntime } |
48 | } |
49 | } |
50 | } |
51 | } |
52 | |
53 | struct FinalConfig { |
54 | flavor: RuntimeFlavor, |
55 | worker_threads: Option<usize>, |
56 | start_paused: Option<bool>, |
57 | crate_name: Option<Path>, |
58 | unhandled_panic: Option<UnhandledPanic>, |
59 | } |
60 | |
61 | /// Config used in case of the attribute not being able to build a valid config |
62 | const DEFAULT_ERROR_CONFIG: FinalConfig = FinalConfig { |
63 | flavor: RuntimeFlavor::CurrentThread, |
64 | worker_threads: None, |
65 | start_paused: None, |
66 | crate_name: None, |
67 | unhandled_panic: None, |
68 | }; |
69 | |
70 | struct Configuration { |
71 | rt_multi_thread_available: bool, |
72 | default_flavor: RuntimeFlavor, |
73 | flavor: Option<RuntimeFlavor>, |
74 | worker_threads: Option<(usize, Span)>, |
75 | start_paused: Option<(bool, Span)>, |
76 | is_test: bool, |
77 | crate_name: Option<Path>, |
78 | unhandled_panic: Option<(UnhandledPanic, Span)>, |
79 | } |
80 | |
81 | impl Configuration { |
82 | fn new(is_test: bool, rt_multi_thread: bool) -> Self { |
83 | Configuration { |
84 | rt_multi_thread_available: rt_multi_thread, |
85 | default_flavor: match is_test { |
86 | true => RuntimeFlavor::CurrentThread, |
87 | false => RuntimeFlavor::Threaded, |
88 | }, |
89 | flavor: None, |
90 | worker_threads: None, |
91 | start_paused: None, |
92 | is_test, |
93 | crate_name: None, |
94 | unhandled_panic: None, |
95 | } |
96 | } |
97 | |
98 | fn set_flavor(&mut self, runtime: syn::Lit, span: Span) -> Result<(), syn::Error> { |
99 | if self.flavor.is_some() { |
100 | return Err(syn::Error::new(span, "`flavor` set multiple times." )); |
101 | } |
102 | |
103 | let runtime_str = parse_string(runtime, span, "flavor" )?; |
104 | let runtime = |
105 | RuntimeFlavor::from_str(&runtime_str).map_err(|err| syn::Error::new(span, err))?; |
106 | self.flavor = Some(runtime); |
107 | Ok(()) |
108 | } |
109 | |
110 | fn set_worker_threads( |
111 | &mut self, |
112 | worker_threads: syn::Lit, |
113 | span: Span, |
114 | ) -> Result<(), syn::Error> { |
115 | if self.worker_threads.is_some() { |
116 | return Err(syn::Error::new( |
117 | span, |
118 | "`worker_threads` set multiple times." , |
119 | )); |
120 | } |
121 | |
122 | let worker_threads = parse_int(worker_threads, span, "worker_threads" )?; |
123 | if worker_threads == 0 { |
124 | return Err(syn::Error::new(span, "`worker_threads` may not be 0." )); |
125 | } |
126 | self.worker_threads = Some((worker_threads, span)); |
127 | Ok(()) |
128 | } |
129 | |
130 | fn set_start_paused(&mut self, start_paused: syn::Lit, span: Span) -> Result<(), syn::Error> { |
131 | if self.start_paused.is_some() { |
132 | return Err(syn::Error::new(span, "`start_paused` set multiple times." )); |
133 | } |
134 | |
135 | let start_paused = parse_bool(start_paused, span, "start_paused" )?; |
136 | self.start_paused = Some((start_paused, span)); |
137 | Ok(()) |
138 | } |
139 | |
140 | fn set_crate_name(&mut self, name: syn::Lit, span: Span) -> Result<(), syn::Error> { |
141 | if self.crate_name.is_some() { |
142 | return Err(syn::Error::new(span, "`crate` set multiple times." )); |
143 | } |
144 | let name_path = parse_path(name, span, "crate" )?; |
145 | self.crate_name = Some(name_path); |
146 | Ok(()) |
147 | } |
148 | |
149 | fn set_unhandled_panic( |
150 | &mut self, |
151 | unhandled_panic: syn::Lit, |
152 | span: Span, |
153 | ) -> Result<(), syn::Error> { |
154 | if self.unhandled_panic.is_some() { |
155 | return Err(syn::Error::new( |
156 | span, |
157 | "`unhandled_panic` set multiple times." , |
158 | )); |
159 | } |
160 | |
161 | let unhandled_panic = parse_string(unhandled_panic, span, "unhandled_panic" )?; |
162 | let unhandled_panic = |
163 | UnhandledPanic::from_str(&unhandled_panic).map_err(|err| syn::Error::new(span, err))?; |
164 | self.unhandled_panic = Some((unhandled_panic, span)); |
165 | Ok(()) |
166 | } |
167 | |
168 | fn macro_name(&self) -> &'static str { |
169 | if self.is_test { |
170 | "tokio::test" |
171 | } else { |
172 | "tokio::main" |
173 | } |
174 | } |
175 | |
176 | fn build(&self) -> Result<FinalConfig, syn::Error> { |
177 | use RuntimeFlavor as F; |
178 | |
179 | let flavor = self.flavor.unwrap_or(self.default_flavor); |
180 | let worker_threads = match (flavor, self.worker_threads) { |
181 | (F::CurrentThread, Some((_, worker_threads_span))) => { |
182 | let msg = format!( |
183 | "The `worker_threads` option requires the `multi_thread` runtime flavor. Use `#[ {}(flavor = \"multi_thread \")]`" , |
184 | self.macro_name(), |
185 | ); |
186 | return Err(syn::Error::new(worker_threads_span, msg)); |
187 | } |
188 | (F::CurrentThread, None) => None, |
189 | (F::Threaded, worker_threads) if self.rt_multi_thread_available => { |
190 | worker_threads.map(|(val, _span)| val) |
191 | } |
192 | (F::Threaded, _) => { |
193 | let msg = if self.flavor.is_none() { |
194 | "The default runtime flavor is `multi_thread`, but the `rt-multi-thread` feature is disabled." |
195 | } else { |
196 | "The runtime flavor `multi_thread` requires the `rt-multi-thread` feature." |
197 | }; |
198 | return Err(syn::Error::new(Span::call_site(), msg)); |
199 | } |
200 | }; |
201 | |
202 | let start_paused = match (flavor, self.start_paused) { |
203 | (F::Threaded, Some((_, start_paused_span))) => { |
204 | let msg = format!( |
205 | "The `start_paused` option requires the `current_thread` runtime flavor. Use `#[ {}(flavor = \"current_thread \")]`" , |
206 | self.macro_name(), |
207 | ); |
208 | return Err(syn::Error::new(start_paused_span, msg)); |
209 | } |
210 | (F::CurrentThread, Some((start_paused, _))) => Some(start_paused), |
211 | (_, None) => None, |
212 | }; |
213 | |
214 | let unhandled_panic = match (flavor, self.unhandled_panic) { |
215 | (F::Threaded, Some((_, unhandled_panic_span))) => { |
216 | let msg = format!( |
217 | "The `unhandled_panic` option requires the `current_thread` runtime flavor. Use `#[ {}(flavor = \"current_thread \")]`" , |
218 | self.macro_name(), |
219 | ); |
220 | return Err(syn::Error::new(unhandled_panic_span, msg)); |
221 | } |
222 | (F::CurrentThread, Some((unhandled_panic, _))) => Some(unhandled_panic), |
223 | (_, None) => None, |
224 | }; |
225 | |
226 | Ok(FinalConfig { |
227 | crate_name: self.crate_name.clone(), |
228 | flavor, |
229 | worker_threads, |
230 | start_paused, |
231 | unhandled_panic, |
232 | }) |
233 | } |
234 | } |
235 | |
236 | fn parse_int(int: syn::Lit, span: Span, field: &str) -> Result<usize, syn::Error> { |
237 | match int { |
238 | syn::Lit::Int(lit: LitInt) => match lit.base10_parse::<usize>() { |
239 | Ok(value: usize) => Ok(value), |
240 | Err(e: Error) => Err(syn::Error::new( |
241 | span, |
242 | message:format!("Failed to parse value of ` {field}` as integer: {e}" ), |
243 | )), |
244 | }, |
245 | _ => Err(syn::Error::new( |
246 | span, |
247 | message:format!("Failed to parse value of ` {field}` as integer." ), |
248 | )), |
249 | } |
250 | } |
251 | |
252 | fn parse_string(int: syn::Lit, span: Span, field: &str) -> Result<String, syn::Error> { |
253 | match int { |
254 | syn::Lit::Str(s: LitStr) => Ok(s.value()), |
255 | syn::Lit::Verbatim(s: Literal) => Ok(s.to_string()), |
256 | _ => Err(syn::Error::new( |
257 | span, |
258 | message:format!("Failed to parse value of ` {field}` as string." ), |
259 | )), |
260 | } |
261 | } |
262 | |
263 | fn parse_path(lit: syn::Lit, span: Span, field: &str) -> Result<Path, syn::Error> { |
264 | match lit { |
265 | syn::Lit::Str(s: LitStr) => { |
266 | let err: Error = syn::Error::new( |
267 | span, |
268 | message:format!( |
269 | "Failed to parse value of ` {}` as path: \"{}\"" , |
270 | field, |
271 | s.value() |
272 | ), |
273 | ); |
274 | s.parse::<syn::Path>().map_err(|_| err.clone()) |
275 | } |
276 | _ => Err(syn::Error::new( |
277 | span, |
278 | message:format!("Failed to parse value of ` {field}` as path." ), |
279 | )), |
280 | } |
281 | } |
282 | |
283 | fn parse_bool(bool: syn::Lit, span: Span, field: &str) -> Result<bool, syn::Error> { |
284 | match bool { |
285 | syn::Lit::Bool(b: LitBool) => Ok(b.value), |
286 | _ => Err(syn::Error::new( |
287 | span, |
288 | message:format!("Failed to parse value of ` {field}` as bool." ), |
289 | )), |
290 | } |
291 | } |
292 | |
293 | fn build_config( |
294 | input: &ItemFn, |
295 | args: AttributeArgs, |
296 | is_test: bool, |
297 | rt_multi_thread: bool, |
298 | ) -> Result<FinalConfig, syn::Error> { |
299 | if input.sig.asyncness.is_none() { |
300 | let msg = "the `async` keyword is missing from the function declaration" ; |
301 | return Err(syn::Error::new_spanned(input.sig.fn_token, msg)); |
302 | } |
303 | |
304 | let mut config = Configuration::new(is_test, rt_multi_thread); |
305 | let macro_name = config.macro_name(); |
306 | |
307 | for arg in args { |
308 | match arg { |
309 | syn::Meta::NameValue(namevalue) => { |
310 | let ident = namevalue |
311 | .path |
312 | .get_ident() |
313 | .ok_or_else(|| { |
314 | syn::Error::new_spanned(&namevalue, "Must have specified ident" ) |
315 | })? |
316 | .to_string() |
317 | .to_lowercase(); |
318 | let lit = match &namevalue.value { |
319 | syn::Expr::Lit(syn::ExprLit { lit, .. }) => lit, |
320 | expr => return Err(syn::Error::new_spanned(expr, "Must be a literal" )), |
321 | }; |
322 | match ident.as_str() { |
323 | "worker_threads" => { |
324 | config.set_worker_threads(lit.clone(), syn::spanned::Spanned::span(lit))?; |
325 | } |
326 | "flavor" => { |
327 | config.set_flavor(lit.clone(), syn::spanned::Spanned::span(lit))?; |
328 | } |
329 | "start_paused" => { |
330 | config.set_start_paused(lit.clone(), syn::spanned::Spanned::span(lit))?; |
331 | } |
332 | "core_threads" => { |
333 | let msg = "Attribute `core_threads` is renamed to `worker_threads`" ; |
334 | return Err(syn::Error::new_spanned(namevalue, msg)); |
335 | } |
336 | "crate" => { |
337 | config.set_crate_name(lit.clone(), syn::spanned::Spanned::span(lit))?; |
338 | } |
339 | "unhandled_panic" => { |
340 | config |
341 | .set_unhandled_panic(lit.clone(), syn::spanned::Spanned::span(lit))?; |
342 | } |
343 | name => { |
344 | let msg = format!( |
345 | "Unknown attribute {name} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `crate`, `unhandled_panic`" , |
346 | ); |
347 | return Err(syn::Error::new_spanned(namevalue, msg)); |
348 | } |
349 | } |
350 | } |
351 | syn::Meta::Path(path) => { |
352 | let name = path |
353 | .get_ident() |
354 | .ok_or_else(|| syn::Error::new_spanned(&path, "Must have specified ident" ))? |
355 | .to_string() |
356 | .to_lowercase(); |
357 | let msg = match name.as_str() { |
358 | "threaded_scheduler" | "multi_thread" => { |
359 | format!( |
360 | "Set the runtime flavor with #[ {macro_name}(flavor = \"multi_thread \")]." |
361 | ) |
362 | } |
363 | "basic_scheduler" | "current_thread" | "single_threaded" => { |
364 | format!( |
365 | "Set the runtime flavor with #[ {macro_name}(flavor = \"current_thread \")]." |
366 | ) |
367 | } |
368 | "flavor" | "worker_threads" | "start_paused" | "crate" | "unhandled_panic" => { |
369 | format!("The ` {name}` attribute requires an argument." ) |
370 | } |
371 | name => { |
372 | format!("Unknown attribute {name} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `crate`, `unhandled_panic`." ) |
373 | } |
374 | }; |
375 | return Err(syn::Error::new_spanned(path, msg)); |
376 | } |
377 | other => { |
378 | return Err(syn::Error::new_spanned( |
379 | other, |
380 | "Unknown attribute inside the macro" , |
381 | )); |
382 | } |
383 | } |
384 | } |
385 | |
386 | config.build() |
387 | } |
388 | |
389 | fn parse_knobs(mut input: ItemFn, is_test: bool, config: FinalConfig) -> TokenStream { |
390 | input.sig.asyncness = None; |
391 | |
392 | // If type mismatch occurs, the current rustc points to the last statement. |
393 | let (last_stmt_start_span, last_stmt_end_span) = { |
394 | let mut last_stmt = input.stmts.last().cloned().unwrap_or_default().into_iter(); |
395 | |
396 | // `Span` on stable Rust has a limitation that only points to the first |
397 | // token, not the whole tokens. We can work around this limitation by |
398 | // using the first/last span of the tokens like |
399 | // `syn::Error::new_spanned` does. |
400 | let start = last_stmt.next().map_or_else(Span::call_site, |t| t.span()); |
401 | let end = last_stmt.last().map_or(start, |t| t.span()); |
402 | (start, end) |
403 | }; |
404 | |
405 | let crate_path = config |
406 | .crate_name |
407 | .map(ToTokens::into_token_stream) |
408 | .unwrap_or_else(|| Ident::new("tokio" , last_stmt_start_span).into_token_stream()); |
409 | |
410 | let mut rt = match config.flavor { |
411 | RuntimeFlavor::CurrentThread => quote_spanned! {last_stmt_start_span=> |
412 | #crate_path::runtime::Builder::new_current_thread() |
413 | }, |
414 | RuntimeFlavor::Threaded => quote_spanned! {last_stmt_start_span=> |
415 | #crate_path::runtime::Builder::new_multi_thread() |
416 | }, |
417 | }; |
418 | if let Some(v) = config.worker_threads { |
419 | rt = quote_spanned! {last_stmt_start_span=> #rt.worker_threads(#v) }; |
420 | } |
421 | if let Some(v) = config.start_paused { |
422 | rt = quote_spanned! {last_stmt_start_span=> #rt.start_paused(#v) }; |
423 | } |
424 | if let Some(v) = config.unhandled_panic { |
425 | let unhandled_panic = v.into_tokens(&crate_path); |
426 | rt = quote_spanned! {last_stmt_start_span=> #rt.unhandled_panic(#unhandled_panic) }; |
427 | } |
428 | |
429 | let generated_attrs = if is_test { |
430 | quote! { |
431 | #[::core::prelude::v1::test] |
432 | } |
433 | } else { |
434 | quote! {} |
435 | }; |
436 | |
437 | let body_ident = quote! { body }; |
438 | // This explicit `return` is intentional. See tokio-rs/tokio#4636 |
439 | let last_block = quote_spanned! {last_stmt_end_span=> |
440 | #[allow(clippy::expect_used, clippy::diverging_sub_expression, clippy::needless_return)] |
441 | { |
442 | return #rt |
443 | .enable_all() |
444 | .build() |
445 | .expect("Failed building the Runtime" ) |
446 | .block_on(#body_ident); |
447 | } |
448 | }; |
449 | |
450 | let body = input.body(); |
451 | |
452 | // For test functions pin the body to the stack and use `Pin<&mut dyn |
453 | // Future>` to reduce the amount of `Runtime::block_on` (and related |
454 | // functions) copies we generate during compilation due to the generic |
455 | // parameter `F` (the future to block on). This could have an impact on |
456 | // performance, but because it's only for testing it's unlikely to be very |
457 | // large. |
458 | // |
459 | // We don't do this for the main function as it should only be used once so |
460 | // there will be no benefit. |
461 | let body = if is_test { |
462 | let output_type = match &input.sig.output { |
463 | // For functions with no return value syn doesn't print anything, |
464 | // but that doesn't work as `Output` for our boxed `Future`, so |
465 | // default to `()` (the same type as the function output). |
466 | syn::ReturnType::Default => quote! { () }, |
467 | syn::ReturnType::Type(_, ret_type) => quote! { #ret_type }, |
468 | }; |
469 | quote! { |
470 | let body = async #body; |
471 | #crate_path::pin!(body); |
472 | let body: ::core::pin::Pin<&mut dyn ::core::future::Future<Output = #output_type>> = body; |
473 | } |
474 | } else { |
475 | quote! { |
476 | let body = async #body; |
477 | } |
478 | }; |
479 | |
480 | input.into_tokens(generated_attrs, body, last_block) |
481 | } |
482 | |
483 | fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream { |
484 | tokens.extend(iter:error.into_compile_error()); |
485 | tokens |
486 | } |
487 | |
488 | pub(crate) fn main(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream { |
489 | // If any of the steps for this macro fail, we still want to expand to an item that is as close |
490 | // to the expected output as possible. This helps out IDEs such that completions and other |
491 | // related features keep working. |
492 | let input: ItemFn = match syn::parse2(tokens:item.clone()) { |
493 | Ok(it: ItemFn) => it, |
494 | Err(e: Error) => return token_stream_with_error(tokens:item, error:e), |
495 | }; |
496 | |
497 | let config: Result = if input.sig.ident == "main" && !input.sig.inputs.is_empty() { |
498 | let msg: &'static str = "the main function cannot accept arguments" ; |
499 | Err(syn::Error::new_spanned(&input.sig.ident, message:msg)) |
500 | } else { |
501 | AttributeArgs::parse_terminated |
502 | .parse2(args) |
503 | .and_then(|args: Punctuated| build_config(&input, args, is_test:false, rt_multi_thread)) |
504 | }; |
505 | |
506 | match config { |
507 | Ok(config: FinalConfig) => parse_knobs(input, is_test:false, config), |
508 | Err(e: Error) => token_stream_with_error(tokens:parse_knobs(input, false, DEFAULT_ERROR_CONFIG), error:e), |
509 | } |
510 | } |
511 | |
512 | // Check whether given attribute is a test attribute of forms: |
513 | // * `#[test]` |
514 | // * `#[core::prelude::*::test]` or `#[::core::prelude::*::test]` |
515 | // * `#[std::prelude::*::test]` or `#[::std::prelude::*::test]` |
516 | fn is_test_attribute(attr: &Attribute) -> bool { |
517 | let path: &Path = match &attr.meta { |
518 | syn::Meta::Path(path: &Path) => path, |
519 | _ => return false, |
520 | }; |
521 | let candidates: [[&'static str; 4]; 2] = [ |
522 | ["core" , "prelude" , "*" , "test" ], |
523 | ["std" , "prelude" , "*" , "test" ], |
524 | ]; |
525 | if path.leading_colon.is_none() |
526 | && path.segments.len() == 1 |
527 | && path.segments[0].arguments.is_none() |
528 | && path.segments[0].ident == "test" |
529 | { |
530 | return true; |
531 | } else if path.segments.len() != candidates[0].len() { |
532 | return false; |
533 | } |
534 | candidates.into_iter().any(|segments: [&'static str; 4]| { |
535 | path.segments.iter().zip(segments).all(|(segment: &PathSegment, path: &'static str)| { |
536 | segment.arguments.is_none() && (path == "*" || segment.ident == path) |
537 | }) |
538 | }) |
539 | } |
540 | |
541 | pub(crate) fn test(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream { |
542 | // If any of the steps for this macro fail, we still want to expand to an item that is as close |
543 | // to the expected output as possible. This helps out IDEs such that completions and other |
544 | // related features keep working. |
545 | let input: ItemFn = match syn::parse2(tokens:item.clone()) { |
546 | Ok(it: ItemFn) => it, |
547 | Err(e: Error) => return token_stream_with_error(tokens:item, error:e), |
548 | }; |
549 | let config: Result = if let Some(attr: &Attribute) = input.attrs().find(|attr: &&Attribute| is_test_attribute(attr)) { |
550 | let msg: &'static str = "second test attribute is supplied, consider removing or changing the order of your test attributes" ; |
551 | Err(syn::Error::new_spanned(tokens:attr, message:msg)) |
552 | } else { |
553 | AttributeArgs::parse_terminated |
554 | .parse2(args) |
555 | .and_then(|args: Punctuated| build_config(&input, args, is_test:true, rt_multi_thread)) |
556 | }; |
557 | |
558 | match config { |
559 | Ok(config: FinalConfig) => parse_knobs(input, is_test:true, config), |
560 | Err(e: Error) => token_stream_with_error(tokens:parse_knobs(input, true, DEFAULT_ERROR_CONFIG), error:e), |
561 | } |
562 | } |
563 | |
564 | struct ItemFn { |
565 | outer_attrs: Vec<Attribute>, |
566 | vis: Visibility, |
567 | sig: Signature, |
568 | brace_token: syn::token::Brace, |
569 | inner_attrs: Vec<Attribute>, |
570 | stmts: Vec<proc_macro2::TokenStream>, |
571 | } |
572 | |
573 | impl ItemFn { |
574 | /// Access all attributes of the function item. |
575 | fn attrs(&self) -> impl Iterator<Item = &Attribute> { |
576 | self.outer_attrs.iter().chain(self.inner_attrs.iter()) |
577 | } |
578 | |
579 | /// Get the body of the function item in a manner so that it can be |
580 | /// conveniently used with the `quote!` macro. |
581 | fn body(&self) -> Body<'_> { |
582 | Body { |
583 | brace_token: self.brace_token, |
584 | stmts: &self.stmts, |
585 | } |
586 | } |
587 | |
588 | /// Convert our local function item into a token stream. |
589 | fn into_tokens( |
590 | self, |
591 | generated_attrs: proc_macro2::TokenStream, |
592 | body: proc_macro2::TokenStream, |
593 | last_block: proc_macro2::TokenStream, |
594 | ) -> TokenStream { |
595 | let mut tokens = proc_macro2::TokenStream::new(); |
596 | // Outer attributes are simply streamed as-is. |
597 | for attr in self.outer_attrs { |
598 | attr.to_tokens(&mut tokens); |
599 | } |
600 | |
601 | // Inner attributes require extra care, since they're not supported on |
602 | // blocks (which is what we're expanded into) we instead lift them |
603 | // outside of the function. This matches the behavior of `syn`. |
604 | for mut attr in self.inner_attrs { |
605 | attr.style = syn::AttrStyle::Outer; |
606 | attr.to_tokens(&mut tokens); |
607 | } |
608 | |
609 | // Add generated macros at the end, so macros processed later are aware of them. |
610 | generated_attrs.to_tokens(&mut tokens); |
611 | |
612 | self.vis.to_tokens(&mut tokens); |
613 | self.sig.to_tokens(&mut tokens); |
614 | |
615 | self.brace_token.surround(&mut tokens, |tokens| { |
616 | body.to_tokens(tokens); |
617 | last_block.to_tokens(tokens); |
618 | }); |
619 | |
620 | tokens |
621 | } |
622 | } |
623 | |
624 | impl Parse for ItemFn { |
625 | #[inline ] |
626 | fn parse(input: ParseStream<'_>) -> syn::Result<Self> { |
627 | // This parse implementation has been largely lifted from `syn`, with |
628 | // the exception of: |
629 | // * We don't have access to the plumbing necessary to parse inner |
630 | // attributes in-place. |
631 | // * We do our own statements parsing to avoid recursively parsing |
632 | // entire statements and only look for the parts we're interested in. |
633 | |
634 | let outer_attrs = input.call(Attribute::parse_outer)?; |
635 | let vis: Visibility = input.parse()?; |
636 | let sig: Signature = input.parse()?; |
637 | |
638 | let content; |
639 | let brace_token = braced!(content in input); |
640 | let inner_attrs = Attribute::parse_inner(&content)?; |
641 | |
642 | let mut buf = proc_macro2::TokenStream::new(); |
643 | let mut stmts = Vec::new(); |
644 | |
645 | while !content.is_empty() { |
646 | if let Some(semi) = content.parse::<Option<syn::Token![;]>>()? { |
647 | semi.to_tokens(&mut buf); |
648 | stmts.push(buf); |
649 | buf = proc_macro2::TokenStream::new(); |
650 | continue; |
651 | } |
652 | |
653 | // Parse a single token tree and extend our current buffer with it. |
654 | // This avoids parsing the entire content of the sub-tree. |
655 | buf.extend([content.parse::<TokenTree>()?]); |
656 | } |
657 | |
658 | if !buf.is_empty() { |
659 | stmts.push(buf); |
660 | } |
661 | |
662 | Ok(Self { |
663 | outer_attrs, |
664 | vis, |
665 | sig, |
666 | brace_token, |
667 | inner_attrs, |
668 | stmts, |
669 | }) |
670 | } |
671 | } |
672 | |
673 | struct Body<'a> { |
674 | brace_token: syn::token::Brace, |
675 | // Statements, with terminating `;`. |
676 | stmts: &'a [TokenStream], |
677 | } |
678 | |
679 | impl ToTokens for Body<'_> { |
680 | fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { |
681 | self.brace_token.surround(tokens, |tokens: &mut TokenStream| { |
682 | for stmt: &TokenStream in self.stmts { |
683 | stmt.to_tokens(tokens); |
684 | } |
685 | }); |
686 | } |
687 | } |
688 | |