1// SPDX-License-Identifier: Apache-2.0 OR MIT
2
3mod context;
4mod expr;
5#[cfg(feature = "type_analysis")]
6mod type_analysis;
7mod visitor;
8
9use proc_macro2::TokenStream;
10use quote::ToTokens;
11#[cfg(feature = "type_analysis")]
12use syn::Pat;
13use syn::{
14 AngleBracketedGenericArguments, Error, Expr, ExprClosure, GenericArgument, Item, ItemEnum,
15 ItemFn, Local, LocalInit, PathArguments, ReturnType, Stmt, Type, TypePath,
16};
17
18use self::{
19 context::{Context, VisitLastMode, VisitMode, DEFAULT_MARKER},
20 expr::child_expr,
21};
22use crate::utils::{block, expr_block, path_eq, replace_expr};
23
24/// The attribute name.
25const NAME: &str = "auto_enum";
26/// The annotation for recursively parsing.
27const NESTED: &str = "nested";
28/// The annotation for skipping branch.
29const NEVER: &str = "never";
30
31pub(crate) fn attribute(args: TokenStream, input: TokenStream) -> TokenStream {
32 let mut cx = match Context::root(input.clone(), args) {
33 Ok(cx) => cx,
34 Err(e) => return e.to_compile_error(),
35 };
36
37 match syn::parse2::<Stmt>(input.clone()) {
38 Ok(mut stmt) => {
39 expand_parent_stmt(&mut cx, &mut stmt);
40 cx.check().map(|()| stmt.into_token_stream())
41 }
42 Err(e) => match syn::parse2::<Expr>(input) {
43 Err(_e) => {
44 cx.error(e);
45 cx.error(format_err!(
46 cx.span,
47 "may only be used on expression, statement, or function"
48 ));
49 cx.check().map(|()| unreachable!())
50 }
51 Ok(mut expr) => {
52 expand_parent_expr(&mut cx, &mut expr, false);
53 cx.check().map(|()| expr.into_token_stream())
54 }
55 },
56 }
57 .unwrap_or_else(Error::into_compile_error)
58}
59
60fn expand_expr(cx: &mut Context, expr: &mut Expr) {
61 let expr = match expr {
62 Expr::Closure(ExprClosure { body, .. }) if cx.visit_last() => {
63 let count = visitor::visit_fn(cx, &mut **body);
64 if count.try_ >= 2 {
65 cx.visit_mode = VisitMode::Try;
66 } else {
67 cx.visit_mode = VisitMode::Return(count.return_);
68 }
69 &mut **body
70 }
71 _ => expr,
72 };
73
74 child_expr(cx, expr);
75
76 #[cfg(feature = "type_analysis")]
77 {
78 if let VisitMode::Return(count) = cx.visit_mode {
79 if cx.args.is_empty() && cx.variant_is_empty() && count < 2 {
80 cx.dummy(expr);
81 return;
82 }
83 }
84 }
85
86 cx.visitor(expr);
87}
88
89fn build_expr(expr: &mut Expr, item: ItemEnum) {
90 replace_expr(this:expr, |expr: Expr| {
91 expr_block(block(stmts:vec![Stmt::Item(item.into()), Stmt::Expr(expr, None)]))
92 });
93}
94
95// =================================================================================================
96// Expand statement or expression in which `#[auto_enum]` was directly used.
97
98fn expand_parent_stmt(cx: &mut Context, stmt: &mut Stmt) {
99 match stmt {
100 Stmt::Expr(expr: &mut Expr, semi: &mut Option) => expand_parent_expr(cx, expr, has_semi:semi.is_some()),
101 Stmt::Local(local: &mut Local) => expand_parent_local(cx, local),
102 Stmt::Item(Item::Fn(item: &mut ItemFn)) => expand_parent_item_fn(cx, item),
103 Stmt::Item(item: &mut Item) => {
104 cx.error(message:format_err!(item, "may only be used on expression, statement, or function"));
105 }
106 Stmt::Macro(_) => {}
107 }
108}
109
110fn expand_parent_expr(cx: &mut Context, expr: &mut Expr, has_semi: bool) {
111 if has_semi {
112 cx.visit_last_mode = VisitLastMode::Never;
113 }
114
115 if cx.is_dummy() {
116 cx.dummy(node:expr);
117 return;
118 }
119
120 expand_expr(cx, expr);
121
122 cx.build(|item: ItemEnum| build_expr(expr, item));
123}
124
125fn expand_parent_local(cx: &mut Context, local: &mut Local) {
126 #[cfg(feature = "type_analysis")]
127 {
128 if let Pat::Type(pat) = &mut local.pat {
129 if cx.collect_impl_trait(&mut pat.ty) {
130 local.pat = (*pat.pat).clone();
131 }
132 }
133 }
134
135 if cx.is_dummy() {
136 cx.dummy(local);
137 return;
138 }
139
140 let expr = if let Some(LocalInit { expr, .. }) = &mut local.init {
141 &mut **expr
142 } else {
143 cx.error(format_err!(
144 local,
145 "the `#[auto_enum]` attribute is not supported uninitialized let statement"
146 ));
147 return;
148 };
149
150 expand_expr(cx, expr);
151
152 cx.build(|item| build_expr(expr, item));
153}
154
155fn expand_parent_item_fn(cx: &mut Context, item: &mut ItemFn) {
156 let ItemFn { sig, block, .. } = item;
157 if let ReturnType::Type(_, ty) = &mut sig.output {
158 match &**ty {
159 // `return`
160 Type::ImplTrait(_) if cx.visit_last_mode != VisitLastMode::Never => {
161 let count = visitor::visit_fn(cx, &mut **block);
162 cx.visit_mode = VisitMode::Return(count.return_);
163 }
164
165 // `?` operator
166 Type::Path(TypePath { qself: None, path })
167 if cx.visit_last_mode != VisitLastMode::Never =>
168 {
169 let ty = path.segments.last().unwrap();
170 match &ty.arguments {
171 // `Result<T, impl Trait>`
172 PathArguments::AngleBracketed(AngleBracketedGenericArguments {
173 colon2_token: None,
174 args,
175 ..
176 }) if args.len() == 2
177 && path_eq(path, &["std", "core"], &["result", "Result"]) =>
178 {
179 if let (
180 GenericArgument::Type(_),
181 GenericArgument::Type(Type::ImplTrait(_)),
182 ) = (&args[0], &args[1])
183 {
184 let count = visitor::visit_fn(cx, &mut **block);
185 if count.try_ >= 2 {
186 cx.visit_mode = VisitMode::Try;
187 }
188 }
189 }
190 _ => {}
191 }
192 }
193
194 _ => {}
195 }
196
197 #[cfg(feature = "type_analysis")]
198 cx.collect_impl_trait(&mut *ty);
199 }
200
201 if cx.is_dummy() {
202 cx.dummy(item);
203 return;
204 }
205
206 match item.block.stmts.last_mut() {
207 Some(Stmt::Expr(expr, None)) => child_expr(cx, expr),
208 Some(_) => {}
209 None => cx.error(format_err!(
210 item.block,
211 "the `#[auto_enum]` attribute is not supported empty functions"
212 )),
213 }
214
215 #[cfg(feature = "type_analysis")]
216 {
217 if let VisitMode::Return(count) = cx.visit_mode {
218 if cx.args.is_empty() && cx.variant_is_empty() && count < 2 {
219 cx.dummy(item);
220 return;
221 }
222 }
223 }
224
225 cx.visitor(item);
226
227 cx.build(|i| item.block.stmts.insert(0, Stmt::Item(i.into())));
228}
229