1//! Parses [`DeriveInput`] into something more useful.
2
3use proc_macro2::Span;
4use syn::{DeriveInput, GenericParam, Generics, ImplGenerics, Result, TypeGenerics, WhereClause};
5
6#[cfg(feature = "zeroize")]
7use crate::DeriveTrait;
8use crate::{Data, DeriveWhere, Discriminant, Either, Error, Item, ItemAttr, Trait};
9
10/// Parsed input.
11pub struct Input<'a> {
12 /// `derive_where` attributes on the item.
13 pub derive_wheres: Vec<DeriveWhere>,
14 /// Generics necessary to define for an `impl`.
15 pub generics: SplitGenerics<'a>,
16 /// Fields or variants of this item.
17 pub item: Item<'a>,
18}
19
20impl<'a> Input<'a> {
21 /// Create [`Input`] from `proc_macro_derive` parameter.
22 pub fn from_input(
23 span: Span,
24 DeriveInput {
25 attrs,
26 ident,
27 generics,
28 data,
29 ..
30 }: &'a DeriveInput,
31 ) -> Result<Self> {
32 // Parse `Attribute`s on item.
33 let ItemAttr {
34 skip_inner,
35 derive_wheres,
36 incomparable,
37 } = ItemAttr::from_attrs(span, data, attrs)?;
38
39 // Find if `incomparable` is specified on any item/variant.
40 let mut found_incomparable = incomparable.0.is_some();
41
42 // Extract fields and variants of this item.
43 let item = match &data {
44 syn::Data::Struct(data) => Data::from_struct(
45 span,
46 &derive_wheres,
47 skip_inner,
48 incomparable,
49 ident,
50 &data.fields,
51 )
52 .map(Item::Item)?,
53 syn::Data::Enum(data) => {
54 let discriminant = Discriminant::parse(attrs, &data.variants)?;
55
56 let variants = data
57 .variants
58 .iter()
59 .map(|variant| Data::from_variant(ident, &derive_wheres, variant))
60 .collect::<Result<Vec<Data>>>()?;
61
62 // Find if a default option is specified on a variant.
63 let mut found_default = false;
64
65 // While searching for a default option, check for duplicates.
66 for variant in &variants {
67 if let Some(span) = variant.default_span() {
68 if found_default {
69 return Err(Error::default_duplicate(span));
70 } else {
71 found_default = true;
72 }
73 }
74 if let (Some(item), Some(variant)) = (incomparable.0, variant.incomparable.0) {
75 return Err(Error::incomparable_on_item_and_variant(item, variant));
76 }
77 found_incomparable |= variant.is_incomparable();
78 }
79
80 // Make sure a variant has the `option` attribute if `Default` is being
81 // implemented.
82 if !found_default
83 && derive_wheres
84 .iter()
85 .any(|derive_where| derive_where.contains(Trait::Default))
86 {
87 return Err(Error::default_missing(span));
88 }
89
90 // Empty enums aren't allowed unless they implement `Default` or are
91 // incomparable.
92 if !found_default
93 && !found_incomparable
94 && variants.iter().all(|variant| match variant.fields() {
95 Either::Left(fields) => fields.fields.is_empty(),
96 Either::Right(_) => true,
97 }) {
98 return Err(Error::item_empty(span));
99 }
100
101 Item::Enum {
102 discriminant,
103 ident,
104 variants,
105 incomparable,
106 }
107 }
108 syn::Data::Union(data) => Data::from_union(
109 span,
110 &derive_wheres,
111 skip_inner,
112 incomparable,
113 ident,
114 &data.fields,
115 )
116 .map(Item::Item)?,
117 };
118
119 // Don't allow generic constraints be the same as generics on item unless there
120 // is a use-case for it.
121 // Count number of generic type parameters.
122 let generics_len = generics
123 .params
124 .iter()
125 .filter(|generic_param| match generic_param {
126 GenericParam::Type(_) => true,
127 GenericParam::Lifetime(_) | GenericParam::Const(_) => false,
128 })
129 .count();
130
131 'outer: for derive_where in &derive_wheres {
132 // No point in starting to compare both if not even the length is the same.
133 // This can be easily circumvented by doing the following:
134 // `#[derive_where(..; T: Clone)]`, or `#[derive_where(..; T, T)]`, which
135 // apparently is valid Rust syntax: `where T: Clone, T: Clone`, we are only here
136 // to help though.
137 if derive_where.generics.len() != generics_len {
138 continue;
139 }
140
141 // No point in starting to check if there is no use-case if a custom bound was
142 // used, which is a use-case.
143 if derive_where.any_custom_bound() {
144 continue;
145 }
146
147 // Check if every generic type parameter present on the item is defined in this
148 // `DeriveWhere`.
149 for generic_param in &generics.params {
150 // Only check generic type parameters.
151 if let GenericParam::Type(type_param) = generic_param {
152 if !derive_where.has_type_param(&type_param.ident) {
153 continue 'outer;
154 }
155 }
156 }
157
158 // The `for` loop should short-circuit to the `'outer` loop if not all generic
159 // type parameters were found.
160
161 // Don't allow no use-case compared to std `derive`.
162 for (span, trait_) in derive_where.spans.iter().zip(&derive_where.traits) {
163 // `Default` is used on an enum.
164 if trait_ == Trait::Default && item.is_enum() {
165 continue;
166 }
167
168 // Any field is skipped with a corresponding `Trait`.
169 if item.any_skip_trait(**trait_) {
170 continue;
171 }
172
173 // Any variant is marked as incomparable.
174 if found_incomparable {
175 continue;
176 }
177
178 #[cfg(feature = "zeroize")]
179 {
180 // `Zeroize(crate = ..)` or `ZeroizeOnDrop(crate = ..)` is used.
181 if let DeriveTrait::Zeroize { crate_: Some(_) }
182 | DeriveTrait::ZeroizeOnDrop { crate_: Some(_) } = *trait_
183 {
184 continue;
185 }
186
187 // `Zeroize(fqs)` is used on any field.
188 if trait_ == Trait::Zeroize && item.any_fqs() {
189 continue;
190 }
191 }
192
193 return Err(Error::use_case(*span));
194 }
195 }
196
197 let generics = SplitGenerics::new(generics);
198
199 Ok(Self {
200 derive_wheres,
201 generics,
202 item,
203 })
204 }
205}
206
207/// Stores output of [`Generics::split_for_impl()`].
208pub struct SplitGenerics<'a> {
209 /// Necessary generic definitions.
210 pub imp: ImplGenerics<'a>,
211 /// Generics on the type itself.
212 pub ty: TypeGenerics<'a>,
213 /// `where` clause.
214 pub where_clause: Option<&'a WhereClause>,
215}
216
217impl<'a> SplitGenerics<'a> {
218 /// Creates a [`SplitGenerics`] from [`Generics`].
219 fn new(generics: &'a Generics) -> Self {
220 let (imp: ImplGenerics<'_>, ty: TypeGenerics<'_>, where_clause: Option<&WhereClause>) = generics.split_for_impl();
221
222 SplitGenerics {
223 imp,
224 ty,
225 where_clause,
226 }
227 }
228}
229