1//! Common implementation help for [`PartialOrd`] and [`Ord`].
2
3#[cfg(not(feature = "nightly"))]
4use std::{borrow::Cow, ops::Deref};
5
6use proc_macro2::TokenStream;
7#[cfg(not(feature = "nightly"))]
8use proc_macro2::{Literal, Span};
9#[cfg(not(feature = "nightly"))]
10use quote::format_ident;
11use quote::quote;
12#[cfg(not(feature = "nightly"))]
13use syn::{parse_quote, Expr, ExprLit, LitInt, Path};
14
15#[cfg(not(feature = "nightly"))]
16use crate::{item::Representation, Discriminant, Trait};
17use crate::{Data, DeriveTrait, Item, SimpleType, SplitGenerics};
18
19/// Build signature for [`PartialOrd`] and [`Ord`].
20pub fn build_ord_signature(
21 item: &Item,
22 #[cfg_attr(feature = "nightly", allow(unused_variables))] generics: &SplitGenerics<'_>,
23 #[cfg_attr(feature = "nightly", allow(unused_variables))] traits: &[DeriveTrait],
24 trait_: &DeriveTrait,
25 body: &TokenStream,
26) -> TokenStream {
27 let mut equal = quote! { ::core::cmp::Ordering::Equal };
28
29 // Add `Option` to `Ordering` if we are implementing `PartialOrd`.
30 if let DeriveTrait::PartialOrd = trait_ {
31 equal = quote! { ::core::option::Option::Some(#equal) };
32 }
33
34 match item {
35 // If the item is incomparable return `None`
36 item if item.is_incomparable() => {
37 quote! { ::core::option::Option::None }
38 }
39 // If there is more than one variant, check for the discriminant.
40 Item::Enum {
41 #[cfg(not(feature = "nightly"))]
42 discriminant,
43 variants,
44 ..
45 } if variants.len() > 1 => {
46 // In case the discriminant matches:
47 // If all variants are empty, return `Equal`.
48 let body_equal = if item.is_empty(**trait_) {
49 None
50 }
51 // Compare variant data and return `Equal` in the rest pattern if there are any empty
52 // variants that are comparable.
53 else if variants
54 .iter()
55 .any(|variant| variant.is_empty(**trait_) && !variant.is_incomparable())
56 {
57 Some(quote! {
58 match (self, __other) {
59 #body
60 _ => #equal,
61 }
62 })
63 }
64 // Insert `unreachable!` in the rest pattern if no variants are empty.
65 else {
66 #[cfg(not(feature = "safe"))]
67 // This follows the standard implementation.
68 let rest = quote! { unsafe { ::core::hint::unreachable_unchecked() } };
69 #[cfg(feature = "safe")]
70 let rest = quote! { ::core::unreachable!("comparing variants yielded unexpected results") };
71
72 Some(quote! {
73 match (self, __other) {
74 #body
75 _ => #rest,
76 }
77 })
78 };
79
80 let incomparable = build_incomparable_pattern(variants);
81
82 // If there is only one comparable variant, it has to be it when it is non
83 // incomparable.
84 let mut comparable = variants.iter().filter(|v| !v.is_incomparable());
85 // Takes the first value from the iterator, but only when there is only one
86 // (second yields none).
87 if let (Some(comparable), None) = (comparable.next(), comparable.next()) {
88 let incomparable = incomparable.expect("there should be > 1 variants");
89 // Either compare the single variant or return `Equal` when it is empty
90 let equal = if comparable.is_empty(**trait_) {
91 equal
92 } else {
93 body_equal.unwrap_or(equal)
94 };
95 quote! {
96 if ::core::matches!(self, #incomparable) || ::core::matches!(__other, #incomparable) {
97 ::core::option::Option::None
98 } else {
99 #equal
100 }
101 }
102 } else {
103 let incomparable = incomparable.into_iter();
104 let incomparable = quote! {
105 #(if ::core::matches!(self, #incomparable) || ::core::matches!(__other, #incomparable) {
106 return ::core::option::Option::None;
107 })*
108 };
109
110 let path = trait_.path();
111 let method = match trait_ {
112 DeriveTrait::PartialOrd => quote! { partial_cmp },
113 DeriveTrait::Ord => quote! { cmp },
114 _ => unreachable!("unsupported trait in `prepare_ord`"),
115 };
116
117 // Nightly implementation.
118 #[cfg(feature = "nightly")]
119 if let Some(body_equal) = body_equal {
120 quote! {
121 #incomparable
122
123 let __self_disc = ::core::intrinsics::discriminant_value(self);
124 let __other_disc = ::core::intrinsics::discriminant_value(__other);
125
126 if __self_disc == __other_disc {
127 #body_equal
128 } else {
129 #path::#method(&__self_disc, &__other_disc)
130 }
131 }
132 } else {
133 quote! {
134 #incomparable
135
136 #path::#method(
137 &::core::intrinsics::discriminant_value(self),
138 &::core::intrinsics::discriminant_value(__other),
139 )
140 }
141 }
142
143 #[cfg(not(feature = "nightly"))]
144 {
145 let body_else = match discriminant {
146 Discriminant::Single => {
147 unreachable!("we should only generate this code with multiple variants")
148 }
149 Discriminant::Unit => {
150 let mut discriminants = None;
151
152 // Validation is only needed if custom discriminants are defined.
153 let validate = (variants
154 .iter()
155 .any(|variant| variant.discriminant.is_some()))
156 .then(|| {
157 let discriminants =
158 discriminants.insert(build_discriminants(variants));
159 let discriminants = discriminants.iter().zip(variants).map(
160 |(discriminant, variant)| {
161 let name =
162 format_ident!("__VALIDATE_ISIZE_{}", variant.ident);
163 let discriminant = discriminant.deref();
164
165 quote! {
166 const #name: isize = #discriminant;
167 }
168 },
169 );
170
171 quote! {
172 #(#discriminants)*
173 }
174 });
175
176 if traits.iter().any(|trait_| trait_ == Trait::Copy) {
177 quote! {
178 #validate
179
180 #path::#method(&(*self as isize), &(*__other as isize))
181 }
182 } else if traits.iter().any(|trait_| trait_ == Trait::Clone) {
183 let clone = DeriveTrait::Clone.path();
184 quote! {
185 #validate
186
187 #path::#method(&(#clone::clone(self) as isize), &(#clone::clone(__other) as isize))
188 }
189 } else {
190 let discriminants = discriminants
191 .get_or_insert_with(|| build_discriminants(variants));
192 build_discriminant_comparison(
193 None,
194 validate,
195 item,
196 generics,
197 variants,
198 discriminants,
199 &path,
200 &method,
201 )
202 }
203 }
204 Discriminant::Data => {
205 let discriminants = build_discriminants(variants);
206 build_discriminant_comparison(
207 None,
208 None,
209 item,
210 generics,
211 variants,
212 &discriminants,
213 &path,
214 &method,
215 )
216 }
217 Discriminant::UnitRepr(repr) => {
218 if traits.iter().any(|trait_| trait_ == Trait::Copy) {
219 quote! {
220 #path::#method(&(*self as #repr), &(*__other as #repr))
221 }
222 } else if traits.iter().any(|trait_| trait_ == Trait::Clone) {
223 let clone = DeriveTrait::Clone.path();
224 quote! {
225 #path::#method(&(#clone::clone(self) as #repr), &(#clone::clone(__other) as #repr))
226 }
227 } else {
228 #[cfg(feature = "safe")]
229 let body_else = {
230 let discriminants = build_discriminants(variants);
231 build_discriminant_comparison(
232 Some(*repr),
233 None,
234 item,
235 generics,
236 variants,
237 &discriminants,
238 &path,
239 &method,
240 )
241 };
242 #[cfg(not(feature = "safe"))]
243 let body_else = quote! {
244 #path::#method(
245 &unsafe { *<*const _>::from(self).cast::<#repr>() },
246 &unsafe { *<*const _>::from(__other).cast::<#repr>() },
247 )
248 };
249
250 body_else
251 }
252 }
253 #[cfg(not(feature = "safe"))]
254 Discriminant::DataRepr(repr) => {
255 quote! {
256 #path::#method(
257 &unsafe { *<*const _>::from(self).cast::<#repr>() },
258 &unsafe { *<*const _>::from(__other).cast::<#repr>() },
259 )
260 }
261 }
262 #[cfg(feature = "safe")]
263 Discriminant::DataRepr(repr) => {
264 let discriminants = build_discriminants(variants);
265 build_discriminant_comparison(
266 Some(*repr),
267 None,
268 item,
269 generics,
270 variants,
271 &discriminants,
272 &path,
273 &method,
274 )
275 }
276 };
277
278 if let Some(body_equal) = body_equal {
279 quote! {
280 #incomparable
281
282 let __self_disc = ::core::mem::discriminant(self);
283 let __other_disc = ::core::mem::discriminant(__other);
284
285 if __self_disc == __other_disc {
286 #body_equal
287 } else {
288 #body_else
289 }
290 }
291 } else {
292 quote! {
293 #incomparable
294
295 #body_else
296 }
297 }
298 }
299 }
300 }
301 // If there is only one variant and it's empty or if the struct is empty, simply
302 // return `Equal`.
303 item if item.is_empty(**trait_) => {
304 quote! { #equal }
305 }
306 _ => {
307 quote! {
308 match (self, __other) {
309 #body
310 }
311 }
312 }
313 }
314}
315
316/// Builds list of discriminant values for all variants.
317#[cfg(not(feature = "nightly"))]
318fn build_discriminants<'a>(variants: &'a [Data<'_>]) -> Vec<Cow<'a, Expr>> {
319 let mut discriminants = Vec::<Cow<Expr>>::with_capacity(variants.len());
320 let mut last_expression: Option<(Option<usize>, usize)> = None;
321
322 for variant in variants {
323 let discriminant = if let Some(discriminant) = variant.discriminant {
324 last_expression = Some((Some(discriminants.len()), 0));
325 Cow::Borrowed(discriminant)
326 } else {
327 let discriminant = match &mut last_expression {
328 Some((Some(expr_index), counter)) => {
329 let expr = &discriminants[*expr_index];
330 *counter += 1;
331 let counter = Literal::usize_unsuffixed(*counter);
332 parse_quote! { (#expr) + #counter }
333 }
334 Some((None, counter)) => {
335 *counter += 1;
336
337 ExprLit {
338 attrs: Vec::new(),
339 lit: LitInt::new(&counter.to_string(), Span::call_site()).into(),
340 }
341 .into()
342 }
343 None => {
344 last_expression = Some((None, 0));
345 ExprLit {
346 attrs: Vec::new(),
347 lit: LitInt::new("0", Span::call_site()).into(),
348 }
349 .into()
350 }
351 };
352
353 Cow::Owned(discriminant)
354 };
355
356 discriminants.push(discriminant);
357 }
358
359 discriminants
360}
361
362/// Uses list of discriminant values to compare variants.
363#[cfg(not(feature = "nightly"))]
364#[allow(clippy::too_many_arguments)]
365fn build_discriminant_comparison(
366 repr: Option<Representation>,
367 validate: Option<TokenStream>,
368 item: &Item,
369 generics: &SplitGenerics<'_>,
370 variants: &[Data<'_>],
371 discriminants: &[Cow<'_, Expr>],
372 path: &Path,
373 method: &TokenStream,
374) -> TokenStream {
375 let variants = variants
376 .iter()
377 .zip(discriminants)
378 .map(|(variant, discriminant)| {
379 let pattern = variant.self_pattern();
380
381 if validate.is_some() {
382 let discriminant = format_ident!("__VALIDATE_ISIZE_{}", variant.ident);
383
384 quote! {
385 #pattern => #discriminant
386 }
387 } else {
388 quote! {
389 #pattern => #discriminant
390 }
391 }
392 });
393
394 // `isize` is currently used by Rust as the default representation when none is
395 // defined.
396 let repr = repr.unwrap_or(Representation::ISize).to_token();
397
398 let item = item.ident();
399 let SplitGenerics {
400 imp,
401 ty,
402 where_clause,
403 } = generics;
404
405 quote! {
406 const fn __discriminant #imp(__this: &#item #ty) -> #repr #where_clause {
407 #validate
408
409 match __this {
410 #(#variants),*
411 }
412 }
413
414 #path::#method(&__discriminant(self), &__discriminant(__other))
415 }
416}
417
418/// Build `match` arms for [`PartialOrd`] and [`Ord`].
419pub fn build_ord_body(trait_: &DeriveTrait, data: &Data) -> TokenStream {
420 let path = trait_.path();
421 let mut equal = quote! { ::core::cmp::Ordering::Equal };
422
423 // Add `Option` to `Ordering` if we are implementing `PartialOrd`.
424 let method = match trait_ {
425 DeriveTrait::PartialOrd => {
426 equal = quote! { ::core::option::Option::Some(#equal) };
427 quote! { partial_cmp }
428 }
429 DeriveTrait::Ord => quote! { cmp },
430 _ => unreachable!("unsupported trait in `build_ord`"),
431 };
432
433 // The match arm starts with `Ordering::Equal`. This will become the
434 // whole `match` arm if no fields are present.
435 let mut body = quote! { #equal };
436
437 // Builds `match` arms backwards, using the `match` arm of the field coming
438 // afterwards. `rev` has to be called twice separately because it can't be
439 // called on `zip`
440 for (field_temp, field_other) in data
441 .iter_self_ident(**trait_)
442 .rev()
443 .zip(data.iter_other_ident(**trait_).rev())
444 {
445 body = quote! {
446 match #path::#method(#field_temp, #field_other) {
447 #equal => #body,
448 __cmp => __cmp,
449 }
450 };
451 }
452
453 body
454}
455
456/// Generate a match arm that returns `body` for all incomparable `variants`
457pub fn build_incomparable_pattern(variants: &[Data]) -> Option<TokenStream> {
458 let mut incomparable: impl Iterator = variantsimpl Iterator
459 .iter()
460 .filter(|variant: &&Data<'_>| variant.is_incomparable())
461 .map(|variant: &Data<'_> @ Data { path: &Path, .. }| match variant.simple_type() {
462 SimpleType::Struct(_) => quote!(#path{..}),
463 SimpleType::Tuple(_) => quote!(#path(..)),
464 SimpleType::Union(_) => unreachable!("enum variants cannot be unions"),
465 SimpleType::Unit(_) => quote!(#path),
466 })
467 .peekable();
468 if incomparable.peek().is_some() {
469 Some(quote! {
470 #(#incomparable)|*
471 })
472 } else {
473 None
474 }
475}
476