1 | use crate::utils::{ |
2 | add_extra_type_param_bound_op_output, named_to_vec, unnamed_to_vec, |
3 | }; |
4 | use proc_macro2::{Span, TokenStream}; |
5 | use quote::{quote, ToTokens}; |
6 | use std::iter; |
7 | use syn::{Data, DataEnum, DeriveInput, Field, Fields, Ident, Index}; |
8 | |
9 | pub fn expand(input: &DeriveInput, trait_name: &str) -> TokenStream { |
10 | let trait_ident = Ident::new(trait_name, Span::call_site()); |
11 | let method_name = trait_name.to_lowercase(); |
12 | let method_ident = &Ident::new(&method_name, Span::call_site()); |
13 | let input_type = &input.ident; |
14 | |
15 | let generics = add_extra_type_param_bound_op_output(&input.generics, &trait_ident); |
16 | let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); |
17 | |
18 | let (output_type, block) = match input.data { |
19 | Data::Struct(ref data_struct) => match data_struct.fields { |
20 | Fields::Unnamed(ref fields) => ( |
21 | quote!(#input_type#ty_generics), |
22 | tuple_content(input_type, &unnamed_to_vec(fields), method_ident), |
23 | ), |
24 | Fields::Named(ref fields) => ( |
25 | quote!(#input_type#ty_generics), |
26 | struct_content(input_type, &named_to_vec(fields), method_ident), |
27 | ), |
28 | _ => panic!("Unit structs cannot use derive( {})" , trait_name), |
29 | }, |
30 | Data::Enum(ref data_enum) => { |
31 | enum_output_type_and_content(input, data_enum, method_ident) |
32 | } |
33 | |
34 | _ => panic!("Only structs and enums can use derive( {})" , trait_name), |
35 | }; |
36 | |
37 | quote!( |
38 | impl#impl_generics ::core::ops::#trait_ident for #input_type#ty_generics #where_clause { |
39 | type Output = #output_type; |
40 | #[inline] |
41 | fn #method_ident(self) -> #output_type { |
42 | #block |
43 | } |
44 | } |
45 | ) |
46 | } |
47 | |
48 | fn tuple_content<T: ToTokens>( |
49 | input_type: &T, |
50 | fields: &[&Field], |
51 | method_ident: &Ident, |
52 | ) -> TokenStream { |
53 | let mut exprs: Vec = vec![]; |
54 | |
55 | for i: usize in 0..fields.len() { |
56 | let i: Index = Index::from(i); |
57 | // generates `self.0.add()` |
58 | let expr: TokenStream = quote!(self.#i.#method_ident()); |
59 | exprs.push(expr); |
60 | } |
61 | |
62 | quote!(#input_type(#(#exprs),*)) |
63 | } |
64 | |
65 | fn struct_content( |
66 | input_type: &Ident, |
67 | fields: &[&Field], |
68 | method_ident: &Ident, |
69 | ) -> TokenStream { |
70 | let mut exprs: Vec = vec![]; |
71 | |
72 | for field: &&Field in fields { |
73 | // It's safe to unwrap because struct fields always have an identifier |
74 | let field_id: Option<&Ident> = field.ident.as_ref(); |
75 | // generates `x: self.x.not()` |
76 | let expr: TokenStream = quote!(#field_id: self.#field_id.#method_ident()); |
77 | exprs.push(expr) |
78 | } |
79 | |
80 | quote!(#input_type{#(#exprs),*}) |
81 | } |
82 | |
83 | fn enum_output_type_and_content( |
84 | input: &DeriveInput, |
85 | data_enum: &DataEnum, |
86 | method_ident: &Ident, |
87 | ) -> (TokenStream, TokenStream) { |
88 | let input_type = &input.ident; |
89 | let (_, ty_generics, _) = input.generics.split_for_impl(); |
90 | let mut matches = vec![]; |
91 | let mut method_iter = iter::repeat(method_ident); |
92 | // If the enum contains unit types that means it can error. |
93 | let has_unit_type = data_enum.variants.iter().any(|v| v.fields == Fields::Unit); |
94 | |
95 | for variant in &data_enum.variants { |
96 | let subtype = &variant.ident; |
97 | let subtype = quote!(#input_type::#subtype); |
98 | |
99 | match variant.fields { |
100 | Fields::Unnamed(ref fields) => { |
101 | // The patern that is outputted should look like this: |
102 | // (Subtype(vars)) => Ok(TypePath(exprs)) |
103 | let size = unnamed_to_vec(fields).len(); |
104 | let vars: &Vec<_> = &(0..size) |
105 | .map(|i| Ident::new(&format!("__ {}" , i), Span::call_site())) |
106 | .collect(); |
107 | let method_iter = method_iter.by_ref(); |
108 | let mut body = quote!(#subtype(#(#vars.#method_iter()),*)); |
109 | if has_unit_type { |
110 | body = quote!(::core::result::Result::Ok(#body)) |
111 | } |
112 | let matcher = quote! { |
113 | #subtype(#(#vars),*) => { |
114 | #body |
115 | } |
116 | }; |
117 | matches.push(matcher); |
118 | } |
119 | Fields::Named(ref fields) => { |
120 | // The patern that is outputted should look like this: |
121 | // (Subtype{a: __l_a, ...} => { |
122 | // Ok(Subtype{a: __l_a.neg(__r_a), ...}) |
123 | // } |
124 | let field_vec = named_to_vec(fields); |
125 | let size = field_vec.len(); |
126 | let field_names: &Vec<_> = &field_vec |
127 | .iter() |
128 | .map(|f| f.ident.as_ref().unwrap()) |
129 | .collect(); |
130 | let vars: &Vec<_> = &(0..size) |
131 | .map(|i| Ident::new(&format!("__ {}" , i), Span::call_site())) |
132 | .collect(); |
133 | let method_iter = method_iter.by_ref(); |
134 | let mut body = |
135 | quote!(#subtype{#(#field_names: #vars.#method_iter()),*}); |
136 | if has_unit_type { |
137 | body = quote!(::core::result::Result::Ok(#body)) |
138 | } |
139 | let matcher = quote! { |
140 | #subtype{#(#field_names: #vars),*} => { |
141 | #body |
142 | } |
143 | }; |
144 | matches.push(matcher); |
145 | } |
146 | Fields::Unit => { |
147 | let message = format!("Cannot {}() unit variants" , method_ident); |
148 | matches.push(quote!(#subtype => ::core::result::Result::Err(#message))); |
149 | } |
150 | } |
151 | } |
152 | |
153 | let body = quote!( |
154 | match self { |
155 | #(#matches),* |
156 | } |
157 | ); |
158 | |
159 | let output_type = if has_unit_type { |
160 | quote!(::core::result::Result<#input_type#ty_generics, &'static str>) |
161 | } else { |
162 | quote!(#input_type#ty_generics) |
163 | }; |
164 | |
165 | (output_type, body) |
166 | } |
167 | |