1 | use crate::add_helpers::{struct_exprs, tuple_exprs}; |
2 | use crate::utils::{ |
3 | add_extra_type_param_bound_op_output, field_idents, named_to_vec, numbered_vars, |
4 | unnamed_to_vec, |
5 | }; |
6 | use proc_macro2::{Span, TokenStream}; |
7 | use quote::{quote, ToTokens}; |
8 | use std::iter; |
9 | use syn::{Data, DataEnum, DeriveInput, Field, Fields, Ident}; |
10 | |
11 | pub fn expand(input: &DeriveInput, trait_name: &str) -> TokenStream { |
12 | let trait_name = trait_name.trim_end_matches("Self" ); |
13 | let trait_ident = Ident::new(trait_name, Span::call_site()); |
14 | let method_name = trait_name.to_lowercase(); |
15 | let method_ident = Ident::new(&method_name, Span::call_site()); |
16 | let input_type = &input.ident; |
17 | |
18 | let generics = add_extra_type_param_bound_op_output(&input.generics, &trait_ident); |
19 | let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); |
20 | |
21 | let (output_type, block) = match input.data { |
22 | Data::Struct(ref data_struct) => match data_struct.fields { |
23 | Fields::Unnamed(ref fields) => ( |
24 | quote!(#input_type#ty_generics), |
25 | tuple_content(input_type, &unnamed_to_vec(fields), &method_ident), |
26 | ), |
27 | Fields::Named(ref fields) => ( |
28 | quote!(#input_type#ty_generics), |
29 | struct_content(input_type, &named_to_vec(fields), &method_ident), |
30 | ), |
31 | _ => panic!("Unit structs cannot use derive( {})" , trait_name), |
32 | }, |
33 | Data::Enum(ref data_enum) => ( |
34 | quote!(::core::result::Result<#input_type#ty_generics, &'static str>), |
35 | enum_content(input_type, data_enum, &method_ident), |
36 | ), |
37 | |
38 | _ => panic!("Only structs and enums can use derive( {})" , trait_name), |
39 | }; |
40 | |
41 | quote!( |
42 | impl#impl_generics ::core::ops::#trait_ident for #input_type#ty_generics #where_clause { |
43 | type Output = #output_type; |
44 | #[inline] |
45 | fn #method_ident(self, rhs: #input_type#ty_generics) -> #output_type { |
46 | #block |
47 | } |
48 | } |
49 | ) |
50 | } |
51 | |
52 | fn tuple_content<T: ToTokens>( |
53 | input_type: &T, |
54 | fields: &[&Field], |
55 | method_ident: &Ident, |
56 | ) -> TokenStream { |
57 | let exprs: Vec = tuple_exprs(fields, method_ident); |
58 | quote!(#input_type(#(#exprs),*)) |
59 | } |
60 | |
61 | fn struct_content( |
62 | input_type: &Ident, |
63 | fields: &[&Field], |
64 | method_ident: &Ident, |
65 | ) -> TokenStream { |
66 | // It's safe to unwrap because struct fields always have an identifier |
67 | let exprs: Vec = struct_exprs(fields, method_ident); |
68 | let field_names: Vec<&Ident> = field_idents(fields); |
69 | |
70 | quote!(#input_type{#(#field_names: #exprs),*}) |
71 | } |
72 | |
73 | #[allow (clippy::cognitive_complexity)] |
74 | fn enum_content( |
75 | input_type: &Ident, |
76 | data_enum: &DataEnum, |
77 | method_ident: &Ident, |
78 | ) -> TokenStream { |
79 | let mut matches = vec![]; |
80 | let mut method_iter = iter::repeat(method_ident); |
81 | |
82 | for variant in &data_enum.variants { |
83 | let subtype = &variant.ident; |
84 | let subtype = quote!(#input_type::#subtype); |
85 | |
86 | match variant.fields { |
87 | Fields::Unnamed(ref fields) => { |
88 | // The patern that is outputted should look like this: |
89 | // (Subtype(left_vars), TypePath(right_vars)) => Ok(TypePath(exprs)) |
90 | let size = unnamed_to_vec(fields).len(); |
91 | let l_vars = &numbered_vars(size, "l_" ); |
92 | let r_vars = &numbered_vars(size, "r_" ); |
93 | let method_iter = method_iter.by_ref(); |
94 | let matcher = quote! { |
95 | (#subtype(#(#l_vars),*), |
96 | #subtype(#(#r_vars),*)) => { |
97 | ::core::result::Result::Ok(#subtype(#(#l_vars.#method_iter(#r_vars)),*)) |
98 | } |
99 | }; |
100 | matches.push(matcher); |
101 | } |
102 | Fields::Named(ref fields) => { |
103 | // The patern that is outputted should look like this: |
104 | // (Subtype{a: __l_a, ...}, Subtype{a: __r_a, ...} => { |
105 | // Ok(Subtype{a: __l_a.add(__r_a), ...}) |
106 | // } |
107 | let field_vec = named_to_vec(fields); |
108 | let size = field_vec.len(); |
109 | let field_names = &field_idents(&field_vec); |
110 | let l_vars = &numbered_vars(size, "l_" ); |
111 | let r_vars = &numbered_vars(size, "r_" ); |
112 | let method_iter = method_iter.by_ref(); |
113 | let matcher = quote! { |
114 | (#subtype{#(#field_names: #l_vars),*}, |
115 | #subtype{#(#field_names: #r_vars),*}) => { |
116 | ::core::result::Result::Ok(#subtype{#(#field_names: #l_vars.#method_iter(#r_vars)),*}) |
117 | } |
118 | }; |
119 | matches.push(matcher); |
120 | } |
121 | Fields::Unit => { |
122 | let message = format!("Cannot {}() unit variants" , method_ident); |
123 | matches.push(quote!((#subtype, #subtype) => ::core::result::Result::Err(#message))); |
124 | } |
125 | } |
126 | } |
127 | |
128 | if data_enum.variants.len() > 1 { |
129 | // In the strange case where there's only one enum variant this is would be an unreachable |
130 | // match. |
131 | let message = format!("Trying to {} mismatched enum variants" , method_ident); |
132 | matches.push(quote!(_ => ::core::result::Result::Err(#message))); |
133 | } |
134 | quote!( |
135 | match (self, rhs) { |
136 | #(#matches),* |
137 | } |
138 | ) |
139 | } |
140 | |