1// SPDX-License-Identifier: Apache-2.0 OR MIT
2
3use proc_macro2::TokenStream;
4use quote::ToTokens;
5use syn::{
6 parse_quote, token,
7 visit_mut::{self, VisitMut},
8 Arm, Attribute, Expr, ExprMacro, ExprMatch, ExprReturn, ExprTry, Item, Local, LocalInit,
9 MetaList, Stmt, Token,
10};
11
12use super::{Context, VisitMode, DEFAULT_MARKER, NAME, NESTED, NEVER};
13use crate::utils::{replace_expr, Attrs, Node};
14
15#[derive(Clone, Copy, Default)]
16struct Scope {
17 /// in closures
18 closure: bool,
19 /// in try blocks
20 try_block: bool,
21 /// in the other `auto_enum` attributes
22 foreign: bool,
23}
24
25impl Scope {
26 // check this scope is in closures or try blocks.
27 fn check_expr(&mut self, expr: &Expr) {
28 match expr {
29 Expr::Closure(_) => self.closure = true,
30 // `?` operator in try blocks are not supported.
31 Expr::TryBlock(_) => self.try_block = true,
32 _ => {}
33 }
34 }
35}
36
37// =================================================================================================
38// default visitor
39
40pub(super) struct Visitor<'a> {
41 cx: &'a mut Context,
42 scope: Scope,
43}
44
45impl<'a> Visitor<'a> {
46 pub(super) fn new(cx: &'a mut Context) -> Self {
47 Self { cx, scope: Scope::default() }
48 }
49
50 fn find_remove_attrs(&mut self, attrs: &mut impl Attrs) {
51 if !self.scope.foreign {
52 if let Some(attr) = attrs.find_remove_attr(NEVER) {
53 if let Err(e) = attr.meta.require_path_only() {
54 self.cx.error(e);
55 }
56 }
57
58 // The old annotation `#[rec]` is replaced with `#[nested]`.
59 if let Some(old) = attrs.find_remove_attr("rec") {
60 self.cx.error(format_err!(
61 old,
62 "#[rec] has been removed and replaced with #[{}]",
63 NESTED
64 ));
65 }
66 }
67 }
68
69 /// `return` in functions or closures
70 fn visit_return(&mut self, node: &mut Expr, count: usize) {
71 debug_assert!(self.cx.visit_mode == VisitMode::Return(count));
72
73 if !self.scope.closure && !node.any_empty_attr(NEVER) {
74 // Desugar `return <expr>` into `return Enum::VariantN(<expr>)`.
75 if let Expr::Return(ExprReturn { expr, .. }) = node {
76 // Skip if `<expr>` is a marker macro.
77 if expr.as_ref().map_or(true, |expr| !self.cx.is_marker_expr(expr)) {
78 self.cx.replace_boxed_expr(expr);
79 }
80 }
81 }
82 }
83
84 /// `?` operator in functions or closures
85 fn visit_try(&mut self, node: &mut Expr) {
86 debug_assert!(self.cx.visit_mode == VisitMode::Try);
87
88 if !self.scope.try_block && !self.scope.closure && !node.any_empty_attr(NEVER) {
89 match &node {
90 // https://github.com/rust-lang/rust/blob/1.35.0/src/librustc/hir/lowering.rs#L4578-L4682
91
92 // Desugar `<expr>?`
93 // into:
94 //
95 // match <expr> {
96 // Ok(val) => val,
97 // Err(err) => return Err(Enum::VariantN(err)),
98 // }
99 //
100 // Skip if `<expr>` is a marker macro.
101 Expr::Try(ExprTry { expr, .. }) if !self.cx.is_marker_expr(expr) => {
102 replace_expr(node, |expr| {
103 let ExprTry { attrs, expr, .. } =
104 if let Expr::Try(expr) = expr { expr } else { unreachable!() };
105
106 let err = self.cx.next_expr(parse_quote!(err));
107 let arms = vec![
108 parse_quote! {
109 ::core::result::Result::Ok(val) => val,
110 },
111 parse_quote! {
112 ::core::result::Result::Err(err) => {
113 return ::core::result::Result::Err(#err);
114 }
115 },
116 ];
117
118 Expr::Match(ExprMatch {
119 attrs,
120 match_token: <Token![match]>::default(),
121 expr,
122 brace_token: token::Brace::default(),
123 arms,
124 })
125 });
126 }
127 _ => {}
128 }
129 }
130 }
131
132 /// `#[nested]`
133 fn visit_nested(&mut self, node: &mut Expr, attr: &Attribute) {
134 debug_assert!(!self.scope.foreign);
135
136 if let Err(e) = attr.meta.require_path_only() {
137 self.cx.error(e);
138 } else {
139 super::expr::child_expr(self.cx, node);
140 }
141 }
142
143 /// Expression level marker (`marker!` macro)
144 fn visit_marker_macro(&mut self, node: &mut Expr) {
145 debug_assert!(!self.scope.foreign || self.cx.current_marker != DEFAULT_MARKER);
146
147 match node {
148 // Desugar `marker!(<expr>)` into `Enum::VariantN(<expr>)`.
149 // Skip if `marker!` is not a marker macro.
150 Expr::Macro(ExprMacro { mac, .. }) if self.cx.is_marker_macro_exact(mac) => {
151 replace_expr(node, |expr| {
152 let expr = if let Expr::Macro(expr) = expr { expr } else { unreachable!() };
153 let args = syn::parse2(expr.mac.tokens).unwrap_or_else(|e| {
154 self.cx.error(e);
155 // Generate an expression to fill in where the error occurred during the visit.
156 // These will eventually need to be replaced with the original error message.
157 parse_quote!(compile_error!(
158 "#[auto_enum] failed to generate error message"
159 ))
160 });
161
162 if self.cx.has_error() {
163 args
164 } else {
165 self.cx.next_expr_with_attrs(expr.attrs, args)
166 }
167 });
168 }
169 _ => {}
170 }
171 }
172
173 fn visit_expr(&mut self, node: &mut Expr, has_semi: bool) {
174 debug_assert!(!self.cx.has_error());
175
176 let tmp = self.scope;
177
178 if node.any_attr(NAME) {
179 self.scope.foreign = true;
180 // Record whether other `auto_enum` attribute exists.
181 self.cx.has_child = true;
182 }
183 self.scope.check_expr(node);
184
185 match self.cx.visit_mode {
186 VisitMode::Return(count) => self.visit_return(node, count),
187 VisitMode::Try => self.visit_try(node),
188 VisitMode::Default => {}
189 }
190
191 if !self.scope.foreign {
192 if let Some(attr) = node.find_remove_attr(NESTED) {
193 self.visit_nested(node, &attr);
194 }
195 }
196
197 VisitStmt::visit_expr(self, node, has_semi);
198
199 if !self.scope.foreign || self.cx.current_marker != DEFAULT_MARKER {
200 self.visit_marker_macro(node);
201 self.find_remove_attrs(node);
202 }
203
204 self.scope = tmp;
205 }
206}
207
208impl VisitMut for Visitor<'_> {
209 fn visit_expr_mut(&mut self, node: &mut Expr) {
210 if !self.cx.has_error() {
211 self.visit_expr(node, false);
212 }
213 }
214
215 fn visit_arm_mut(&mut self, node: &mut Arm) {
216 if !self.cx.has_error() {
217 if !self.scope.foreign {
218 if let Some(attr) = node.find_remove_attr(NESTED) {
219 self.visit_nested(&mut node.body, &attr);
220 }
221 }
222
223 visit_mut::visit_arm_mut(self, node);
224
225 self.find_remove_attrs(node);
226 }
227 }
228
229 fn visit_local_mut(&mut self, node: &mut Local) {
230 if !self.cx.has_error() {
231 if !self.scope.foreign {
232 if let Some(attr) = node.find_remove_attr(NESTED) {
233 if let Some(LocalInit { expr, .. }) = &mut node.init {
234 self.visit_nested(expr, &attr);
235 }
236 }
237 }
238
239 visit_mut::visit_local_mut(self, node);
240
241 self.find_remove_attrs(node);
242 }
243 }
244
245 fn visit_stmt_mut(&mut self, node: &mut Stmt) {
246 if !self.cx.has_error() {
247 if let Stmt::Expr(expr, semi) = node {
248 self.visit_expr(expr, semi.is_some());
249 } else {
250 let tmp = self.scope;
251
252 if node.any_attr(NAME) {
253 self.scope.foreign = true;
254 // Record whether other `auto_enum` attribute exists.
255 self.cx.has_child = true;
256 }
257
258 VisitStmt::visit_stmt(self, node);
259
260 self.scope = tmp;
261 }
262 }
263 }
264
265 fn visit_item_mut(&mut self, _: &mut Item) {
266 // Do not recurse into nested items.
267 }
268}
269
270impl VisitStmt for Visitor<'_> {
271 fn cx(&mut self) -> &mut Context {
272 self.cx
273 }
274}
275
276// =================================================================================================
277// dummy visitor
278
279pub(super) struct Dummy<'a> {
280 cx: &'a mut Context,
281}
282
283impl<'a> Dummy<'a> {
284 pub(super) fn new(cx: &'a mut Context) -> Self {
285 Self { cx }
286 }
287}
288
289impl VisitMut for Dummy<'_> {
290 fn visit_stmt_mut(&mut self, node: &mut Stmt) {
291 if !self.cx.has_error() {
292 if node.any_attr(NAME) {
293 self.cx.has_child = true;
294 }
295 VisitStmt::visit_stmt(self, node);
296 }
297 }
298
299 fn visit_expr_mut(&mut self, node: &mut Expr) {
300 if !self.cx.has_error() {
301 if node.any_attr(NAME) {
302 self.cx.has_child = true;
303 }
304 VisitStmt::visit_expr(self, node, has_semi:false);
305 }
306 }
307
308 fn visit_item_mut(&mut self, _: &mut Item) {
309 // Do not recurse into nested items.
310 }
311}
312
313impl VisitStmt for Dummy<'_> {
314 fn cx(&mut self) -> &mut Context {
315 self.cx
316 }
317}
318
319// =================================================================================================
320// VisitStmt
321
322trait VisitStmt: VisitMut {
323 fn cx(&mut self) -> &mut Context;
324
325 fn visit_expr(visitor: &mut Self, node: &mut Expr, has_semi: bool) {
326 let attr = node.find_remove_attr(NAME);
327
328 let res = attr.map(|attr| {
329 attr.meta.require_list().and_then(|MetaList { tokens, .. }| {
330 visitor.cx().make_child(node.to_token_stream(), tokens.clone())
331 })
332 });
333
334 visit_mut::visit_expr_mut(visitor, node);
335
336 match res {
337 Some(Err(e)) => visitor.cx().error(e),
338 Some(Ok(mut cx)) => {
339 super::expand_parent_expr(&mut cx, node, has_semi);
340 visitor.cx().join_child(cx);
341 }
342 None => {}
343 }
344 }
345
346 fn visit_stmt(visitor: &mut Self, node: &mut Stmt) {
347 let attr = match node {
348 Stmt::Expr(expr, semi) => {
349 Self::visit_expr(visitor, expr, semi.is_some());
350 return;
351 }
352 Stmt::Local(local) => local.find_remove_attr(NAME),
353 Stmt::Macro(_) => None,
354 // Do not recurse into nested items.
355 Stmt::Item(_) => return,
356 };
357
358 let res = attr.map(|attr| {
359 let args = match attr.meta {
360 syn::Meta::Path(_) => TokenStream::new(),
361 syn::Meta::List(list) => list.tokens,
362 syn::Meta::NameValue(nv) => bail!(nv.eq_token, "expected list"),
363 };
364 visitor.cx().make_child(node.to_token_stream(), args)
365 });
366
367 visit_mut::visit_stmt_mut(visitor, node);
368
369 match res {
370 Some(Err(e)) => visitor.cx().error(e),
371 Some(Ok(mut cx)) => {
372 super::expand_parent_stmt(&mut cx, node);
373 visitor.cx().join_child(cx);
374 }
375 None => {}
376 }
377 }
378}
379
380// =================================================================================================
381// FindNested
382
383/// Find `#[nested]` attribute.
384pub(super) fn find_nested(node: &mut impl Node) -> bool {
385 struct FindNested {
386 has: bool,
387 }
388
389 impl VisitMut for FindNested {
390 fn visit_expr_mut(&mut self, node: &mut Expr) {
391 if !node.any_attr(NAME) {
392 if node.any_empty_attr(NESTED) {
393 self.has = true;
394 } else {
395 visit_mut::visit_expr_mut(self, node);
396 }
397 }
398 }
399
400 fn visit_arm_mut(&mut self, node: &mut Arm) {
401 if node.any_empty_attr(NESTED) {
402 self.has = true;
403 } else {
404 visit_mut::visit_arm_mut(self, node);
405 }
406 }
407
408 fn visit_local_mut(&mut self, node: &mut Local) {
409 if !node.any_attr(NAME) {
410 if node.any_empty_attr(NESTED) {
411 self.has = true;
412 } else {
413 visit_mut::visit_local_mut(self, node);
414 }
415 }
416 }
417
418 fn visit_item_mut(&mut self, _: &mut Item) {
419 // Do not recurse into nested items.
420 }
421 }
422
423 let mut visitor = FindNested { has: false };
424 node.visited(&mut visitor);
425 visitor.has
426}
427
428// =================================================================================================
429// FnVisitor
430
431#[derive(Default)]
432pub(super) struct FnCount {
433 pub(super) try_: usize,
434 pub(super) return_: usize,
435}
436
437pub(super) fn visit_fn(cx: &Context, node: &mut impl Node) -> FnCount {
438 struct FnVisitor<'a> {
439 cx: &'a Context,
440 scope: Scope,
441 count: FnCount,
442 }
443
444 impl VisitMut for FnVisitor<'_> {
445 fn visit_expr_mut(&mut self, node: &mut Expr) {
446 let tmp = self.scope;
447
448 self.scope.check_expr(node);
449
450 if !self.scope.closure && !node.any_empty_attr(NEVER) {
451 match node {
452 Expr::Try(ExprTry { expr, .. }) => {
453 // Skip if `<expr>` is a marker macro.
454 if !self.cx.is_marker_expr(expr) {
455 self.count.try_ += 1;
456 }
457 }
458 Expr::Return(ExprReturn { expr, .. }) => {
459 // Skip if `<expr>` is a marker macro.
460 if expr.as_ref().map_or(true, |expr| !self.cx.is_marker_expr(expr)) {
461 self.count.return_ += 1;
462 }
463 }
464 _ => {}
465 }
466 }
467
468 if node.any_attr(NAME) {
469 self.scope.foreign = true;
470 }
471
472 visit_mut::visit_expr_mut(self, node);
473
474 self.scope = tmp;
475 }
476
477 fn visit_stmt_mut(&mut self, node: &mut Stmt) {
478 let tmp = self.scope;
479
480 if node.any_attr(NAME) {
481 self.scope.foreign = true;
482 }
483
484 visit_mut::visit_stmt_mut(self, node);
485
486 self.scope = tmp;
487 }
488
489 fn visit_item_mut(&mut self, _: &mut Item) {
490 // Do not recurse into nested items.
491 }
492 }
493
494 let mut visitor = FnVisitor { cx, scope: Scope::default(), count: FnCount::default() };
495 node.visited(&mut visitor);
496 visitor.count
497}
498