1//! [`Attribute`] parsing for items.
2
3use std::{borrow::Cow, ops::Deref};
4
5use proc_macro2::Span;
6use syn::{
7 parse::{discouraged::Speculative, Parse, ParseStream},
8 punctuated::Punctuated,
9 spanned::Spanned,
10 Attribute, Data, Ident, Meta, Path, PredicateType, Result, Token, TraitBound,
11 TraitBoundModifier, Type, TypeParamBound, TypePath, WhereClause, WherePredicate,
12};
13
14use crate::{
15 util::{self, MetaListExt},
16 Error, Incomparable, Item, Skip, SkipGroup, Trait, TraitImpl, DERIVE_WHERE,
17};
18
19/// Attributes on item.
20#[derive(Default)]
21pub struct ItemAttr {
22 /// [`Trait`]s to skip all fields for.
23 pub skip_inner: Skip,
24 /// Comparing item will yield `false` for [`PartialEq`] and [`None`] for
25 /// [`PartialOrd`].
26 pub incomparable: Incomparable,
27 /// [`DeriveWhere`]s on this item.
28 pub derive_wheres: Vec<DeriveWhere>,
29}
30
31impl ItemAttr {
32 /// Create [`ItemAttr`] from [`Attribute`]s.
33 pub fn from_attrs(span: Span, data: &Data, attrs: &[Attribute]) -> Result<Self> {
34 let mut self_ = ItemAttr::default();
35 let mut skip_inners = Vec::new();
36 let mut incomparables = Vec::new();
37
38 for attr in attrs {
39 if attr.path().is_ident(DERIVE_WHERE) {
40 if let Meta::List(list) = &attr.meta {
41 if let Ok(nested) =
42 list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
43 {
44 match nested.len() {
45 // Don't allow an empty list.
46 0 => return Err(Error::empty(list.span())),
47 // Check for `skip_inner` if list only has one item.
48 1 => {
49 let meta =
50 nested.into_iter().next().expect("unexpected empty list");
51
52 if meta.path().is_ident(Skip::SKIP_INNER) {
53 // Don't allow `skip_inner` on the item level for enums.
54 if let Data::Enum(_) = data {
55 return Err(Error::option_enum_skip_inner(meta.span()));
56 }
57
58 // Don't parse `Skip` yet, because it needs access to all
59 // `DeriveWhere`s.
60 skip_inners.push(meta);
61 } else if meta.path().is_ident(Incomparable::INCOMPARABLE) {
62 // Needs to be parsed after all traits are known.
63 incomparables.push(meta)
64 } else if meta.path().is_ident("crate") {
65 // Do nothing, we checked this before
66 // already.
67 }
68 // The list can have one item but still not be the `skip_inner`
69 // attribute, continue with parsing `DeriveWhere`.
70 else {
71 self_
72 .derive_wheres
73 .push(DeriveWhere::from_attr(span, data, attr)?);
74 }
75 }
76 _ => self_
77 .derive_wheres
78 .push(DeriveWhere::from_attr(span, data, attr)?),
79 }
80 }
81 // Anything list that isn't using `,` as separator, is because we expect
82 // `A, B; C`.
83 else {
84 self_
85 .derive_wheres
86 .push(DeriveWhere::from_attr(span, data, attr)?)
87 }
88 } else {
89 return Err(Error::option_syntax(attr.meta.span()));
90 }
91 }
92 }
93
94 // Check that we specified at least one `#[derive_where(..)]` with traits.
95 if self_.derive_wheres.is_empty() {
96 return Err(Error::none(span));
97 }
98
99 // Merge `DeriveWhere`s with the same bounds.
100 self_
101 .derive_wheres
102 .dedup_by(|derive_where_1, derive_where_2| {
103 if derive_where_1.generics == derive_where_2.generics {
104 derive_where_2.spans.append(&mut derive_where_1.spans);
105 derive_where_2.traits.append(&mut derive_where_1.traits);
106 true
107 } else {
108 false
109 }
110 });
111
112 // Check for duplicate traits in the same `derive_where` after merging with the
113 // same bounds.
114 for derive_where in &self_.derive_wheres {
115 for (skip, trait_) in (1..).zip(&derive_where.traits) {
116 if let Some((span, _)) = derive_where
117 .spans
118 .iter()
119 .zip(&derive_where.traits)
120 .skip(skip)
121 .find(|(_, other_trait)| *other_trait == trait_)
122 {
123 return Err(Error::trait_duplicate(*span));
124 }
125 }
126 }
127
128 // Delayed parsing of `skip_inner` and `incomparable` to get access to all
129 // traits to be implemented.
130 for meta in skip_inners {
131 self_
132 .skip_inner
133 .add_attribute(&self_.derive_wheres, None, &meta)?;
134 }
135
136 for meta in incomparables {
137 self_
138 .incomparable
139 .add_attribute(&meta, &self_.derive_wheres)?;
140 }
141
142 Ok(self_)
143 }
144}
145
146/// Holds parsed [generics](Generic) and [traits](crate::Trait).
147pub struct DeriveWhere {
148 /// [`Span`]s for each [trait](DeriveTrait).
149 pub spans: Vec<Span>,
150 /// [Traits](DeriveTrait) to implement.
151 pub traits: Vec<DeriveTrait>,
152 /// [Generics](Generic) for where clause.
153 pub generics: Vec<Generic>,
154}
155
156impl DeriveWhere {
157 /// Create [`DeriveWhere`] from [`Attribute`].
158 fn from_attr(span: Span, data: &Data, attr: &Attribute) -> Result<Self> {
159 attr.parse_args_with(|input: ParseStream| {
160 // Parse the attribute input, this should either be:
161 // - Comma separated traits.
162 // - Comma separated traits `;` Comma separated generics.
163
164 let mut spans = Vec::new();
165 let mut traits = Vec::new();
166 let mut generics = Vec::new();
167
168 // Check for an empty list is already done in `ItemAttr::from_attrs`.
169 assert!(!input.is_empty());
170
171 while !input.is_empty() {
172 // Start with parsing a trait.
173 // Not checking for duplicates here, we do that after merging `derive_where`s
174 // with the same bounds.
175 let (span, trait_) = DeriveTrait::from_stream(span, data, input)?;
176 spans.push(span);
177 traits.push(trait_);
178
179 if !input.is_empty() {
180 let mut fork = input.fork();
181
182 // Track `Span` of whatever was found instead of a delimiter. We parse the `,`
183 // first because it's allowed to be followed by a `;`.
184 let no_delimiter_found = match <Token![,]>::parse(&fork) {
185 Ok(_) => {
186 input.advance_to(&fork);
187 None
188 }
189 Err(error) => {
190 // Reset the fork if we didn't find a `,`.
191 fork = input.fork();
192 Some(error.span())
193 }
194 };
195
196 if <Token![;]>::parse(&fork).is_ok() {
197 input.advance_to(&fork);
198
199 // If we found a semi-colon, start parsing generics.
200 if !input.is_empty() {
201 // `parse_terminated` parses everything left, which should end the
202 // while-loop.
203 // Not checking for duplicates here, as even Rust doesn't give a warning
204 // for those: `where T: Clone, T: Clone` produces no error or warning.
205 generics = Punctuated::<Generic, Token![,]>::parse_terminated(input)?
206 .into_iter()
207 .collect();
208 }
209 }
210 // We are here because the input isn't empty, but we also found no delimiter,
211 // something unexpected is here instead.
212 else if let Some(span) = no_delimiter_found {
213 return Err(Error::derive_where_delimiter(span));
214 }
215 }
216 }
217
218 Ok(Self {
219 generics,
220 spans,
221 traits,
222 })
223 })
224 }
225
226 /// Returns `true` if [`Trait`] is present.
227 pub fn contains(&self, trait_: Trait) -> bool {
228 self.traits
229 .iter()
230 .any(|derive_trait| derive_trait == trait_)
231 }
232
233 /// Returns `true` if any [`CustomBound`](Generic::CustomBound) is present.
234 pub fn any_custom_bound(&self) -> bool {
235 self.generics.iter().any(|generic| match generic {
236 Generic::CustomBound(_) => true,
237 Generic::NoBound(_) => false,
238 })
239 }
240
241 /// Returns `true` if the given generic type parameter if present.
242 pub fn has_type_param(&self, type_param: &Ident) -> bool {
243 self.generics.iter().any(|generic| match generic {
244 Generic::NoBound(Type::Path(TypePath { qself: None, path })) => {
245 if let Some(ident) = path.get_ident() {
246 ident == type_param
247 } else {
248 false
249 }
250 }
251 _ => false,
252 })
253 }
254
255 /// Returns `true` if any [`Trait`] supports skipping.
256 pub fn any_skip(&self) -> bool {
257 self.traits
258 .iter()
259 .any(|trait_| SkipGroup::trait_supported(**trait_))
260 }
261
262 /// Create [`WhereClause`] for the given parameters.
263 pub fn where_clause(
264 &self,
265 where_clause: &mut Option<Cow<WhereClause>>,
266 trait_: &DeriveTrait,
267 item: &Item,
268 ) {
269 // Only create a where clause if required
270 if !self.generics.is_empty() {
271 // We use the existing where clause or create a new one if required.
272 let where_clause = where_clause.get_or_insert(Cow::Owned(WhereClause {
273 where_token: <Token![where]>::default(),
274 predicates: Punctuated::default(),
275 }));
276
277 // Insert bounds into the `where` clause.
278 for generic in &self.generics {
279 where_clause
280 .to_mut()
281 .predicates
282 .push(WherePredicate::Type(match generic {
283 Generic::CustomBound(type_bound) => type_bound.clone(),
284 Generic::NoBound(path) => PredicateType {
285 lifetimes: None,
286 bounded_ty: path.clone(),
287 colon_token: <Token![:]>::default(),
288 bounds: trait_.where_bounds(item),
289 },
290 }));
291 }
292 }
293 }
294}
295
296/// Holds a single generic [type](Type) or [type with bound](PredicateType).
297#[derive(Eq, PartialEq)]
298pub enum Generic {
299 /// Generic type with custom [specified bounds](PredicateType).
300 CustomBound(PredicateType),
301 /// Generic [type](Type) which will be bound to the [`DeriveTrait`].
302 NoBound(Type),
303}
304
305impl Parse for Generic {
306 fn parse(input: ParseStream) -> Result<Self> {
307 let fork: ParseBuffer<'_> = input.fork();
308
309 // Try to parse input as a `WherePredicate`. The problem is, both expressions
310 // start with a Type, so starting with the `WherePredicate` is the easiest way
311 // of differentiating them.
312 if let Ok(where_predicate: WherePredicate) = WherePredicate::parse(&fork) {
313 input.advance_to(&fork);
314
315 // Don't allow lifetimes, as it doesn't make sense in the context.
316 if let WherePredicate::Type(path: PredicateType) = where_predicate {
317 Ok(Generic::CustomBound(path))
318 } else {
319 Err(Error::generic(where_predicate.span()))
320 }
321 } else {
322 match Type::parse(input) {
323 Ok(type_: Type) => Ok(Generic::NoBound(type_)),
324 Err(error: Error) => Err(Error::generic_syntax(error.span(), error)),
325 }
326 }
327 }
328}
329
330/// Trait to implement.
331#[derive(Eq, PartialEq)]
332pub enum DeriveTrait {
333 /// [`Clone`].
334 Clone,
335 /// [`Copy`].
336 Copy,
337 /// [`Debug`](std::fmt::Debug).
338 Debug,
339 /// [`Default`].
340 Default,
341 /// [`Eq`].
342 Eq,
343 /// [`Hash`](std::hash::Hash).
344 Hash,
345 /// [`Ord`].
346 Ord,
347 /// [`PartialEq`].
348 PartialEq,
349 /// [`PartialOrd`].
350 PartialOrd,
351 /// [`Zeroize`](https://docs.rs/zeroize/latest/zeroize/trait.Zeroize.html).
352 #[cfg(feature = "zeroize")]
353 Zeroize {
354 /// [`Zeroize`](https://docs.rs/zeroize/latest/zeroize/trait.Zeroize.html) path.
355 crate_: Option<Path>,
356 },
357 /// [`ZeroizeOnDrop`](https://docs.rs/zeroize/latest/zeroize/trait.ZeroizeOnDrop.html).
358 #[cfg(feature = "zeroize")]
359 ZeroizeOnDrop {
360 /// [`ZeroizeOnDrop`](https://docs.rs/zeroize/latest/zeroize/trait.ZeroizeOnDrop.html) path.
361 crate_: Option<Path>,
362 },
363}
364
365impl Deref for DeriveTrait {
366 type Target = Trait;
367
368 fn deref(&self) -> &Self::Target {
369 use DeriveTrait::*;
370
371 match self {
372 Clone => &Trait::Clone,
373 Copy => &Trait::Copy,
374 Debug => &Trait::Debug,
375 Default => &Trait::Default,
376 Eq => &Trait::Eq,
377 Hash => &Trait::Hash,
378 Ord => &Trait::Ord,
379 PartialEq => &Trait::PartialEq,
380 PartialOrd => &Trait::PartialOrd,
381 #[cfg(feature = "zeroize")]
382 Zeroize { .. } => &Trait::Zeroize,
383 #[cfg(feature = "zeroize")]
384 ZeroizeOnDrop { .. } => &Trait::ZeroizeOnDrop,
385 }
386 }
387}
388
389impl PartialEq<Trait> for &DeriveTrait {
390 fn eq(&self, other: &Trait) -> bool {
391 let trait_: &Trait = self;
392 trait_ == other
393 }
394}
395
396impl DeriveTrait {
397 /// Returns fully qualified [`Path`] for this trait.
398 pub fn path(&self) -> Path {
399 use DeriveTrait::*;
400
401 match self {
402 Clone => util::path_from_root_and_strs(self.crate_(), &["clone", "Clone"]),
403 Copy => util::path_from_root_and_strs(self.crate_(), &["marker", "Copy"]),
404 Debug => util::path_from_root_and_strs(self.crate_(), &["fmt", "Debug"]),
405 Default => util::path_from_root_and_strs(self.crate_(), &["default", "Default"]),
406 Eq => util::path_from_root_and_strs(self.crate_(), &["cmp", "Eq"]),
407 Hash => util::path_from_root_and_strs(self.crate_(), &["hash", "Hash"]),
408 Ord => util::path_from_root_and_strs(self.crate_(), &["cmp", "Ord"]),
409 PartialEq => util::path_from_root_and_strs(self.crate_(), &["cmp", "PartialEq"]),
410 PartialOrd => util::path_from_root_and_strs(self.crate_(), &["cmp", "PartialOrd"]),
411 #[cfg(feature = "zeroize")]
412 Zeroize { .. } => util::path_from_root_and_strs(self.crate_(), &["Zeroize"]),
413 #[cfg(feature = "zeroize")]
414 ZeroizeOnDrop { .. } => util::path_from_root_and_strs(self.crate_(), &["ZeroizeOnDrop"]),
415 }
416 }
417
418 /// Returns the path to the root crate for this trait.
419 pub fn crate_(&self) -> Path {
420 use DeriveTrait::*;
421
422 match self {
423 Clone => util::path_from_strs(&["core"]),
424 Copy => util::path_from_strs(&["core"]),
425 Debug => util::path_from_strs(&["core"]),
426 Default => util::path_from_strs(&["core"]),
427 Eq => util::path_from_strs(&["core"]),
428 Hash => util::path_from_strs(&["core"]),
429 Ord => util::path_from_strs(&["core"]),
430 PartialEq => util::path_from_strs(&["core"]),
431 PartialOrd => util::path_from_strs(&["core"]),
432 #[cfg(feature = "zeroize")]
433 Zeroize { crate_, .. } => {
434 if let Some(crate_) = crate_ {
435 crate_.clone()
436 } else {
437 util::path_from_strs(&["zeroize"])
438 }
439 }
440 #[cfg(feature = "zeroize")]
441 ZeroizeOnDrop { crate_, .. } => {
442 if let Some(crate_) = crate_ {
443 crate_.clone()
444 } else {
445 util::path_from_strs(&["zeroize"])
446 }
447 }
448 }
449 }
450
451 /// Returns where-clause bounds for the trait in respect of the item type.
452 fn where_bounds(&self, data: &Item) -> Punctuated<TypeParamBound, Token![+]> {
453 let mut list = Punctuated::new();
454
455 list.push(TypeParamBound::Trait(TraitBound {
456 paren_token: None,
457 modifier: TraitBoundModifier::None,
458 lifetimes: None,
459 path: self.path(),
460 }));
461
462 // Add bounds specific to the trait.
463 if let Some(bound) = self.additional_where_bounds(data) {
464 list.push(bound)
465 }
466
467 list
468 }
469
470 /// Create [`DeriveTrait`] from [`ParseStream`].
471 fn from_stream(span: Span, data: &Data, input: ParseStream) -> Result<(Span, Self)> {
472 match Meta::parse(input) {
473 Ok(meta) => {
474 let trait_ = Trait::from_path(meta.path())?;
475
476 if let Data::Union(_) = data {
477 // Make sure this `Trait` supports unions.
478 if !trait_.supports_union() {
479 return Err(Error::union(span));
480 }
481 }
482
483 match &meta {
484 Meta::Path(path) => Ok((path.span(), trait_.default_derive_trait())),
485 Meta::List(list) => {
486 let nested = list.parse_non_empty_nested_metas()?;
487
488 // This will return an error if no options are supported.
489 Ok((list.span(), trait_.parse_derive_trait(meta.span(), nested)?))
490 }
491 Meta::NameValue(name_value) => Err(Error::option_syntax(name_value.span())),
492 }
493 }
494 Err(error) => Err(Error::trait_syntax(error.span())),
495 }
496 }
497}
498