1use proc_macro2::{Span, TokenStream};
2use quote::{format_ident, quote};
3use syn::{Data, DeriveInput, Fields, Type};
4
5use crate::helpers::{non_enum_error, HasStrumVariantProperties, HasTypeProperties};
6
7pub fn from_repr_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
8 let name = &ast.ident;
9 let gen = &ast.generics;
10 let (impl_generics, ty_generics, where_clause) = gen.split_for_impl();
11 let vis = &ast.vis;
12
13 let mut discriminant_type: Type = syn::parse("usize".parse().unwrap()).unwrap();
14 if let Some(type_path) = ast
15 .get_type_properties()
16 .ok()
17 .and_then(|tp| tp.enum_repr)
18 .and_then(|repr_ts| syn::parse2::<Type>(repr_ts).ok())
19 {
20 if let Type::Path(path) = type_path.clone() {
21 if let Some(seg) = path.path.segments.last() {
22 for t in &[
23 "u8", "u16", "u32", "u64", "usize", "i8", "i16", "i32", "i64", "isize",
24 ] {
25 if seg.ident == t {
26 discriminant_type = type_path;
27 break;
28 }
29 }
30 }
31 }
32 }
33
34 if gen.lifetimes().count() > 0 {
35 return Err(syn::Error::new(
36 Span::call_site(),
37 "This macro doesn't support enums with lifetimes. \
38 The resulting enums would be unbounded.",
39 ));
40 }
41
42 let variants = match &ast.data {
43 Data::Enum(v) => &v.variants,
44 _ => return Err(non_enum_error()),
45 };
46
47 let mut arms = Vec::new();
48 let mut constant_defs = Vec::new();
49 let mut has_additional_data = false;
50 let mut prev_const_var_ident = None;
51 for variant in variants {
52 if variant.get_variant_properties()?.disabled.is_some() {
53 continue;
54 }
55
56 let ident = &variant.ident;
57 let params = match &variant.fields {
58 Fields::Unit => quote! {},
59 Fields::Unnamed(fields) => {
60 has_additional_data = true;
61 let defaults = ::core::iter::repeat(quote!(::core::default::Default::default()))
62 .take(fields.unnamed.len());
63 quote! { (#(#defaults),*) }
64 }
65 Fields::Named(fields) => {
66 has_additional_data = true;
67 let fields = fields
68 .named
69 .iter()
70 .map(|field| field.ident.as_ref().unwrap());
71 quote! { {#(#fields: ::core::default::Default::default()),*} }
72 }
73 };
74
75 let const_var_str = format!("{}_DISCRIMINANT", variant.ident);
76 let const_var_ident = format_ident!("{}", const_var_str);
77
78 let const_val_expr = match &variant.discriminant {
79 Some((_, expr)) => quote! { #expr },
80 None => match &prev_const_var_ident {
81 Some(prev) => quote! { #prev + 1 },
82 None => quote! { 0 },
83 },
84 };
85
86 constant_defs.push(quote! {
87 #[allow(non_upper_case_globals)]
88 const #const_var_ident: #discriminant_type = #const_val_expr;
89 });
90 arms.push(quote! {v if v == #const_var_ident => ::core::option::Option::Some(#name::#ident #params)});
91
92 prev_const_var_ident = Some(const_var_ident);
93 }
94
95 arms.push(quote! { _ => ::core::option::Option::None });
96
97 let const_if_possible = if has_additional_data {
98 quote! {}
99 } else {
100 #[rustversion::before(1.46)]
101 fn filter_by_rust_version(_: TokenStream) -> TokenStream {
102 quote! {}
103 }
104
105 #[rustversion::since(1.46)]
106 fn filter_by_rust_version(s: TokenStream) -> TokenStream {
107 s
108 }
109 filter_by_rust_version(quote! { const })
110 };
111
112 Ok(quote! {
113 #[allow(clippy::use_self)]
114 impl #impl_generics #name #ty_generics #where_clause {
115 #[doc = "Try to create [Self] from the raw representation"]
116 #vis #const_if_possible fn from_repr(discriminant: #discriminant_type) -> Option<#name #ty_generics> {
117 #(#constant_defs)*
118 match discriminant {
119 #(#arms),*
120 }
121 }
122 }
123 })
124}
125