1 | use proc_macro2::{Ident, TokenStream}; |
2 | use quote::quote; |
3 | use syn::{punctuated::Punctuated, Data, DeriveInput, Fields, LitStr, Token}; |
4 | |
5 | use crate::helpers::{non_enum_error, HasStrumVariantProperties, HasTypeProperties}; |
6 | |
7 | pub fn display_inner(ast: &DeriveInput) -> syn::Result<TokenStream> { |
8 | let name = &ast.ident; |
9 | let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl(); |
10 | let variants = match &ast.data { |
11 | Data::Enum(v) => &v.variants, |
12 | _ => return Err(non_enum_error()), |
13 | }; |
14 | |
15 | let type_properties = ast.get_type_properties()?; |
16 | |
17 | let mut arms = Vec::new(); |
18 | for variant in variants { |
19 | let ident = &variant.ident; |
20 | let variant_properties = variant.get_variant_properties()?; |
21 | |
22 | if variant_properties.disabled.is_some() { |
23 | continue; |
24 | } |
25 | |
26 | // Look at all the serialize attributes. |
27 | let output = variant_properties |
28 | .get_preferred_name(type_properties.case_style, type_properties.prefix.as_ref()); |
29 | |
30 | let params = match variant.fields { |
31 | Fields::Unit => quote! {}, |
32 | Fields::Unnamed(..) => quote! { (..) }, |
33 | Fields::Named(ref field_names) => { |
34 | // Transform named params '{ name: String, age: u8 }' to '{ ref name, ref age }' |
35 | let names: Punctuated<TokenStream, Token!(,)> = field_names |
36 | .named |
37 | .iter() |
38 | .map(|field| { |
39 | let ident = field.ident.as_ref().unwrap(); |
40 | quote! { ref #ident } |
41 | }) |
42 | .collect(); |
43 | |
44 | quote! { {#names} } |
45 | } |
46 | }; |
47 | |
48 | if variant_properties.to_string.is_none() && variant_properties.default.is_some() { |
49 | match &variant.fields { |
50 | Fields::Unnamed(fields) if fields.unnamed.len() == 1 => { |
51 | arms.push(quote! { #name::#ident(ref s) => ::core::fmt::Display::fmt(s, f) }); |
52 | } |
53 | _ => { |
54 | return Err(syn::Error::new_spanned( |
55 | variant, |
56 | "Default only works on newtype structs with a single String field" , |
57 | )) |
58 | } |
59 | } |
60 | } else { |
61 | let arm = if let Fields::Named(ref field_names) = variant.fields { |
62 | let used_vars = capture_format_string_idents(&output)?; |
63 | if used_vars.is_empty() { |
64 | quote! { #name::#ident #params => ::core::fmt::Display::fmt(#output, f) } |
65 | } else { |
66 | // Create args like 'name = name, age = age' for format macro |
67 | let args: Punctuated<_, Token!(,)> = field_names |
68 | .named |
69 | .iter() |
70 | .filter_map(|field| { |
71 | let ident = field.ident.as_ref().unwrap(); |
72 | // Only contain variables that are used in format string |
73 | if !used_vars.contains(ident) { |
74 | None |
75 | } else { |
76 | Some(quote! { #ident = #ident }) |
77 | } |
78 | }) |
79 | .collect(); |
80 | |
81 | quote! { |
82 | #[allow(unused_variables)] |
83 | #name::#ident #params => ::core::fmt::Display::fmt(&format!(#output, #args), f) |
84 | } |
85 | } |
86 | } else { |
87 | quote! { #name::#ident #params => ::core::fmt::Display::fmt(#output, f) } |
88 | }; |
89 | |
90 | arms.push(arm); |
91 | } |
92 | } |
93 | |
94 | if arms.len() < variants.len() { |
95 | arms.push(quote! { _ => panic!("fmt() called on disabled variant." ) }); |
96 | } |
97 | |
98 | Ok(quote! { |
99 | impl #impl_generics ::core::fmt::Display for #name #ty_generics #where_clause { |
100 | fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::result::Result<(), ::core::fmt::Error> { |
101 | match *self { |
102 | #(#arms),* |
103 | } |
104 | } |
105 | } |
106 | }) |
107 | } |
108 | |
109 | fn capture_format_string_idents(string_literal: &LitStr) -> syn::Result<Vec<Ident>> { |
110 | // Remove escaped brackets |
111 | let format_str = string_literal.value().replace("{{" , "" ).replace("}}" , "" ); |
112 | |
113 | let mut new_var_start_index: Option<usize> = None; |
114 | let mut var_used: Vec<Ident> = Vec::new(); |
115 | |
116 | for (i, chr) in format_str.bytes().enumerate() { |
117 | if chr == b'{' { |
118 | if new_var_start_index.is_some() { |
119 | return Err(syn::Error::new_spanned( |
120 | string_literal, |
121 | "Bracket opened without closing previous bracket" , |
122 | )); |
123 | } |
124 | new_var_start_index = Some(i); |
125 | continue; |
126 | } |
127 | |
128 | if chr == b'}' { |
129 | let start_index = new_var_start_index.take().ok_or(syn::Error::new_spanned( |
130 | string_literal, |
131 | "Bracket closed without previous opened bracket" , |
132 | ))?; |
133 | |
134 | let inside_brackets = &format_str[start_index + 1..i]; |
135 | let ident_str = inside_brackets.split(":" ).next().unwrap(); |
136 | let ident = syn::parse_str::<Ident>(ident_str).map_err(|_| { |
137 | syn::Error::new_spanned( |
138 | string_literal, |
139 | "Invalid identifier inside format string bracket" , |
140 | ) |
141 | })?; |
142 | var_used.push(ident); |
143 | } |
144 | } |
145 | |
146 | Ok(var_used) |
147 | } |
148 | |