1 | //! Common implementation help for [`PartialOrd`] and [`Ord`]. |
2 | |
3 | #[cfg (not(feature = "nightly" ))] |
4 | use std::{borrow::Cow, ops::Deref}; |
5 | |
6 | use proc_macro2::TokenStream; |
7 | #[cfg (not(feature = "nightly" ))] |
8 | use proc_macro2::{Literal, Span}; |
9 | #[cfg (not(feature = "nightly" ))] |
10 | use quote::format_ident; |
11 | use quote::quote; |
12 | #[cfg (not(feature = "nightly" ))] |
13 | use syn::{parse_quote, Expr, ExprLit, LitInt, Path}; |
14 | |
15 | #[cfg (not(feature = "nightly" ))] |
16 | use crate::{item::Representation, Discriminant, Trait}; |
17 | use crate::{Data, DeriveTrait, Item, SimpleType, SplitGenerics}; |
18 | |
19 | /// Build signature for [`PartialOrd`] and [`Ord`]. |
20 | pub fn build_ord_signature( |
21 | item: &Item, |
22 | #[cfg_attr (feature = "nightly" , allow(unused_variables))] generics: &SplitGenerics<'_>, |
23 | #[cfg_attr (feature = "nightly" , allow(unused_variables))] traits: &[DeriveTrait], |
24 | trait_: &DeriveTrait, |
25 | body: &TokenStream, |
26 | ) -> TokenStream { |
27 | let mut equal = quote! { ::core::cmp::Ordering::Equal }; |
28 | |
29 | // Add `Option` to `Ordering` if we are implementing `PartialOrd`. |
30 | if let DeriveTrait::PartialOrd = trait_ { |
31 | equal = quote! { ::core::option::Option::Some(#equal) }; |
32 | } |
33 | |
34 | match item { |
35 | // If the item is incomparable return `None` |
36 | item if item.is_incomparable() => { |
37 | quote! { ::core::option::Option::None } |
38 | } |
39 | // If there is more than one variant, check for the discriminant. |
40 | Item::Enum { |
41 | #[cfg (not(feature = "nightly" ))] |
42 | discriminant, |
43 | variants, |
44 | .. |
45 | } if variants.len() > 1 => { |
46 | // In case the discriminant matches: |
47 | // If all variants are empty, return `Equal`. |
48 | let body_equal = if item.is_empty(**trait_) { |
49 | None |
50 | } |
51 | // Compare variant data and return `Equal` in the rest pattern if there are any empty |
52 | // variants that are comparable. |
53 | else if variants |
54 | .iter() |
55 | .any(|variant| variant.is_empty(**trait_) && !variant.is_incomparable()) |
56 | { |
57 | Some(quote! { |
58 | match (self, __other) { |
59 | #body |
60 | _ => #equal, |
61 | } |
62 | }) |
63 | } |
64 | // Insert `unreachable!` in the rest pattern if no variants are empty. |
65 | else { |
66 | #[cfg (not(feature = "safe" ))] |
67 | // This follows the standard implementation. |
68 | let rest = quote! { unsafe { ::core::hint::unreachable_unchecked() } }; |
69 | #[cfg (feature = "safe" )] |
70 | let rest = quote! { ::core::unreachable!("comparing variants yielded unexpected results" ) }; |
71 | |
72 | Some(quote! { |
73 | match (self, __other) { |
74 | #body |
75 | _ => #rest, |
76 | } |
77 | }) |
78 | }; |
79 | |
80 | let incomparable = build_incomparable_pattern(variants); |
81 | |
82 | // If there is only one comparable variant, it has to be it when it is non |
83 | // incomparable. |
84 | let mut comparable = variants.iter().filter(|v| !v.is_incomparable()); |
85 | // Takes the first value from the iterator, but only when there is only one |
86 | // (second yields none). |
87 | if let (Some(comparable), None) = (comparable.next(), comparable.next()) { |
88 | let incomparable = incomparable.expect("there should be > 1 variants" ); |
89 | // Either compare the single variant or return `Equal` when it is empty |
90 | let equal = if comparable.is_empty(**trait_) { |
91 | equal |
92 | } else { |
93 | body_equal.unwrap_or(equal) |
94 | }; |
95 | quote! { |
96 | if ::core::matches!(self, #incomparable) || ::core::matches!(__other, #incomparable) { |
97 | ::core::option::Option::None |
98 | } else { |
99 | #equal |
100 | } |
101 | } |
102 | } else { |
103 | let incomparable = incomparable.into_iter(); |
104 | let incomparable = quote! { |
105 | #(if ::core::matches!(self, #incomparable) || ::core::matches!(__other, #incomparable) { |
106 | return ::core::option::Option::None; |
107 | })* |
108 | }; |
109 | |
110 | let path = trait_.path(); |
111 | let method = match trait_ { |
112 | DeriveTrait::PartialOrd => quote! { partial_cmp }, |
113 | DeriveTrait::Ord => quote! { cmp }, |
114 | _ => unreachable!("unsupported trait in `prepare_ord`" ), |
115 | }; |
116 | |
117 | // Nightly implementation. |
118 | #[cfg (feature = "nightly" )] |
119 | if let Some(body_equal) = body_equal { |
120 | quote! { |
121 | #incomparable |
122 | |
123 | let __self_disc = ::core::intrinsics::discriminant_value(self); |
124 | let __other_disc = ::core::intrinsics::discriminant_value(__other); |
125 | |
126 | if __self_disc == __other_disc { |
127 | #body_equal |
128 | } else { |
129 | #path::#method(&__self_disc, &__other_disc) |
130 | } |
131 | } |
132 | } else { |
133 | quote! { |
134 | #incomparable |
135 | |
136 | #path::#method( |
137 | &::core::intrinsics::discriminant_value(self), |
138 | &::core::intrinsics::discriminant_value(__other), |
139 | ) |
140 | } |
141 | } |
142 | |
143 | #[cfg (not(feature = "nightly" ))] |
144 | { |
145 | let body_else = match discriminant { |
146 | Discriminant::Single => { |
147 | unreachable!("we should only generate this code with multiple variants" ) |
148 | } |
149 | Discriminant::Unit => { |
150 | let mut discriminants = None; |
151 | |
152 | // Validation is only needed if custom discriminants are defined. |
153 | let validate = (variants |
154 | .iter() |
155 | .any(|variant| variant.discriminant.is_some())) |
156 | .then(|| { |
157 | let discriminants = |
158 | discriminants.insert(build_discriminants(variants)); |
159 | let discriminants = discriminants.iter().zip(variants).map( |
160 | |(discriminant, variant)| { |
161 | let name = |
162 | format_ident!("__VALIDATE_ISIZE_ {}" , variant.ident); |
163 | let discriminant = discriminant.deref(); |
164 | |
165 | quote! { |
166 | const #name: isize = #discriminant; |
167 | } |
168 | }, |
169 | ); |
170 | |
171 | quote! { |
172 | #(#discriminants)* |
173 | } |
174 | }); |
175 | |
176 | if traits.iter().any(|trait_| trait_ == Trait::Copy) { |
177 | quote! { |
178 | #validate |
179 | |
180 | #path::#method(&(*self as isize), &(*__other as isize)) |
181 | } |
182 | } else if traits.iter().any(|trait_| trait_ == Trait::Clone) { |
183 | let clone = DeriveTrait::Clone.path(); |
184 | quote! { |
185 | #validate |
186 | |
187 | #path::#method(&(#clone::clone(self) as isize), &(#clone::clone(__other) as isize)) |
188 | } |
189 | } else { |
190 | let discriminants = discriminants |
191 | .get_or_insert_with(|| build_discriminants(variants)); |
192 | build_discriminant_comparison( |
193 | None, |
194 | validate, |
195 | item, |
196 | generics, |
197 | variants, |
198 | discriminants, |
199 | &path, |
200 | &method, |
201 | ) |
202 | } |
203 | } |
204 | Discriminant::Data => { |
205 | let discriminants = build_discriminants(variants); |
206 | build_discriminant_comparison( |
207 | None, |
208 | None, |
209 | item, |
210 | generics, |
211 | variants, |
212 | &discriminants, |
213 | &path, |
214 | &method, |
215 | ) |
216 | } |
217 | Discriminant::UnitRepr(repr) => { |
218 | if traits.iter().any(|trait_| trait_ == Trait::Copy) { |
219 | quote! { |
220 | #path::#method(&(*self as #repr), &(*__other as #repr)) |
221 | } |
222 | } else if traits.iter().any(|trait_| trait_ == Trait::Clone) { |
223 | let clone = DeriveTrait::Clone.path(); |
224 | quote! { |
225 | #path::#method(&(#clone::clone(self) as #repr), &(#clone::clone(__other) as #repr)) |
226 | } |
227 | } else { |
228 | #[cfg (feature = "safe" )] |
229 | let body_else = { |
230 | let discriminants = build_discriminants(variants); |
231 | build_discriminant_comparison( |
232 | Some(*repr), |
233 | None, |
234 | item, |
235 | generics, |
236 | variants, |
237 | &discriminants, |
238 | &path, |
239 | &method, |
240 | ) |
241 | }; |
242 | #[cfg (not(feature = "safe" ))] |
243 | let body_else = quote! { |
244 | #path::#method( |
245 | &unsafe { *<*const _>::from(self).cast::<#repr>() }, |
246 | &unsafe { *<*const _>::from(__other).cast::<#repr>() }, |
247 | ) |
248 | }; |
249 | |
250 | body_else |
251 | } |
252 | } |
253 | #[cfg (not(feature = "safe" ))] |
254 | Discriminant::DataRepr(repr) => { |
255 | quote! { |
256 | #path::#method( |
257 | &unsafe { *<*const _>::from(self).cast::<#repr>() }, |
258 | &unsafe { *<*const _>::from(__other).cast::<#repr>() }, |
259 | ) |
260 | } |
261 | } |
262 | #[cfg (feature = "safe" )] |
263 | Discriminant::DataRepr(repr) => { |
264 | let discriminants = build_discriminants(variants); |
265 | build_discriminant_comparison( |
266 | Some(*repr), |
267 | None, |
268 | item, |
269 | generics, |
270 | variants, |
271 | &discriminants, |
272 | &path, |
273 | &method, |
274 | ) |
275 | } |
276 | }; |
277 | |
278 | if let Some(body_equal) = body_equal { |
279 | quote! { |
280 | #incomparable |
281 | |
282 | let __self_disc = ::core::mem::discriminant(self); |
283 | let __other_disc = ::core::mem::discriminant(__other); |
284 | |
285 | if __self_disc == __other_disc { |
286 | #body_equal |
287 | } else { |
288 | #body_else |
289 | } |
290 | } |
291 | } else { |
292 | quote! { |
293 | #incomparable |
294 | |
295 | #body_else |
296 | } |
297 | } |
298 | } |
299 | } |
300 | } |
301 | // If there is only one variant and it's empty or if the struct is empty, simply |
302 | // return `Equal`. |
303 | item if item.is_empty(**trait_) => { |
304 | quote! { #equal } |
305 | } |
306 | _ => { |
307 | quote! { |
308 | match (self, __other) { |
309 | #body |
310 | } |
311 | } |
312 | } |
313 | } |
314 | } |
315 | |
316 | /// Builds list of discriminant values for all variants. |
317 | #[cfg (not(feature = "nightly" ))] |
318 | fn build_discriminants<'a>(variants: &'a [Data<'_>]) -> Vec<Cow<'a, Expr>> { |
319 | let mut discriminants = Vec::<Cow<Expr>>::with_capacity(variants.len()); |
320 | let mut last_expression: Option<(Option<usize>, usize)> = None; |
321 | |
322 | for variant in variants { |
323 | let discriminant = if let Some(discriminant) = variant.discriminant { |
324 | last_expression = Some((Some(discriminants.len()), 0)); |
325 | Cow::Borrowed(discriminant) |
326 | } else { |
327 | let discriminant = match &mut last_expression { |
328 | Some((Some(expr_index), counter)) => { |
329 | let expr = &discriminants[*expr_index]; |
330 | *counter += 1; |
331 | let counter = Literal::usize_unsuffixed(*counter); |
332 | parse_quote! { (#expr) + #counter } |
333 | } |
334 | Some((None, counter)) => { |
335 | *counter += 1; |
336 | |
337 | ExprLit { |
338 | attrs: Vec::new(), |
339 | lit: LitInt::new(&counter.to_string(), Span::call_site()).into(), |
340 | } |
341 | .into() |
342 | } |
343 | None => { |
344 | last_expression = Some((None, 0)); |
345 | ExprLit { |
346 | attrs: Vec::new(), |
347 | lit: LitInt::new("0" , Span::call_site()).into(), |
348 | } |
349 | .into() |
350 | } |
351 | }; |
352 | |
353 | Cow::Owned(discriminant) |
354 | }; |
355 | |
356 | discriminants.push(discriminant); |
357 | } |
358 | |
359 | discriminants |
360 | } |
361 | |
362 | /// Uses list of discriminant values to compare variants. |
363 | #[cfg (not(feature = "nightly" ))] |
364 | #[allow (clippy::too_many_arguments)] |
365 | fn build_discriminant_comparison( |
366 | repr: Option<Representation>, |
367 | validate: Option<TokenStream>, |
368 | item: &Item, |
369 | generics: &SplitGenerics<'_>, |
370 | variants: &[Data<'_>], |
371 | discriminants: &[Cow<'_, Expr>], |
372 | path: &Path, |
373 | method: &TokenStream, |
374 | ) -> TokenStream { |
375 | let variants = variants |
376 | .iter() |
377 | .zip(discriminants) |
378 | .map(|(variant, discriminant)| { |
379 | let pattern = variant.self_pattern(); |
380 | |
381 | if validate.is_some() { |
382 | let discriminant = format_ident!("__VALIDATE_ISIZE_ {}" , variant.ident); |
383 | |
384 | quote! { |
385 | #pattern => #discriminant |
386 | } |
387 | } else { |
388 | quote! { |
389 | #pattern => #discriminant |
390 | } |
391 | } |
392 | }); |
393 | |
394 | // `isize` is currently used by Rust as the default representation when none is |
395 | // defined. |
396 | let repr = repr.unwrap_or(Representation::ISize).to_token(); |
397 | |
398 | let item = item.ident(); |
399 | let SplitGenerics { |
400 | imp, |
401 | ty, |
402 | where_clause, |
403 | } = generics; |
404 | |
405 | quote! { |
406 | const fn __discriminant #imp(__this: &#item #ty) -> #repr #where_clause { |
407 | #validate |
408 | |
409 | match __this { |
410 | #(#variants),* |
411 | } |
412 | } |
413 | |
414 | #path::#method(&__discriminant(self), &__discriminant(__other)) |
415 | } |
416 | } |
417 | |
418 | /// Build `match` arms for [`PartialOrd`] and [`Ord`]. |
419 | pub fn build_ord_body(trait_: &DeriveTrait, data: &Data) -> TokenStream { |
420 | let path = trait_.path(); |
421 | let mut equal = quote! { ::core::cmp::Ordering::Equal }; |
422 | |
423 | // Add `Option` to `Ordering` if we are implementing `PartialOrd`. |
424 | let method = match trait_ { |
425 | DeriveTrait::PartialOrd => { |
426 | equal = quote! { ::core::option::Option::Some(#equal) }; |
427 | quote! { partial_cmp } |
428 | } |
429 | DeriveTrait::Ord => quote! { cmp }, |
430 | _ => unreachable!("unsupported trait in `build_ord`" ), |
431 | }; |
432 | |
433 | // The match arm starts with `Ordering::Equal`. This will become the |
434 | // whole `match` arm if no fields are present. |
435 | let mut body = quote! { #equal }; |
436 | |
437 | // Builds `match` arms backwards, using the `match` arm of the field coming |
438 | // afterwards. `rev` has to be called twice separately because it can't be |
439 | // called on `zip` |
440 | for (field_temp, field_other) in data |
441 | .iter_self_ident(**trait_) |
442 | .rev() |
443 | .zip(data.iter_other_ident(**trait_).rev()) |
444 | { |
445 | body = quote! { |
446 | match #path::#method(#field_temp, #field_other) { |
447 | #equal => #body, |
448 | __cmp => __cmp, |
449 | } |
450 | }; |
451 | } |
452 | |
453 | body |
454 | } |
455 | |
456 | /// Generate a match arm that returns `body` for all incomparable `variants` |
457 | pub fn build_incomparable_pattern(variants: &[Data]) -> Option<TokenStream> { |
458 | let mut incomparable: impl Iterator = variantsimpl Iterator |
459 | .iter() |
460 | .filter(|variant: &&Data<'_>| variant.is_incomparable()) |
461 | .map(|variant: &Data<'_> @ Data { path: &Path, .. }| match variant.simple_type() { |
462 | SimpleType::Struct(_) => quote!(#path{..}), |
463 | SimpleType::Tuple(_) => quote!(#path(..)), |
464 | SimpleType::Union(_) => unreachable!("enum variants cannot be unions" ), |
465 | SimpleType::Unit(_) => quote!(#path), |
466 | }) |
467 | .peekable(); |
468 | if incomparable.peek().is_some() { |
469 | Some(quote! { |
470 | #(#incomparable)|* |
471 | }) |
472 | } else { |
473 | None |
474 | } |
475 | } |
476 | |