1use proc_macro2::{Span, TokenStream};
2use quote::{format_ident, quote};
3use syn::{spanned::Spanned, Data, DeriveInput, Fields};
4
5use crate::helpers::{non_enum_error, snakify, HasStrumVariantProperties};
6
7pub fn enum_table_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
8 let name = &ast.ident;
9 let gen = &ast.generics;
10 let vis = &ast.vis;
11 let mut doc_comment = format!("A map over the variants of `{}`", name);
12
13 if gen.lifetimes().count() > 0 {
14 return Err(syn::Error::new(
15 Span::call_site(),
16 "`EnumTable` doesn't support enums with lifetimes.",
17 ));
18 }
19
20 let variants = match &ast.data {
21 Data::Enum(v) => &v.variants,
22 _ => return Err(non_enum_error()),
23 };
24
25 let table_name = format_ident!("{}Table", name);
26
27 // the identifiers of each variant, in PascalCase
28 let mut pascal_idents = Vec::new();
29 // the identifiers of each struct field, in snake_case
30 let mut snake_idents = Vec::new();
31 // match arms in the form `MyEnumTable::Variant => &self.variant,`
32 let mut get_matches = Vec::new();
33 // match arms in the form `MyEnumTable::Variant => &mut self.variant,`
34 let mut get_matches_mut = Vec::new();
35 // match arms in the form `MyEnumTable::Variant => self.variant = new_value`
36 let mut set_matches = Vec::new();
37 // struct fields of the form `variant: func(MyEnum::Variant),*
38 let mut closure_fields = Vec::new();
39 // struct fields of the form `variant: func(MyEnum::Variant, self.variant),`
40 let mut transform_fields = Vec::new();
41
42 // identifiers for disabled variants
43 let mut disabled_variants = Vec::new();
44 // match arms for disabled variants
45 let mut disabled_matches = Vec::new();
46
47 for variant in variants {
48 // skip disabled variants
49 if variant.get_variant_properties()?.disabled.is_some() {
50 let disabled_ident = &variant.ident;
51 let panic_message = format!(
52 "Can't use `{}` with `{}` - variant is disabled for Strum features",
53 disabled_ident, table_name
54 );
55 disabled_variants.push(disabled_ident);
56 disabled_matches.push(quote!(#name::#disabled_ident => panic!(#panic_message),));
57 continue;
58 }
59
60 // Error on variants with data
61 if variant.fields != Fields::Unit {
62 return Err(syn::Error::new(
63 variant.fields.span(),
64 "`EnumTable` doesn't support enums with non-unit variants",
65 ));
66 };
67
68 let pascal_case = &variant.ident;
69 let snake_case = format_ident!("_{}", snakify(&pascal_case.to_string()));
70
71 get_matches.push(quote! {#name::#pascal_case => &self.#snake_case,});
72 get_matches_mut.push(quote! {#name::#pascal_case => &mut self.#snake_case,});
73 set_matches.push(quote! {#name::#pascal_case => self.#snake_case = new_value,});
74 closure_fields.push(quote! {#snake_case: func(#name::#pascal_case),});
75 transform_fields.push(quote! {#snake_case: func(#name::#pascal_case, &self.#snake_case),});
76 pascal_idents.push(pascal_case);
77 snake_idents.push(snake_case);
78 }
79
80 // Error on empty enums
81 if pascal_idents.is_empty() {
82 return Err(syn::Error::new(
83 variants.span(),
84 "`EnumTable` requires at least one non-disabled variant",
85 ));
86 }
87
88 // if the index operation can panic, add that to the documentation
89 if !disabled_variants.is_empty() {
90 doc_comment.push_str(&format!(
91 "\n# Panics\nIndexing `{}` with any of the following variants will cause a panic:",
92 table_name
93 ));
94 for variant in disabled_variants {
95 doc_comment.push_str(&format!("\n\n- `{}::{}`", name, variant));
96 }
97 }
98
99 let doc_new = format!(
100 "Create a new {} with a value for each variant of {}",
101 table_name, name
102 );
103 let doc_closure = format!(
104 "Create a new {} by running a function on each variant of `{}`",
105 table_name, name
106 );
107 let doc_transform = format!("Create a new `{}` by running a function on each variant of `{}` and the corresponding value in the current `{0}`", table_name, name);
108 let doc_filled = format!(
109 "Create a new `{}` with the same value in each field.",
110 table_name
111 );
112 let doc_option_all = format!("Converts `{}<Option<T>>` into `Option<{0}<T>>`. Returns `Some` if all fields are `Some`, otherwise returns `None`.", table_name);
113 let doc_result_all_ok = format!("Converts `{}<Result<T, E>>` into `Result<{0}<T>, E>`. Returns `Ok` if all fields are `Ok`, otherwise returns `Err`.", table_name);
114
115 Ok(quote! {
116 #[doc = #doc_comment]
117 #[allow(
118 missing_copy_implementations,
119 )]
120 #[derive(Debug, Clone, Default, PartialEq, Eq, Hash)]
121 #vis struct #table_name<T> {
122 #(#snake_idents: T,)*
123 }
124
125 impl<T: Clone> #table_name<T> {
126 #[doc = #doc_filled]
127 #vis fn filled(value: T) -> #table_name<T> {
128 #table_name {
129 #(#snake_idents: value.clone(),)*
130 }
131 }
132 }
133
134 impl<T> #table_name<T> {
135 #[doc = #doc_new]
136 #vis fn new(
137 #(#snake_idents: T,)*
138 ) -> #table_name<T> {
139 #table_name {
140 #(#snake_idents,)*
141 }
142 }
143
144 #[doc = #doc_closure]
145 #vis fn from_closure<F: Fn(#name)->T>(func: F) -> #table_name<T> {
146 #table_name {
147 #(#closure_fields)*
148 }
149 }
150
151 #[doc = #doc_transform]
152 #vis fn transform<U, F: Fn(#name, &T)->U>(&self, func: F) -> #table_name<U> {
153 #table_name {
154 #(#transform_fields)*
155 }
156 }
157
158 }
159
160 impl<T> ::core::ops::Index<#name> for #table_name<T> {
161 type Output = T;
162
163 fn index(&self, idx: #name) -> &T {
164 match idx {
165 #(#get_matches)*
166 #(#disabled_matches)*
167 }
168 }
169 }
170
171 impl<T> ::core::ops::IndexMut<#name> for #table_name<T> {
172 fn index_mut(&mut self, idx: #name) -> &mut T {
173 match idx {
174 #(#get_matches_mut)*
175 #(#disabled_matches)*
176 }
177 }
178 }
179
180 impl<T> #table_name<::core::option::Option<T>> {
181 #[doc = #doc_option_all]
182 #vis fn all(self) -> ::core::option::Option<#table_name<T>> {
183 if let #table_name {
184 #(#snake_idents: ::core::option::Option::Some(#snake_idents),)*
185 } = self {
186 ::core::option::Option::Some(#table_name {
187 #(#snake_idents,)*
188 })
189 } else {
190 ::core::option::Option::None
191 }
192 }
193 }
194
195 impl<T, E> #table_name<::core::result::Result<T, E>> {
196 #[doc = #doc_result_all_ok]
197 #vis fn all_ok(self) -> ::core::result::Result<#table_name<T>, E> {
198 ::core::result::Result::Ok(#table_name {
199 #(#snake_idents: self.#snake_idents?,)*
200 })
201 }
202 }
203 })
204}
205