1use crate::add_helpers::{struct_exprs, tuple_exprs};
2use crate::utils::{
3 add_extra_type_param_bound_op_output, field_idents, named_to_vec, numbered_vars,
4 unnamed_to_vec,
5};
6use proc_macro2::{Span, TokenStream};
7use quote::{quote, ToTokens};
8use std::iter;
9use syn::{Data, DataEnum, DeriveInput, Field, Fields, Ident};
10
11pub fn expand(input: &DeriveInput, trait_name: &str) -> TokenStream {
12 let trait_name = trait_name.trim_end_matches("Self");
13 let trait_ident = Ident::new(trait_name, Span::call_site());
14 let method_name = trait_name.to_lowercase();
15 let method_ident = Ident::new(&method_name, Span::call_site());
16 let input_type = &input.ident;
17
18 let generics = add_extra_type_param_bound_op_output(&input.generics, &trait_ident);
19 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
20
21 let (output_type, block) = match input.data {
22 Data::Struct(ref data_struct) => match data_struct.fields {
23 Fields::Unnamed(ref fields) => (
24 quote!(#input_type#ty_generics),
25 tuple_content(input_type, &unnamed_to_vec(fields), &method_ident),
26 ),
27 Fields::Named(ref fields) => (
28 quote!(#input_type#ty_generics),
29 struct_content(input_type, &named_to_vec(fields), &method_ident),
30 ),
31 _ => panic!("Unit structs cannot use derive({})", trait_name),
32 },
33 Data::Enum(ref data_enum) => (
34 quote!(::core::result::Result<#input_type#ty_generics, &'static str>),
35 enum_content(input_type, data_enum, &method_ident),
36 ),
37
38 _ => panic!("Only structs and enums can use derive({})", trait_name),
39 };
40
41 quote!(
42 impl#impl_generics ::core::ops::#trait_ident for #input_type#ty_generics #where_clause {
43 type Output = #output_type;
44 #[inline]
45 fn #method_ident(self, rhs: #input_type#ty_generics) -> #output_type {
46 #block
47 }
48 }
49 )
50}
51
52fn tuple_content<T: ToTokens>(
53 input_type: &T,
54 fields: &[&Field],
55 method_ident: &Ident,
56) -> TokenStream {
57 let exprs: Vec = tuple_exprs(fields, method_ident);
58 quote!(#input_type(#(#exprs),*))
59}
60
61fn struct_content(
62 input_type: &Ident,
63 fields: &[&Field],
64 method_ident: &Ident,
65) -> TokenStream {
66 // It's safe to unwrap because struct fields always have an identifier
67 let exprs: Vec = struct_exprs(fields, method_ident);
68 let field_names: Vec<&Ident> = field_idents(fields);
69
70 quote!(#input_type{#(#field_names: #exprs),*})
71}
72
73#[allow(clippy::cognitive_complexity)]
74fn enum_content(
75 input_type: &Ident,
76 data_enum: &DataEnum,
77 method_ident: &Ident,
78) -> TokenStream {
79 let mut matches = vec![];
80 let mut method_iter = iter::repeat(method_ident);
81
82 for variant in &data_enum.variants {
83 let subtype = &variant.ident;
84 let subtype = quote!(#input_type::#subtype);
85
86 match variant.fields {
87 Fields::Unnamed(ref fields) => {
88 // The patern that is outputted should look like this:
89 // (Subtype(left_vars), TypePath(right_vars)) => Ok(TypePath(exprs))
90 let size = unnamed_to_vec(fields).len();
91 let l_vars = &numbered_vars(size, "l_");
92 let r_vars = &numbered_vars(size, "r_");
93 let method_iter = method_iter.by_ref();
94 let matcher = quote! {
95 (#subtype(#(#l_vars),*),
96 #subtype(#(#r_vars),*)) => {
97 ::core::result::Result::Ok(#subtype(#(#l_vars.#method_iter(#r_vars)),*))
98 }
99 };
100 matches.push(matcher);
101 }
102 Fields::Named(ref fields) => {
103 // The patern that is outputted should look like this:
104 // (Subtype{a: __l_a, ...}, Subtype{a: __r_a, ...} => {
105 // Ok(Subtype{a: __l_a.add(__r_a), ...})
106 // }
107 let field_vec = named_to_vec(fields);
108 let size = field_vec.len();
109 let field_names = &field_idents(&field_vec);
110 let l_vars = &numbered_vars(size, "l_");
111 let r_vars = &numbered_vars(size, "r_");
112 let method_iter = method_iter.by_ref();
113 let matcher = quote! {
114 (#subtype{#(#field_names: #l_vars),*},
115 #subtype{#(#field_names: #r_vars),*}) => {
116 ::core::result::Result::Ok(#subtype{#(#field_names: #l_vars.#method_iter(#r_vars)),*})
117 }
118 };
119 matches.push(matcher);
120 }
121 Fields::Unit => {
122 let message = format!("Cannot {}() unit variants", method_ident);
123 matches.push(quote!((#subtype, #subtype) => ::core::result::Result::Err(#message)));
124 }
125 }
126 }
127
128 if data_enum.variants.len() > 1 {
129 // In the strange case where there's only one enum variant this is would be an unreachable
130 // match.
131 let message = format!("Trying to {} mismatched enum variants", method_ident);
132 matches.push(quote!(_ => ::core::result::Result::Err(#message)));
133 }
134 quote!(
135 match (self, rhs) {
136 #(#matches),*
137 }
138 )
139}
140