1//! [`PartialEq`](trait@std::cmp::PartialEq) implementation.
2
3use proc_macro2::TokenStream;
4use quote::quote;
5
6use super::common_ord::build_incomparable_pattern;
7use crate::{Data, DeriveTrait, Item, SimpleType, SplitGenerics, TraitImpl};
8
9/// Dummy-struct implement [`Trait`](crate::Trait) for
10/// [`PartialEq`](trait@std::cmp::PartialEq).
11pub struct PartialEq;
12
13impl TraitImpl for PartialEq {
14 fn as_str(&self) -> &'static str {
15 "PartialEq"
16 }
17
18 fn default_derive_trait(&self) -> DeriveTrait {
19 DeriveTrait::PartialEq
20 }
21
22 fn build_signature(
23 &self,
24 _any_bound: bool,
25 item: &Item,
26 _generics: &SplitGenerics<'_>,
27 _traits: &[DeriveTrait],
28 trait_: &DeriveTrait,
29 body: &TokenStream,
30 ) -> TokenStream {
31 let body = {
32 match item {
33 // If the whole item is incomparable return false
34 item if item.is_incomparable() => {
35 quote! { false }
36 }
37 // If there is more than one variant and not all variants are empty, check for
38 // discriminant and match on variant data.
39 Item::Enum { variants, .. } if variants.len() > 1 && !item.is_empty(**trait_) => {
40 // Return `true` in the rest pattern if there are any empty variants
41 // that are not incomparable.
42 let rest = if variants
43 .iter()
44 .any(|variant| variant.is_empty(**trait_) && !variant.is_incomparable())
45 {
46 quote! { true }
47 } else {
48 #[cfg(not(feature = "safe"))]
49 // This follows the standard implementation.
50 quote! { unsafe { ::core::hint::unreachable_unchecked() } }
51 #[cfg(feature = "safe")]
52 quote! { ::core::unreachable!("comparing variants yielded unexpected results") }
53 };
54
55 // Return `false` for all incomparable variants
56 let incomparable = build_incomparable_pattern(variants).into_iter();
57
58 quote! {
59 if ::core::mem::discriminant(self) == ::core::mem::discriminant(__other) {
60 match (self, __other) {
61 #body
62 #((#incomparable, ..) => false,)*
63 _ => #rest,
64 }
65 } else {
66 false
67 }
68 }
69 }
70 // If there is more than one variant and all are empty, check for
71 // discriminant and simply return `true` if it is not incomparable.
72 Item::Enum { variants, .. } if variants.len() > 1 && item.is_empty(**trait_) => {
73 let incomparable = build_incomparable_pattern(variants).into_iter();
74 quote! {
75 if ::core::mem::discriminant(self) == ::core::mem::discriminant(__other) {
76 #(if ::core::matches!(self, #incomparable) {
77 return false;
78 })*
79 true
80 } else {
81 false
82 }
83 }
84 }
85 // If there is only one variant and it's empty or if the struct is empty, simply
86 // return `true`.
87 item if item.is_empty(**trait_) => {
88 quote! { true }
89 }
90 _ => {
91 quote! {
92 match (self, __other) {
93 #body
94 }
95 }
96 }
97 }
98 };
99
100 quote! {
101 #[inline]
102 fn eq(&self, __other: &Self) -> bool {
103 #body
104 }
105 }
106 }
107
108 fn build_body(
109 &self,
110 _any_bound: bool,
111 _traits: &[DeriveTrait],
112 trait_: &DeriveTrait,
113 data: &Data,
114 ) -> TokenStream {
115 if data.is_empty(**trait_) || data.is_incomparable() {
116 TokenStream::new()
117 } else {
118 match data.simple_type() {
119 SimpleType::Struct(fields) | SimpleType::Tuple(fields) => {
120 let self_pattern = &fields.self_pattern;
121 let other_pattern = &fields.other_pattern;
122 let trait_path = trait_.path();
123 let self_ident = data.iter_self_ident(**trait_);
124 let other_ident = data.iter_other_ident(**trait_);
125
126 quote! {
127 (#self_pattern, #other_pattern) =>
128 true #(&& #trait_path::eq(#self_ident, #other_ident))*,
129 }
130 }
131 SimpleType::Unit(_) => TokenStream::new(),
132 SimpleType::Union(_) => unreachable!("unexpected trait for union"),
133 }
134 }
135 }
136}
137