1 | use crate::utils::{ |
2 | add_extra_ty_param_bound, add_extra_where_clauses, MultiFieldData, State, |
3 | }; |
4 | use proc_macro2::{Span, TokenStream}; |
5 | use quote::quote; |
6 | use syn::{DeriveInput, Ident, Result}; |
7 | |
8 | pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStream> { |
9 | let state = State::new( |
10 | input, |
11 | trait_name, |
12 | quote!(::core::iter), |
13 | trait_name.to_lowercase(), |
14 | )?; |
15 | let multi_field_data = state.enabled_fields_data(); |
16 | let MultiFieldData { |
17 | input_type, |
18 | field_types, |
19 | trait_path, |
20 | method_ident, |
21 | .. |
22 | } = multi_field_data.clone(); |
23 | |
24 | let op_trait_name = if trait_name == "Sum" { "Add" } else { "Mul" }; |
25 | let op_trait_ident = Ident::new(op_trait_name, Span::call_site()); |
26 | let op_path = quote!(::core::ops::#op_trait_ident); |
27 | let op_method_ident = |
28 | Ident::new(&(op_trait_name.to_lowercase()), Span::call_site()); |
29 | let has_type_params = input.generics.type_params().next().is_none(); |
30 | let generics = if has_type_params { |
31 | input.generics.clone() |
32 | } else { |
33 | let (_, ty_generics, _) = input.generics.split_for_impl(); |
34 | let generics = add_extra_ty_param_bound(&input.generics, trait_path); |
35 | let operator_where_clause = quote! { |
36 | where #input_type#ty_generics: #op_path<Output=#input_type#ty_generics> |
37 | }; |
38 | add_extra_where_clauses(&generics, operator_where_clause) |
39 | }; |
40 | let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); |
41 | |
42 | let initializers: Vec<_> = field_types |
43 | .iter() |
44 | .map(|field_type| quote!(#trait_path::#method_ident(::core::iter::empty::<#field_type>()))) |
45 | .collect(); |
46 | let identity = multi_field_data.initializer(&initializers); |
47 | |
48 | Ok(quote!( |
49 | impl#impl_generics #trait_path for #input_type#ty_generics #where_clause { |
50 | #[inline] |
51 | fn #method_ident<I: ::core::iter::Iterator<Item = Self>>(iter: I) -> Self { |
52 | iter.fold(#identity, #op_path::#op_method_ident) |
53 | } |
54 | } |
55 | )) |
56 | } |
57 | |