1use std::iter;
2
3use proc_macro2::{Span, TokenStream};
4use quote::{quote, ToTokens};
5use syn::{parse::Result, DeriveInput, Ident, Index};
6
7use crate::utils::{
8 add_where_clauses_for_new_ident, AttrParams, DeriveType, HashMap, MultiFieldData,
9 RefType, State,
10};
11
12/// Provides the hook to expand `#[derive(From)]` into an implementation of `From`
13pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStream> {
14 let state: State<'_> = State::with_attr_params(
15 input,
16 trait_name,
17 trait_module:quote!(::core::convert),
18 trait_attr:trait_name.to_lowercase(),
19 allowed_attr_params:AttrParams {
20 enum_: vec!["forward", "ignore"],
21 variant: vec!["forward", "ignore", "types"],
22 struct_: vec!["forward", "types"],
23 field: vec!["forward"],
24 },
25 )?;
26 if state.derive_type == DeriveType::Enum {
27 Ok(enum_from(input, state))
28 } else {
29 Ok(struct_from(input, &state))
30 }
31}
32
33pub fn struct_from(input: &DeriveInput, state: &State) -> TokenStream {
34 let multi_field_data = state.enabled_fields_data();
35 let MultiFieldData {
36 fields,
37 variant_info,
38 infos,
39 input_type,
40 trait_path,
41 ..
42 } = multi_field_data.clone();
43
44 let additional_types = variant_info.additional_types(RefType::No);
45 let mut impls = Vec::with_capacity(additional_types.len() + 1);
46 for explicit_type in iter::once(None).chain(additional_types.iter().map(Some)) {
47 let mut new_generics = input.generics.clone();
48
49 let mut initializers = Vec::with_capacity(infos.len());
50 let mut from_types = Vec::with_capacity(infos.len());
51 for (i, (info, field)) in infos.iter().zip(fields.iter()).enumerate() {
52 let field_type = &field.ty;
53 let variable = if fields.len() == 1 {
54 quote! { original }
55 } else {
56 let tuple_index = Index::from(i);
57 quote! { original.#tuple_index }
58 };
59 if let Some(type_) = explicit_type {
60 initializers.push(quote! {
61 <#field_type as #trait_path<#type_>>::from(#variable)
62 });
63 from_types.push(quote! { #type_ });
64 } else if info.forward {
65 let type_param =
66 &Ident::new(&format!("__FromT{}", i), Span::call_site());
67 let sub_trait_path = quote! { #trait_path<#type_param> };
68 let type_where_clauses = quote! {
69 where #field_type: #sub_trait_path
70 };
71 new_generics = add_where_clauses_for_new_ident(
72 &new_generics,
73 &[field],
74 type_param,
75 type_where_clauses,
76 true,
77 );
78 let casted_trait = quote! { <#field_type as #sub_trait_path> };
79 initializers.push(quote! { #casted_trait::from(#variable) });
80 from_types.push(quote! { #type_param });
81 } else {
82 initializers.push(variable);
83 from_types.push(quote! { #field_type });
84 }
85 }
86
87 let body = multi_field_data.initializer(&initializers);
88 let (impl_generics, _, where_clause) = new_generics.split_for_impl();
89 let (_, ty_generics, _) = input.generics.split_for_impl();
90
91 impls.push(quote! {
92 #[automatically_derived]
93 impl#impl_generics #trait_path<(#(#from_types),*)> for
94 #input_type#ty_generics #where_clause {
95
96 #[inline]
97 fn from(original: (#(#from_types),*)) -> #input_type#ty_generics {
98 #body
99 }
100 }
101 });
102 }
103
104 quote! { #( #impls )* }
105}
106
107fn enum_from(input: &DeriveInput, state: State) -> TokenStream {
108 let mut tokens = TokenStream::new();
109
110 let mut variants_per_types = HashMap::default();
111 for variant_state in state.enabled_variant_data().variant_states {
112 let multi_field_data = variant_state.enabled_fields_data();
113 let MultiFieldData { field_types, .. } = multi_field_data.clone();
114 variants_per_types
115 .entry(field_types.clone())
116 .or_insert_with(Vec::new)
117 .push(variant_state);
118 }
119 for (ref field_types, ref variant_states) in variants_per_types {
120 for variant_state in variant_states {
121 let multi_field_data = variant_state.enabled_fields_data();
122 let MultiFieldData {
123 variant_info,
124 infos,
125 ..
126 } = multi_field_data.clone();
127 // If there would be a conflict on a empty tuple derive, ignore the
128 // variants that are not explicitly enabled or have explicitly enabled
129 // or disabled fields
130 if field_types.is_empty()
131 && variant_states.len() > 1
132 && !std::iter::once(variant_info)
133 .chain(infos)
134 .any(|info| info.info.enabled.is_some())
135 {
136 continue;
137 }
138 struct_from(input, variant_state).to_tokens(&mut tokens);
139 }
140 }
141 tokens
142}
143