1// https://github.com/rust-lang/rust/issues/13101
2
3use ast;
4use attr;
5use matcher;
6use paths;
7use proc_macro2;
8use syn;
9use utils;
10
11/// Derive `Eq` for `input`.
12pub fn derive_eq(input: &ast::Input) -> proc_macro2::TokenStream {
13 let name = &input.ident;
14
15 let eq_trait_path = eq_trait_path();
16 let generics = utils::build_impl_generics(
17 input,
18 &eq_trait_path,
19 needs_eq_bound,
20 |field| field.eq_bound(),
21 |input| input.eq_bound(),
22 );
23 let new_where_clause;
24 let (impl_generics, ty_generics, mut where_clause) = generics.split_for_impl();
25
26 if let Some(new_where_clause2) =
27 maybe_add_copy(input, where_clause, |f| !f.attrs.ignore_partial_eq())
28 {
29 new_where_clause = new_where_clause2;
30 where_clause = Some(&new_where_clause);
31 }
32
33 quote! {
34 #[allow(unused_qualifications)]
35 impl #impl_generics #eq_trait_path for #name #ty_generics #where_clause {}
36 }
37}
38
39/// Derive `PartialEq` for `input`.
40pub fn derive_partial_eq(input: &ast::Input) -> proc_macro2::TokenStream {
41 let discriminant_cmp = if let ast::Body::Enum(_) = input.body {
42 let discriminant_path = paths::discriminant_path();
43
44 quote!((#discriminant_path(&*self) == #discriminant_path(&*other)))
45 } else {
46 quote!(true)
47 };
48
49 let name = &input.ident;
50
51 let partial_eq_trait_path = partial_eq_trait_path();
52 let generics = utils::build_impl_generics(
53 input,
54 &partial_eq_trait_path,
55 needs_partial_eq_bound,
56 |field| field.partial_eq_bound(),
57 |input| input.partial_eq_bound(),
58 );
59 let new_where_clause;
60 let (impl_generics, ty_generics, mut where_clause) = generics.split_for_impl();
61
62 let match_fields = if input.is_trivial_enum() {
63 quote!(true)
64 } else {
65 matcher::Matcher::new(matcher::BindingStyle::Ref, input.attrs.is_packed)
66 .with_field_filter(|f: &ast::Field| !f.attrs.ignore_partial_eq())
67 .build_2_arms(
68 (quote!(*self), quote!(*other)),
69 (input, "__self"),
70 (input, "__other"),
71 |_, _, _, (left_variant, right_variant)| {
72 let cmp = left_variant.iter().zip(&right_variant).map(|(o, i)| {
73 let outer_name = &o.expr;
74 let inner_name = &i.expr;
75
76 if o.field.attrs.ignore_partial_eq() {
77 None
78 } else if let Some(compare_fn) = o.field.attrs.partial_eq_compare_with() {
79 Some(quote!(&& #compare_fn(&#outer_name, &#inner_name)))
80 } else {
81 Some(quote!(&& &#outer_name == &#inner_name))
82 }
83 });
84
85 quote!(true #(#cmp)*)
86 },
87 )
88 };
89
90 if let Some(new_where_clause2) =
91 maybe_add_copy(input, where_clause, |f| !f.attrs.ignore_partial_eq())
92 {
93 new_where_clause = new_where_clause2;
94 where_clause = Some(&new_where_clause);
95 }
96
97 quote! {
98 #[allow(unused_qualifications)]
99 #[allow(clippy::unneeded_field_pattern)]
100 impl #impl_generics #partial_eq_trait_path for #name #ty_generics #where_clause {
101 fn eq(&self, other: &Self) -> bool {
102 #discriminant_cmp && #match_fields
103 }
104 }
105 }
106}
107
108/// Derive `PartialOrd` for `input`.
109pub fn derive_partial_ord(
110 input: &ast::Input,
111 errors: &mut proc_macro2::TokenStream,
112) -> proc_macro2::TokenStream {
113 if let ast::Body::Enum(_) = input.body {
114 if !input.attrs.partial_ord_on_enum() {
115 let message = "can't use `#[derivative(PartialOrd)]` on an enumeration without \
116 `feature_allow_slow_enum`; see the documentation for more details";
117 errors.extend(syn::Error::new(input.span, message).to_compile_error());
118 }
119 }
120
121 let option_path = option_path();
122 let ordering_path = ordering_path();
123
124 let body = matcher::Matcher::new(matcher::BindingStyle::Ref, input.attrs.is_packed)
125 .with_field_filter(|f: &ast::Field| !f.attrs.ignore_partial_ord())
126 .build_arms(input, "__self", |_, n, _, _, _, outer_bis| {
127 let body = matcher::Matcher::new(matcher::BindingStyle::Ref, input.attrs.is_packed)
128 .with_field_filter(|f: &ast::Field| !f.attrs.ignore_partial_ord())
129 .build_arms(input, "__other", |_, m, _, _, _, inner_bis| {
130 match n.cmp(&m) {
131 ::std::cmp::Ordering::Less => {
132 quote!(#option_path::Some(#ordering_path::Less))
133 }
134 ::std::cmp::Ordering::Greater => {
135 quote!(#option_path::Some(#ordering_path::Greater))
136 }
137 ::std::cmp::Ordering::Equal => {
138 let equal_path = quote!(#ordering_path::Equal);
139 outer_bis
140 .iter()
141 .rev()
142 .zip(inner_bis.into_iter().rev())
143 .fold(quote!(#option_path::Some(#equal_path)), |acc, (o, i)| {
144 let outer_name = &o.expr;
145 let inner_name = &i.expr;
146
147 if o.field.attrs.ignore_partial_ord() {
148 acc
149 } else {
150 let cmp_fn = o
151 .field
152 .attrs
153 .partial_ord_compare_with()
154 .map(|f| quote!(#f))
155 .unwrap_or_else(|| {
156 let path = partial_ord_trait_path();
157 quote!(#path::partial_cmp)
158 });
159
160 quote!(match #cmp_fn(&#outer_name, &#inner_name) {
161 #option_path::Some(#equal_path) => #acc,
162 __derive_ordering_other => __derive_ordering_other,
163 })
164 }
165 })
166 }
167 }
168 });
169
170 quote! {
171 match *other {
172 #body
173 }
174
175 }
176 });
177
178 let name = &input.ident;
179
180 let partial_ord_trait_path = partial_ord_trait_path();
181 let generics = utils::build_impl_generics(
182 input,
183 &partial_ord_trait_path,
184 needs_partial_ord_bound,
185 |field| field.partial_ord_bound(),
186 |input| input.partial_ord_bound(),
187 );
188 let new_where_clause;
189 let (impl_generics, ty_generics, mut where_clause) = generics.split_for_impl();
190
191 if let Some(new_where_clause2) =
192 maybe_add_copy(input, where_clause, |f| !f.attrs.ignore_partial_ord())
193 {
194 new_where_clause = new_where_clause2;
195 where_clause = Some(&new_where_clause);
196 }
197
198 quote! {
199 #[allow(unused_qualifications)]
200 #[allow(clippy::unneeded_field_pattern)]
201 impl #impl_generics #partial_ord_trait_path for #name #ty_generics #where_clause {
202 fn partial_cmp(&self, other: &Self) -> #option_path<#ordering_path> {
203 match *self {
204 #body
205 }
206 }
207 }
208 }
209}
210
211/// Derive `Ord` for `input`.
212pub fn derive_ord(
213 input: &ast::Input,
214 errors: &mut proc_macro2::TokenStream,
215) -> proc_macro2::TokenStream {
216 if let ast::Body::Enum(_) = input.body {
217 if !input.attrs.ord_on_enum() {
218 let message = "can't use `#[derivative(Ord)]` on an enumeration without \
219 `feature_allow_slow_enum`; see the documentation for more details";
220 errors.extend(syn::Error::new(input.span, message).to_compile_error());
221 }
222 }
223
224 let ordering_path = ordering_path();
225
226 let body = matcher::Matcher::new(matcher::BindingStyle::Ref, input.attrs.is_packed)
227 .with_field_filter(|f: &ast::Field| !f.attrs.ignore_ord())
228 .build_arms(input, "__self", |_, n, _, _, _, outer_bis| {
229 let body = matcher::Matcher::new(matcher::BindingStyle::Ref, input.attrs.is_packed)
230 .with_field_filter(|f: &ast::Field| !f.attrs.ignore_ord())
231 .build_arms(input, "__other", |_, m, _, _, _, inner_bis| {
232 match n.cmp(&m) {
233 ::std::cmp::Ordering::Less => quote!(#ordering_path::Less),
234 ::std::cmp::Ordering::Greater => quote!(#ordering_path::Greater),
235 ::std::cmp::Ordering::Equal => {
236 let equal_path = quote!(#ordering_path::Equal);
237 outer_bis
238 .iter()
239 .rev()
240 .zip(inner_bis.into_iter().rev())
241 .fold(quote!(#equal_path), |acc, (o, i)| {
242 let outer_name = &o.expr;
243 let inner_name = &i.expr;
244
245 if o.field.attrs.ignore_ord() {
246 acc
247 } else {
248 let cmp_fn = o
249 .field
250 .attrs
251 .ord_compare_with()
252 .map(|f| quote!(#f))
253 .unwrap_or_else(|| {
254 let path = ord_trait_path();
255 quote!(#path::cmp)
256 });
257
258 quote!(match #cmp_fn(&#outer_name, &#inner_name) {
259 #equal_path => #acc,
260 __derive_ordering_other => __derive_ordering_other,
261 })
262 }
263 })
264 }
265 }
266 });
267
268 quote! {
269 match *other {
270 #body
271 }
272
273 }
274 });
275
276 let name = &input.ident;
277
278 let ord_trait_path = ord_trait_path();
279 let generics = utils::build_impl_generics(
280 input,
281 &ord_trait_path,
282 needs_ord_bound,
283 |field| field.ord_bound(),
284 |input| input.ord_bound(),
285 );
286 let new_where_clause;
287 let (impl_generics, ty_generics, mut where_clause) = generics.split_for_impl();
288
289 if let Some(new_where_clause2) = maybe_add_copy(input, where_clause, |f| !f.attrs.ignore_ord())
290 {
291 new_where_clause = new_where_clause2;
292 where_clause = Some(&new_where_clause);
293 }
294
295 quote! {
296 #[allow(unused_qualifications)]
297 #[allow(clippy::unneeded_field_pattern)]
298 impl #impl_generics #ord_trait_path for #name #ty_generics #where_clause {
299 fn cmp(&self, other: &Self) -> #ordering_path {
300 match *self {
301 #body
302 }
303 }
304 }
305 }
306}
307
308fn needs_partial_eq_bound(attrs: &attr::Field) -> bool {
309 !attrs.ignore_partial_eq() && attrs.partial_eq_bound().is_none()
310}
311
312fn needs_partial_ord_bound(attrs: &attr::Field) -> bool {
313 !attrs.ignore_partial_ord() && attrs.partial_ord_bound().is_none()
314}
315
316fn needs_ord_bound(attrs: &attr::Field) -> bool {
317 !attrs.ignore_ord() && attrs.ord_bound().is_none()
318}
319
320fn needs_eq_bound(attrs: &attr::Field) -> bool {
321 !attrs.ignore_partial_eq() && attrs.eq_bound().is_none()
322}
323
324/// Return the path of the `Eq` trait, that is `::std::cmp::Eq`.
325fn eq_trait_path() -> syn::Path {
326 if cfg!(feature = "use_core") {
327 parse_quote!(::core::cmp::Eq)
328 } else {
329 parse_quote!(::std::cmp::Eq)
330 }
331}
332
333/// Return the path of the `PartialEq` trait, that is `::std::cmp::PartialEq`.
334fn partial_eq_trait_path() -> syn::Path {
335 if cfg!(feature = "use_core") {
336 parse_quote!(::core::cmp::PartialEq)
337 } else {
338 parse_quote!(::std::cmp::PartialEq)
339 }
340}
341
342/// Return the path of the `PartialOrd` trait, that is `::std::cmp::PartialOrd`.
343fn partial_ord_trait_path() -> syn::Path {
344 if cfg!(feature = "use_core") {
345 parse_quote!(::core::cmp::PartialOrd)
346 } else {
347 parse_quote!(::std::cmp::PartialOrd)
348 }
349}
350
351/// Return the path of the `Ord` trait, that is `::std::cmp::Ord`.
352fn ord_trait_path() -> syn::Path {
353 if cfg!(feature = "use_core") {
354 parse_quote!(::core::cmp::Ord)
355 } else {
356 parse_quote!(::std::cmp::Ord)
357 }
358}
359
360/// Return the path of the `Option` trait, that is `::std::option::Option`.
361fn option_path() -> syn::Path {
362 if cfg!(feature = "use_core") {
363 parse_quote!(::core::option::Option)
364 } else {
365 parse_quote!(::std::option::Option)
366 }
367}
368
369/// Return the path of the `Ordering` trait, that is `::std::cmp::Ordering`.
370fn ordering_path() -> syn::Path {
371 if cfg!(feature = "use_core") {
372 parse_quote!(::core::cmp::Ordering)
373 } else {
374 parse_quote!(::std::cmp::Ordering)
375 }
376}
377
378fn maybe_add_copy(
379 input: &ast::Input,
380 where_clause: Option<&syn::WhereClause>,
381 field_filter: impl Fn(&ast::Field) -> bool,
382) -> Option<syn::WhereClause> {
383 if input.attrs.is_packed && !input.body.is_empty() {
384 let mut new_where_clause = where_clause.cloned().unwrap_or_else(|| syn::WhereClause {
385 where_token: parse_quote!(where),
386 predicates: Default::default(),
387 });
388
389 new_where_clause.predicates.extend(
390 input
391 .body
392 .all_fields()
393 .into_iter()
394 .filter(|f| field_filter(f))
395 .map(|f| {
396 let ty = f.ty;
397
398 let pred: syn::WherePredicate = parse_quote!(#ty: Copy);
399 pred
400 }),
401 );
402
403 Some(new_where_clause)
404 } else {
405 None
406 }
407}
408