1 | use std::iter; |
2 | |
3 | use proc_macro2::{Span, TokenStream}; |
4 | use quote::{quote, ToTokens}; |
5 | use syn::{parse::Result, DeriveInput, Ident, Index}; |
6 | |
7 | use crate::utils::{ |
8 | add_where_clauses_for_new_ident, AttrParams, DeriveType, HashMap, MultiFieldData, |
9 | RefType, State, |
10 | }; |
11 | |
12 | /// Provides the hook to expand `#[derive(From)]` into an implementation of `From` |
13 | pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStream> { |
14 | let state: State<'_> = State::with_attr_params( |
15 | input, |
16 | trait_name, |
17 | trait_module:quote!(::core::convert), |
18 | trait_attr:trait_name.to_lowercase(), |
19 | allowed_attr_params:AttrParams { |
20 | enum_: vec!["forward" , "ignore" ], |
21 | variant: vec!["forward" , "ignore" , "types" ], |
22 | struct_: vec!["forward" , "types" ], |
23 | field: vec!["forward" ], |
24 | }, |
25 | )?; |
26 | if state.derive_type == DeriveType::Enum { |
27 | Ok(enum_from(input, state)) |
28 | } else { |
29 | Ok(struct_from(input, &state)) |
30 | } |
31 | } |
32 | |
33 | pub fn struct_from(input: &DeriveInput, state: &State) -> TokenStream { |
34 | let multi_field_data = state.enabled_fields_data(); |
35 | let MultiFieldData { |
36 | fields, |
37 | variant_info, |
38 | infos, |
39 | input_type, |
40 | trait_path, |
41 | .. |
42 | } = multi_field_data.clone(); |
43 | |
44 | let additional_types = variant_info.additional_types(RefType::No); |
45 | let mut impls = Vec::with_capacity(additional_types.len() + 1); |
46 | for explicit_type in iter::once(None).chain(additional_types.iter().map(Some)) { |
47 | let mut new_generics = input.generics.clone(); |
48 | |
49 | let mut initializers = Vec::with_capacity(infos.len()); |
50 | let mut from_types = Vec::with_capacity(infos.len()); |
51 | for (i, (info, field)) in infos.iter().zip(fields.iter()).enumerate() { |
52 | let field_type = &field.ty; |
53 | let variable = if fields.len() == 1 { |
54 | quote! { original } |
55 | } else { |
56 | let tuple_index = Index::from(i); |
57 | quote! { original.#tuple_index } |
58 | }; |
59 | if let Some(type_) = explicit_type { |
60 | initializers.push(quote! { |
61 | <#field_type as #trait_path<#type_>>::from(#variable) |
62 | }); |
63 | from_types.push(quote! { #type_ }); |
64 | } else if info.forward { |
65 | let type_param = |
66 | &Ident::new(&format!("__FromT {}" , i), Span::call_site()); |
67 | let sub_trait_path = quote! { #trait_path<#type_param> }; |
68 | let type_where_clauses = quote! { |
69 | where #field_type: #sub_trait_path |
70 | }; |
71 | new_generics = add_where_clauses_for_new_ident( |
72 | &new_generics, |
73 | &[field], |
74 | type_param, |
75 | type_where_clauses, |
76 | true, |
77 | ); |
78 | let casted_trait = quote! { <#field_type as #sub_trait_path> }; |
79 | initializers.push(quote! { #casted_trait::from(#variable) }); |
80 | from_types.push(quote! { #type_param }); |
81 | } else { |
82 | initializers.push(variable); |
83 | from_types.push(quote! { #field_type }); |
84 | } |
85 | } |
86 | |
87 | let body = multi_field_data.initializer(&initializers); |
88 | let (impl_generics, _, where_clause) = new_generics.split_for_impl(); |
89 | let (_, ty_generics, _) = input.generics.split_for_impl(); |
90 | |
91 | impls.push(quote! { |
92 | #[automatically_derived] |
93 | impl#impl_generics #trait_path<(#(#from_types),*)> for |
94 | #input_type#ty_generics #where_clause { |
95 | |
96 | #[inline] |
97 | fn from(original: (#(#from_types),*)) -> #input_type#ty_generics { |
98 | #body |
99 | } |
100 | } |
101 | }); |
102 | } |
103 | |
104 | quote! { #( #impls )* } |
105 | } |
106 | |
107 | fn enum_from(input: &DeriveInput, state: State) -> TokenStream { |
108 | let mut tokens = TokenStream::new(); |
109 | |
110 | let mut variants_per_types = HashMap::default(); |
111 | for variant_state in state.enabled_variant_data().variant_states { |
112 | let multi_field_data = variant_state.enabled_fields_data(); |
113 | let MultiFieldData { field_types, .. } = multi_field_data.clone(); |
114 | variants_per_types |
115 | .entry(field_types.clone()) |
116 | .or_insert_with(Vec::new) |
117 | .push(variant_state); |
118 | } |
119 | for (ref field_types, ref variant_states) in variants_per_types { |
120 | for variant_state in variant_states { |
121 | let multi_field_data = variant_state.enabled_fields_data(); |
122 | let MultiFieldData { |
123 | variant_info, |
124 | infos, |
125 | .. |
126 | } = multi_field_data.clone(); |
127 | // If there would be a conflict on a empty tuple derive, ignore the |
128 | // variants that are not explicitly enabled or have explicitly enabled |
129 | // or disabled fields |
130 | if field_types.is_empty() |
131 | && variant_states.len() > 1 |
132 | && !std::iter::once(variant_info) |
133 | .chain(infos) |
134 | .any(|info| info.info.enabled.is_some()) |
135 | { |
136 | continue; |
137 | } |
138 | struct_from(input, variant_state).to_tokens(&mut tokens); |
139 | } |
140 | } |
141 | tokens |
142 | } |
143 | |