1 | // SPDX-License-Identifier: Apache-2.0 OR MIT |
2 | |
3 | use proc_macro2::TokenStream; |
4 | use quote::ToTokens; |
5 | use syn::{ |
6 | parse_quote, token, |
7 | visit_mut::{self, VisitMut}, |
8 | Arm, Attribute, Expr, ExprMacro, ExprMatch, ExprReturn, ExprTry, Item, Local, LocalInit, |
9 | MetaList, Stmt, Token, |
10 | }; |
11 | |
12 | use super::{Context, VisitMode, DEFAULT_MARKER, NAME, NESTED, NEVER}; |
13 | use crate::utils::{replace_expr, Attrs, Node}; |
14 | |
15 | #[derive (Clone, Copy, Default)] |
16 | struct Scope { |
17 | /// in closures |
18 | closure: bool, |
19 | /// in try blocks |
20 | try_block: bool, |
21 | /// in the other `auto_enum` attributes |
22 | foreign: bool, |
23 | } |
24 | |
25 | impl Scope { |
26 | // check this scope is in closures or try blocks. |
27 | fn check_expr(&mut self, expr: &Expr) { |
28 | match expr { |
29 | Expr::Closure(_) => self.closure = true, |
30 | // `?` operator in try blocks are not supported. |
31 | Expr::TryBlock(_) => self.try_block = true, |
32 | _ => {} |
33 | } |
34 | } |
35 | } |
36 | |
37 | // ================================================================================================= |
38 | // default visitor |
39 | |
40 | pub(super) struct Visitor<'a> { |
41 | cx: &'a mut Context, |
42 | scope: Scope, |
43 | } |
44 | |
45 | impl<'a> Visitor<'a> { |
46 | pub(super) fn new(cx: &'a mut Context) -> Self { |
47 | Self { cx, scope: Scope::default() } |
48 | } |
49 | |
50 | fn find_remove_attrs(&mut self, attrs: &mut impl Attrs) { |
51 | if !self.scope.foreign { |
52 | if let Some(attr) = attrs.find_remove_attr(NEVER) { |
53 | if let Err(e) = attr.meta.require_path_only() { |
54 | self.cx.error(e); |
55 | } |
56 | } |
57 | |
58 | // The old annotation `#[rec]` is replaced with `#[nested]`. |
59 | if let Some(old) = attrs.find_remove_attr("rec" ) { |
60 | self.cx.error(format_err!( |
61 | old, |
62 | "#[rec] has been removed and replaced with #[ {}]" , |
63 | NESTED |
64 | )); |
65 | } |
66 | } |
67 | } |
68 | |
69 | /// `return` in functions or closures |
70 | fn visit_return(&mut self, node: &mut Expr, count: usize) { |
71 | debug_assert!(self.cx.visit_mode == VisitMode::Return(count)); |
72 | |
73 | if !self.scope.closure && !node.any_empty_attr(NEVER) { |
74 | // Desugar `return <expr>` into `return Enum::VariantN(<expr>)`. |
75 | if let Expr::Return(ExprReturn { expr, .. }) = node { |
76 | // Skip if `<expr>` is a marker macro. |
77 | if expr.as_ref().map_or(true, |expr| !self.cx.is_marker_expr(expr)) { |
78 | self.cx.replace_boxed_expr(expr); |
79 | } |
80 | } |
81 | } |
82 | } |
83 | |
84 | /// `?` operator in functions or closures |
85 | fn visit_try(&mut self, node: &mut Expr) { |
86 | debug_assert!(self.cx.visit_mode == VisitMode::Try); |
87 | |
88 | if !self.scope.try_block && !self.scope.closure && !node.any_empty_attr(NEVER) { |
89 | match &node { |
90 | // https://github.com/rust-lang/rust/blob/1.35.0/src/librustc/hir/lowering.rs#L4578-L4682 |
91 | |
92 | // Desugar `<expr>?` |
93 | // into: |
94 | // |
95 | // match <expr> { |
96 | // Ok(val) => val, |
97 | // Err(err) => return Err(Enum::VariantN(err)), |
98 | // } |
99 | // |
100 | // Skip if `<expr>` is a marker macro. |
101 | Expr::Try(ExprTry { expr, .. }) if !self.cx.is_marker_expr(expr) => { |
102 | replace_expr(node, |expr| { |
103 | let ExprTry { attrs, expr, .. } = |
104 | if let Expr::Try(expr) = expr { expr } else { unreachable!() }; |
105 | |
106 | let err = self.cx.next_expr(parse_quote!(err)); |
107 | let arms = vec![ |
108 | parse_quote! { |
109 | ::core::result::Result::Ok(val) => val, |
110 | }, |
111 | parse_quote! { |
112 | ::core::result::Result::Err(err) => { |
113 | return ::core::result::Result::Err(#err); |
114 | } |
115 | }, |
116 | ]; |
117 | |
118 | Expr::Match(ExprMatch { |
119 | attrs, |
120 | match_token: <Token![match]>::default(), |
121 | expr, |
122 | brace_token: token::Brace::default(), |
123 | arms, |
124 | }) |
125 | }); |
126 | } |
127 | _ => {} |
128 | } |
129 | } |
130 | } |
131 | |
132 | /// `#[nested]` |
133 | fn visit_nested(&mut self, node: &mut Expr, attr: &Attribute) { |
134 | debug_assert!(!self.scope.foreign); |
135 | |
136 | if let Err(e) = attr.meta.require_path_only() { |
137 | self.cx.error(e); |
138 | } else { |
139 | super::expr::child_expr(self.cx, node); |
140 | } |
141 | } |
142 | |
143 | /// Expression level marker (`marker!` macro) |
144 | fn visit_marker_macro(&mut self, node: &mut Expr) { |
145 | debug_assert!(!self.scope.foreign || self.cx.current_marker != DEFAULT_MARKER); |
146 | |
147 | match node { |
148 | // Desugar `marker!(<expr>)` into `Enum::VariantN(<expr>)`. |
149 | // Skip if `marker!` is not a marker macro. |
150 | Expr::Macro(ExprMacro { mac, .. }) if self.cx.is_marker_macro_exact(mac) => { |
151 | replace_expr(node, |expr| { |
152 | let expr = if let Expr::Macro(expr) = expr { expr } else { unreachable!() }; |
153 | let args = syn::parse2(expr.mac.tokens).unwrap_or_else(|e| { |
154 | self.cx.error(e); |
155 | // Generate an expression to fill in where the error occurred during the visit. |
156 | // These will eventually need to be replaced with the original error message. |
157 | parse_quote!(compile_error!( |
158 | "#[auto_enum] failed to generate error message" |
159 | )) |
160 | }); |
161 | |
162 | if self.cx.has_error() { |
163 | args |
164 | } else { |
165 | self.cx.next_expr_with_attrs(expr.attrs, args) |
166 | } |
167 | }); |
168 | } |
169 | _ => {} |
170 | } |
171 | } |
172 | |
173 | fn visit_expr(&mut self, node: &mut Expr, has_semi: bool) { |
174 | debug_assert!(!self.cx.has_error()); |
175 | |
176 | let tmp = self.scope; |
177 | |
178 | if node.any_attr(NAME) { |
179 | self.scope.foreign = true; |
180 | // Record whether other `auto_enum` attribute exists. |
181 | self.cx.has_child = true; |
182 | } |
183 | self.scope.check_expr(node); |
184 | |
185 | match self.cx.visit_mode { |
186 | VisitMode::Return(count) => self.visit_return(node, count), |
187 | VisitMode::Try => self.visit_try(node), |
188 | VisitMode::Default => {} |
189 | } |
190 | |
191 | if !self.scope.foreign { |
192 | if let Some(attr) = node.find_remove_attr(NESTED) { |
193 | self.visit_nested(node, &attr); |
194 | } |
195 | } |
196 | |
197 | VisitStmt::visit_expr(self, node, has_semi); |
198 | |
199 | if !self.scope.foreign || self.cx.current_marker != DEFAULT_MARKER { |
200 | self.visit_marker_macro(node); |
201 | self.find_remove_attrs(node); |
202 | } |
203 | |
204 | self.scope = tmp; |
205 | } |
206 | } |
207 | |
208 | impl VisitMut for Visitor<'_> { |
209 | fn visit_expr_mut(&mut self, node: &mut Expr) { |
210 | if !self.cx.has_error() { |
211 | self.visit_expr(node, false); |
212 | } |
213 | } |
214 | |
215 | fn visit_arm_mut(&mut self, node: &mut Arm) { |
216 | if !self.cx.has_error() { |
217 | if !self.scope.foreign { |
218 | if let Some(attr) = node.find_remove_attr(NESTED) { |
219 | self.visit_nested(&mut node.body, &attr); |
220 | } |
221 | } |
222 | |
223 | visit_mut::visit_arm_mut(self, node); |
224 | |
225 | self.find_remove_attrs(node); |
226 | } |
227 | } |
228 | |
229 | fn visit_local_mut(&mut self, node: &mut Local) { |
230 | if !self.cx.has_error() { |
231 | if !self.scope.foreign { |
232 | if let Some(attr) = node.find_remove_attr(NESTED) { |
233 | if let Some(LocalInit { expr, .. }) = &mut node.init { |
234 | self.visit_nested(expr, &attr); |
235 | } |
236 | } |
237 | } |
238 | |
239 | visit_mut::visit_local_mut(self, node); |
240 | |
241 | self.find_remove_attrs(node); |
242 | } |
243 | } |
244 | |
245 | fn visit_stmt_mut(&mut self, node: &mut Stmt) { |
246 | if !self.cx.has_error() { |
247 | if let Stmt::Expr(expr, semi) = node { |
248 | self.visit_expr(expr, semi.is_some()); |
249 | } else { |
250 | let tmp = self.scope; |
251 | |
252 | if node.any_attr(NAME) { |
253 | self.scope.foreign = true; |
254 | // Record whether other `auto_enum` attribute exists. |
255 | self.cx.has_child = true; |
256 | } |
257 | |
258 | VisitStmt::visit_stmt(self, node); |
259 | |
260 | self.scope = tmp; |
261 | } |
262 | } |
263 | } |
264 | |
265 | fn visit_item_mut(&mut self, _: &mut Item) { |
266 | // Do not recurse into nested items. |
267 | } |
268 | } |
269 | |
270 | impl VisitStmt for Visitor<'_> { |
271 | fn cx(&mut self) -> &mut Context { |
272 | self.cx |
273 | } |
274 | } |
275 | |
276 | // ================================================================================================= |
277 | // dummy visitor |
278 | |
279 | pub(super) struct Dummy<'a> { |
280 | cx: &'a mut Context, |
281 | } |
282 | |
283 | impl<'a> Dummy<'a> { |
284 | pub(super) fn new(cx: &'a mut Context) -> Self { |
285 | Self { cx } |
286 | } |
287 | } |
288 | |
289 | impl VisitMut for Dummy<'_> { |
290 | fn visit_stmt_mut(&mut self, node: &mut Stmt) { |
291 | if !self.cx.has_error() { |
292 | if node.any_attr(NAME) { |
293 | self.cx.has_child = true; |
294 | } |
295 | VisitStmt::visit_stmt(self, node); |
296 | } |
297 | } |
298 | |
299 | fn visit_expr_mut(&mut self, node: &mut Expr) { |
300 | if !self.cx.has_error() { |
301 | if node.any_attr(NAME) { |
302 | self.cx.has_child = true; |
303 | } |
304 | VisitStmt::visit_expr(self, node, has_semi:false); |
305 | } |
306 | } |
307 | |
308 | fn visit_item_mut(&mut self, _: &mut Item) { |
309 | // Do not recurse into nested items. |
310 | } |
311 | } |
312 | |
313 | impl VisitStmt for Dummy<'_> { |
314 | fn cx(&mut self) -> &mut Context { |
315 | self.cx |
316 | } |
317 | } |
318 | |
319 | // ================================================================================================= |
320 | // VisitStmt |
321 | |
322 | trait VisitStmt: VisitMut { |
323 | fn cx(&mut self) -> &mut Context; |
324 | |
325 | fn visit_expr(visitor: &mut Self, node: &mut Expr, has_semi: bool) { |
326 | let attr = node.find_remove_attr(NAME); |
327 | |
328 | let res = attr.map(|attr| { |
329 | attr.meta.require_list().and_then(|MetaList { tokens, .. }| { |
330 | visitor.cx().make_child(node.to_token_stream(), tokens.clone()) |
331 | }) |
332 | }); |
333 | |
334 | visit_mut::visit_expr_mut(visitor, node); |
335 | |
336 | match res { |
337 | Some(Err(e)) => visitor.cx().error(e), |
338 | Some(Ok(mut cx)) => { |
339 | super::expand_parent_expr(&mut cx, node, has_semi); |
340 | visitor.cx().join_child(cx); |
341 | } |
342 | None => {} |
343 | } |
344 | } |
345 | |
346 | fn visit_stmt(visitor: &mut Self, node: &mut Stmt) { |
347 | let attr = match node { |
348 | Stmt::Expr(expr, semi) => { |
349 | Self::visit_expr(visitor, expr, semi.is_some()); |
350 | return; |
351 | } |
352 | Stmt::Local(local) => local.find_remove_attr(NAME), |
353 | Stmt::Macro(_) => None, |
354 | // Do not recurse into nested items. |
355 | Stmt::Item(_) => return, |
356 | }; |
357 | |
358 | let res = attr.map(|attr| { |
359 | let args = match attr.meta { |
360 | syn::Meta::Path(_) => TokenStream::new(), |
361 | syn::Meta::List(list) => list.tokens, |
362 | syn::Meta::NameValue(nv) => bail!(nv.eq_token, "expected list" ), |
363 | }; |
364 | visitor.cx().make_child(node.to_token_stream(), args) |
365 | }); |
366 | |
367 | visit_mut::visit_stmt_mut(visitor, node); |
368 | |
369 | match res { |
370 | Some(Err(e)) => visitor.cx().error(e), |
371 | Some(Ok(mut cx)) => { |
372 | super::expand_parent_stmt(&mut cx, node); |
373 | visitor.cx().join_child(cx); |
374 | } |
375 | None => {} |
376 | } |
377 | } |
378 | } |
379 | |
380 | // ================================================================================================= |
381 | // FindNested |
382 | |
383 | /// Find `#[nested]` attribute. |
384 | pub(super) fn find_nested(node: &mut impl Node) -> bool { |
385 | struct FindNested { |
386 | has: bool, |
387 | } |
388 | |
389 | impl VisitMut for FindNested { |
390 | fn visit_expr_mut(&mut self, node: &mut Expr) { |
391 | if !node.any_attr(NAME) { |
392 | if node.any_empty_attr(NESTED) { |
393 | self.has = true; |
394 | } else { |
395 | visit_mut::visit_expr_mut(self, node); |
396 | } |
397 | } |
398 | } |
399 | |
400 | fn visit_arm_mut(&mut self, node: &mut Arm) { |
401 | if node.any_empty_attr(NESTED) { |
402 | self.has = true; |
403 | } else { |
404 | visit_mut::visit_arm_mut(self, node); |
405 | } |
406 | } |
407 | |
408 | fn visit_local_mut(&mut self, node: &mut Local) { |
409 | if !node.any_attr(NAME) { |
410 | if node.any_empty_attr(NESTED) { |
411 | self.has = true; |
412 | } else { |
413 | visit_mut::visit_local_mut(self, node); |
414 | } |
415 | } |
416 | } |
417 | |
418 | fn visit_item_mut(&mut self, _: &mut Item) { |
419 | // Do not recurse into nested items. |
420 | } |
421 | } |
422 | |
423 | let mut visitor = FindNested { has: false }; |
424 | node.visited(&mut visitor); |
425 | visitor.has |
426 | } |
427 | |
428 | // ================================================================================================= |
429 | // FnVisitor |
430 | |
431 | #[derive (Default)] |
432 | pub(super) struct FnCount { |
433 | pub(super) try_: usize, |
434 | pub(super) return_: usize, |
435 | } |
436 | |
437 | pub(super) fn visit_fn(cx: &Context, node: &mut impl Node) -> FnCount { |
438 | struct FnVisitor<'a> { |
439 | cx: &'a Context, |
440 | scope: Scope, |
441 | count: FnCount, |
442 | } |
443 | |
444 | impl VisitMut for FnVisitor<'_> { |
445 | fn visit_expr_mut(&mut self, node: &mut Expr) { |
446 | let tmp = self.scope; |
447 | |
448 | self.scope.check_expr(node); |
449 | |
450 | if !self.scope.closure && !node.any_empty_attr(NEVER) { |
451 | match node { |
452 | Expr::Try(ExprTry { expr, .. }) => { |
453 | // Skip if `<expr>` is a marker macro. |
454 | if !self.cx.is_marker_expr(expr) { |
455 | self.count.try_ += 1; |
456 | } |
457 | } |
458 | Expr::Return(ExprReturn { expr, .. }) => { |
459 | // Skip if `<expr>` is a marker macro. |
460 | if expr.as_ref().map_or(true, |expr| !self.cx.is_marker_expr(expr)) { |
461 | self.count.return_ += 1; |
462 | } |
463 | } |
464 | _ => {} |
465 | } |
466 | } |
467 | |
468 | if node.any_attr(NAME) { |
469 | self.scope.foreign = true; |
470 | } |
471 | |
472 | visit_mut::visit_expr_mut(self, node); |
473 | |
474 | self.scope = tmp; |
475 | } |
476 | |
477 | fn visit_stmt_mut(&mut self, node: &mut Stmt) { |
478 | let tmp = self.scope; |
479 | |
480 | if node.any_attr(NAME) { |
481 | self.scope.foreign = true; |
482 | } |
483 | |
484 | visit_mut::visit_stmt_mut(self, node); |
485 | |
486 | self.scope = tmp; |
487 | } |
488 | |
489 | fn visit_item_mut(&mut self, _: &mut Item) { |
490 | // Do not recurse into nested items. |
491 | } |
492 | } |
493 | |
494 | let mut visitor = FnVisitor { cx, scope: Scope::default(), count: FnCount::default() }; |
495 | node.visited(&mut visitor); |
496 | visitor.count |
497 | } |
498 | |