1use proc_macro2::TokenStream;
2use quote::quote;
3use syn::{Data, DeriveInput, Fields};
4
5use crate::helpers::{
6 non_enum_error, occurrence_error, HasStrumVariantProperties, HasTypeProperties,
7};
8
9pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
10 let name = &ast.ident;
11 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
12 let variants = match &ast.data {
13 Data::Enum(v) => &v.variants,
14 _ => return Err(non_enum_error()),
15 };
16
17 let type_properties = ast.get_type_properties()?;
18 let strum_module_path = type_properties.crate_module_path();
19
20 let mut default_kw = None;
21 let mut default =
22 quote! { ::core::result::Result::Err(#strum_module_path::ParseError::VariantNotFound) };
23
24 let mut phf_exact_match_arms = Vec::new();
25 let mut standard_match_arms = Vec::new();
26 for variant in variants {
27 let ident = &variant.ident;
28 let variant_properties = variant.get_variant_properties()?;
29
30 if variant_properties.disabled.is_some() {
31 continue;
32 }
33
34 if let Some(kw) = variant_properties.default {
35 if let Some(fst_kw) = default_kw {
36 return Err(occurrence_error(fst_kw, kw, "default"));
37 }
38
39 match &variant.fields {
40 Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {}
41 _ => {
42 return Err(syn::Error::new_spanned(
43 variant,
44 "Default only works on newtype structs with a single String field",
45 ))
46 }
47 }
48
49 default_kw = Some(kw);
50 default = quote! {
51 ::core::result::Result::Ok(#name::#ident(s.into()))
52 };
53 continue;
54 }
55
56 let params = match &variant.fields {
57 Fields::Unit => quote! {},
58 Fields::Unnamed(fields) => {
59 let defaults =
60 ::core::iter::repeat(quote!(Default::default())).take(fields.unnamed.len());
61 quote! { (#(#defaults),*) }
62 }
63 Fields::Named(fields) => {
64 let fields = fields
65 .named
66 .iter()
67 .map(|field| field.ident.as_ref().unwrap());
68 quote! { {#(#fields: Default::default()),*} }
69 }
70 };
71
72 let is_ascii_case_insensitive = variant_properties
73 .ascii_case_insensitive
74 .unwrap_or(type_properties.ascii_case_insensitive);
75
76 // If we don't have any custom variants, add the default serialized name.
77 for serialization in variant_properties.get_serializations(type_properties.case_style) {
78 if type_properties.use_phf {
79 phf_exact_match_arms.push(quote! { #serialization => #name::#ident #params, });
80
81 if is_ascii_case_insensitive {
82 // Store the lowercase and UPPERCASE variants in the phf map to capture
83 let ser_string = serialization.value();
84
85 let lower =
86 syn::LitStr::new(&ser_string.to_ascii_lowercase(), serialization.span());
87 let upper =
88 syn::LitStr::new(&ser_string.to_ascii_uppercase(), serialization.span());
89 phf_exact_match_arms.push(quote! { #lower => #name::#ident #params, });
90 phf_exact_match_arms.push(quote! { #upper => #name::#ident #params, });
91 standard_match_arms.push(quote! { s if s.eq_ignore_ascii_case(#serialization) => #name::#ident #params, });
92 }
93 } else {
94 standard_match_arms.push(if !is_ascii_case_insensitive {
95 quote! { #serialization => #name::#ident #params, }
96 } else {
97 quote! { s if s.eq_ignore_ascii_case(#serialization) => #name::#ident #params, }
98 });
99 }
100 }
101 }
102
103 let phf_body = if phf_exact_match_arms.is_empty() {
104 quote!()
105 } else {
106 quote! {
107 use #strum_module_path::_private_phf_reexport_for_macro_if_phf_feature as phf;
108 static PHF: phf::Map<&'static str, #name> = phf::phf_map! {
109 #(#phf_exact_match_arms)*
110 };
111 if let Some(value) = PHF.get(s).cloned() {
112 return ::core::result::Result::Ok(value);
113 }
114 }
115 };
116 let standard_match_body = if standard_match_arms.is_empty() {
117 default
118 } else {
119 quote! {
120 ::core::result::Result::Ok(match s {
121 #(#standard_match_arms)*
122 _ => return #default,
123 })
124 }
125 };
126
127 let from_str = quote! {
128 #[allow(clippy::use_self)]
129 impl #impl_generics ::core::str::FromStr for #name #ty_generics #where_clause {
130 type Err = #strum_module_path::ParseError;
131 fn from_str(s: &str) -> ::core::result::Result< #name #ty_generics , <Self as ::core::str::FromStr>::Err> {
132 #phf_body
133 #standard_match_body
134 }
135 }
136 };
137
138 let try_from_str = try_from_str(
139 name,
140 &impl_generics,
141 &ty_generics,
142 where_clause,
143 &strum_module_path,
144 );
145
146 Ok(quote! {
147 #from_str
148 #try_from_str
149 })
150}
151
152#[rustversion::before(1.34)]
153fn try_from_str(
154 _name: &proc_macro2::Ident,
155 _impl_generics: &syn::ImplGenerics,
156 _ty_generics: &syn::TypeGenerics,
157 _where_clause: Option<&syn::WhereClause>,
158 _strum_module_path: &syn::Path,
159) -> TokenStream {
160 Default::default()
161}
162
163#[rustversion::since(1.34)]
164fn try_from_str(
165 name: &proc_macro2::Ident,
166 impl_generics: &syn::ImplGenerics,
167 ty_generics: &syn::TypeGenerics,
168 where_clause: Option<&syn::WhereClause>,
169 strum_module_path: &syn::Path,
170) -> TokenStream {
171 quote! {
172 #[allow(clippy::use_self)]
173 impl #impl_generics ::core::convert::TryFrom<&str> for #name #ty_generics #where_clause {
174 type Error = #strum_module_path::ParseError;
175 fn try_from(s: &str) -> ::core::result::Result< #name #ty_generics , <Self as ::core::convert::TryFrom<&str>>::Error> {
176 ::core::str::FromStr::from_str(s)
177 }
178 }
179 }
180}
181