1 | use proc_macro2::{Span, TokenStream}; |
2 | use quote::{format_ident, quote}; |
3 | use syn::{spanned::Spanned, Data, DeriveInput, Fields}; |
4 | |
5 | use crate::helpers::{non_enum_error, snakify, HasStrumVariantProperties}; |
6 | |
7 | pub fn enum_table_inner(ast: &DeriveInput) -> syn::Result<TokenStream> { |
8 | let name = &ast.ident; |
9 | let gen = &ast.generics; |
10 | let vis = &ast.vis; |
11 | let mut doc_comment = format!("A map over the variants of ` {}`" , name); |
12 | |
13 | if gen.lifetimes().count() > 0 { |
14 | return Err(syn::Error::new( |
15 | Span::call_site(), |
16 | "`EnumTable` doesn't support enums with lifetimes." , |
17 | )); |
18 | } |
19 | |
20 | let variants = match &ast.data { |
21 | Data::Enum(v) => &v.variants, |
22 | _ => return Err(non_enum_error()), |
23 | }; |
24 | |
25 | let table_name = format_ident!(" {}Table" , name); |
26 | |
27 | // the identifiers of each variant, in PascalCase |
28 | let mut pascal_idents = Vec::new(); |
29 | // the identifiers of each struct field, in snake_case |
30 | let mut snake_idents = Vec::new(); |
31 | // match arms in the form `MyEnumTable::Variant => &self.variant,` |
32 | let mut get_matches = Vec::new(); |
33 | // match arms in the form `MyEnumTable::Variant => &mut self.variant,` |
34 | let mut get_matches_mut = Vec::new(); |
35 | // match arms in the form `MyEnumTable::Variant => self.variant = new_value` |
36 | let mut set_matches = Vec::new(); |
37 | // struct fields of the form `variant: func(MyEnum::Variant),* |
38 | let mut closure_fields = Vec::new(); |
39 | // struct fields of the form `variant: func(MyEnum::Variant, self.variant),` |
40 | let mut transform_fields = Vec::new(); |
41 | |
42 | // identifiers for disabled variants |
43 | let mut disabled_variants = Vec::new(); |
44 | // match arms for disabled variants |
45 | let mut disabled_matches = Vec::new(); |
46 | |
47 | for variant in variants { |
48 | // skip disabled variants |
49 | if variant.get_variant_properties()?.disabled.is_some() { |
50 | let disabled_ident = &variant.ident; |
51 | let panic_message = format!( |
52 | "Can't use ` {}` with ` {}` - variant is disabled for Strum features" , |
53 | disabled_ident, table_name |
54 | ); |
55 | disabled_variants.push(disabled_ident); |
56 | disabled_matches.push(quote!(#name::#disabled_ident => panic!(#panic_message),)); |
57 | continue; |
58 | } |
59 | |
60 | // Error on variants with data |
61 | if variant.fields != Fields::Unit { |
62 | return Err(syn::Error::new( |
63 | variant.fields.span(), |
64 | "`EnumTable` doesn't support enums with non-unit variants" , |
65 | )); |
66 | }; |
67 | |
68 | let pascal_case = &variant.ident; |
69 | let snake_case = format_ident!("_ {}" , snakify(&pascal_case.to_string())); |
70 | |
71 | get_matches.push(quote! {#name::#pascal_case => &self.#snake_case,}); |
72 | get_matches_mut.push(quote! {#name::#pascal_case => &mut self.#snake_case,}); |
73 | set_matches.push(quote! {#name::#pascal_case => self.#snake_case = new_value,}); |
74 | closure_fields.push(quote! {#snake_case: func(#name::#pascal_case),}); |
75 | transform_fields.push(quote! {#snake_case: func(#name::#pascal_case, &self.#snake_case),}); |
76 | pascal_idents.push(pascal_case); |
77 | snake_idents.push(snake_case); |
78 | } |
79 | |
80 | // Error on empty enums |
81 | if pascal_idents.is_empty() { |
82 | return Err(syn::Error::new( |
83 | variants.span(), |
84 | "`EnumTable` requires at least one non-disabled variant" , |
85 | )); |
86 | } |
87 | |
88 | // if the index operation can panic, add that to the documentation |
89 | if !disabled_variants.is_empty() { |
90 | doc_comment.push_str(&format!( |
91 | " \n# Panics \nIndexing ` {}` with any of the following variants will cause a panic:" , |
92 | table_name |
93 | )); |
94 | for variant in disabled_variants { |
95 | doc_comment.push_str(&format!(" \n\n- ` {}:: {}`" , name, variant)); |
96 | } |
97 | } |
98 | |
99 | let doc_new = format!( |
100 | "Create a new {} with a value for each variant of {}" , |
101 | table_name, name |
102 | ); |
103 | let doc_closure = format!( |
104 | "Create a new {} by running a function on each variant of ` {}`" , |
105 | table_name, name |
106 | ); |
107 | let doc_transform = format!("Create a new ` {}` by running a function on each variant of ` {}` and the corresponding value in the current ` {0}`" , table_name, name); |
108 | let doc_filled = format!( |
109 | "Create a new ` {}` with the same value in each field." , |
110 | table_name |
111 | ); |
112 | let doc_option_all = format!("Converts ` {}<Option<T>>` into `Option< {0}<T>>`. Returns `Some` if all fields are `Some`, otherwise returns `None`." , table_name); |
113 | let doc_result_all_ok = format!("Converts ` {}<Result<T, E>>` into `Result< {0}<T>, E>`. Returns `Ok` if all fields are `Ok`, otherwise returns `Err`." , table_name); |
114 | |
115 | Ok(quote! { |
116 | #[doc = #doc_comment] |
117 | #[allow( |
118 | missing_copy_implementations, |
119 | )] |
120 | #[derive(Debug, Clone, Default, PartialEq, Eq, Hash)] |
121 | #vis struct #table_name<T> { |
122 | #(#snake_idents: T,)* |
123 | } |
124 | |
125 | impl<T: Clone> #table_name<T> { |
126 | #[doc = #doc_filled] |
127 | #vis fn filled(value: T) -> #table_name<T> { |
128 | #table_name { |
129 | #(#snake_idents: value.clone(),)* |
130 | } |
131 | } |
132 | } |
133 | |
134 | impl<T> #table_name<T> { |
135 | #[doc = #doc_new] |
136 | #vis fn new( |
137 | #(#snake_idents: T,)* |
138 | ) -> #table_name<T> { |
139 | #table_name { |
140 | #(#snake_idents,)* |
141 | } |
142 | } |
143 | |
144 | #[doc = #doc_closure] |
145 | #vis fn from_closure<F: Fn(#name)->T>(func: F) -> #table_name<T> { |
146 | #table_name { |
147 | #(#closure_fields)* |
148 | } |
149 | } |
150 | |
151 | #[doc = #doc_transform] |
152 | #vis fn transform<U, F: Fn(#name, &T)->U>(&self, func: F) -> #table_name<U> { |
153 | #table_name { |
154 | #(#transform_fields)* |
155 | } |
156 | } |
157 | |
158 | } |
159 | |
160 | impl<T> ::core::ops::Index<#name> for #table_name<T> { |
161 | type Output = T; |
162 | |
163 | fn index(&self, idx: #name) -> &T { |
164 | match idx { |
165 | #(#get_matches)* |
166 | #(#disabled_matches)* |
167 | } |
168 | } |
169 | } |
170 | |
171 | impl<T> ::core::ops::IndexMut<#name> for #table_name<T> { |
172 | fn index_mut(&mut self, idx: #name) -> &mut T { |
173 | match idx { |
174 | #(#get_matches_mut)* |
175 | #(#disabled_matches)* |
176 | } |
177 | } |
178 | } |
179 | |
180 | impl<T> #table_name<::core::option::Option<T>> { |
181 | #[doc = #doc_option_all] |
182 | #vis fn all(self) -> ::core::option::Option<#table_name<T>> { |
183 | if let #table_name { |
184 | #(#snake_idents: ::core::option::Option::Some(#snake_idents),)* |
185 | } = self { |
186 | ::core::option::Option::Some(#table_name { |
187 | #(#snake_idents,)* |
188 | }) |
189 | } else { |
190 | ::core::option::Option::None |
191 | } |
192 | } |
193 | } |
194 | |
195 | impl<T, E> #table_name<::core::result::Result<T, E>> { |
196 | #[doc = #doc_result_all_ok] |
197 | #vis fn all_ok(self) -> ::core::result::Result<#table_name<T>, E> { |
198 | ::core::result::Result::Ok(#table_name { |
199 | #(#snake_idents: self.#snake_idents?,)* |
200 | }) |
201 | } |
202 | } |
203 | }) |
204 | } |
205 | |