1use crate::enum_attributes::ErrorTypeAttribute;
2use crate::utils::die;
3use crate::variant_attributes::{NumEnumVariantAttributeItem, NumEnumVariantAttributes};
4use proc_macro2::Span;
5use quote::{format_ident, ToTokens};
6use std::collections::BTreeSet;
7use syn::{
8 parse::{Parse, ParseStream},
9 parse_quote, Attribute, Data, DeriveInput, Expr, ExprLit, ExprUnary, Fields, Ident, Lit,
10 LitInt, Meta, Path, Result, UnOp,
11};
12
13pub(crate) struct EnumInfo {
14 pub(crate) name: Ident,
15 pub(crate) repr: Ident,
16 pub(crate) variants: Vec<VariantInfo>,
17 pub(crate) error_type_info: ErrorType,
18}
19
20impl EnumInfo {
21 /// Returns whether the number of variants (ignoring defaults, catch-alls, etc) is the same as
22 /// the capacity of the repr.
23 pub(crate) fn is_naturally_exhaustive(&self) -> Result<bool> {
24 let repr_str = self.repr.to_string();
25 if !repr_str.is_empty() {
26 let suffix = repr_str
27 .strip_prefix('i')
28 .or_else(|| repr_str.strip_prefix('u'));
29 if let Some(suffix) = suffix {
30 if suffix == "size" {
31 return Ok(false);
32 } else if let Ok(bits) = suffix.parse::<u32>() {
33 let variants = 1usize.checked_shl(bits);
34 return Ok(variants.map_or(false, |v| {
35 v == self
36 .variants
37 .iter()
38 .map(|v| v.alternative_values.len() + 1)
39 .sum()
40 }));
41 }
42 }
43 }
44 die!(self.repr.clone() => "Failed to parse repr into bit size");
45 }
46
47 pub(crate) fn default(&self) -> Option<&Ident> {
48 self.variants
49 .iter()
50 .find(|info| info.is_default)
51 .map(|info| &info.ident)
52 }
53
54 pub(crate) fn catch_all(&self) -> Option<&Ident> {
55 self.variants
56 .iter()
57 .find(|info| info.is_catch_all)
58 .map(|info| &info.ident)
59 }
60
61 pub(crate) fn variant_idents(&self) -> Vec<Ident> {
62 self.variants
63 .iter()
64 .map(|variant| variant.ident.clone())
65 .collect()
66 }
67
68 pub(crate) fn expression_idents(&self) -> Vec<Vec<Ident>> {
69 self.variants
70 .iter()
71 .filter(|variant| !variant.is_catch_all)
72 .map(|info| {
73 let indices = 0..(info.alternative_values.len() + 1);
74 indices
75 .map(|index| format_ident!("{}__num_enum_{}__", info.ident, index))
76 .collect()
77 })
78 .collect()
79 }
80
81 pub(crate) fn variant_expressions(&self) -> Vec<Vec<Expr>> {
82 self.variants
83 .iter()
84 .map(|variant| variant.all_values().cloned().collect())
85 .collect()
86 }
87
88 fn parse_attrs<Attrs: Iterator<Item = Attribute>>(
89 attrs: Attrs,
90 ) -> Result<(Ident, Option<ErrorType>)> {
91 let mut maybe_repr = None;
92 let mut maybe_error_type = None;
93 for attr in attrs {
94 if let Meta::List(meta_list) = &attr.meta {
95 if let Some(ident) = meta_list.path.get_ident() {
96 if ident == "repr" {
97 let mut nested = meta_list.tokens.clone().into_iter();
98 let repr_tree = match (nested.next(), nested.next()) {
99 (Some(repr_tree), None) => repr_tree,
100 _ => die!(attr =>
101 "Expected exactly one `repr` argument"
102 ),
103 };
104 let repr_ident: Ident = parse_quote! {
105 #repr_tree
106 };
107 if repr_ident == "C" {
108 die!(repr_ident =>
109 "repr(C) doesn't have a well defined size"
110 );
111 } else {
112 maybe_repr = Some(repr_ident);
113 }
114 } else if ident == "num_enum" {
115 let attributes =
116 attr.parse_args_with(crate::enum_attributes::Attributes::parse)?;
117 if let Some(error_type) = attributes.error_type {
118 if maybe_error_type.is_some() {
119 die!(attr => "At most one num_enum error_type attribute may be specified");
120 }
121 maybe_error_type = Some(error_type.into());
122 }
123 }
124 }
125 }
126 }
127 if maybe_repr.is_none() {
128 die!("Missing `#[repr({Integer})]` attribute");
129 }
130 Ok((maybe_repr.unwrap(), maybe_error_type))
131 }
132}
133
134impl Parse for EnumInfo {
135 fn parse(input: ParseStream) -> Result<Self> {
136 Ok({
137 let input: DeriveInput = input.parse()?;
138 let name = input.ident;
139 let data = match input.data {
140 Data::Enum(data) => data,
141 Data::Union(data) => die!(data.union_token => "Expected enum but found union"),
142 Data::Struct(data) => die!(data.struct_token => "Expected enum but found struct"),
143 };
144
145 let (repr, maybe_error_type) = Self::parse_attrs(input.attrs.into_iter())?;
146
147 let mut variants: Vec<VariantInfo> = vec![];
148 let mut has_default_variant: bool = false;
149 let mut has_catch_all_variant: bool = false;
150
151 // Vec to keep track of the used discriminants and alt values.
152 let mut discriminant_int_val_set = BTreeSet::new();
153
154 let mut next_discriminant = literal(0);
155 for variant in data.variants.into_iter() {
156 let ident = variant.ident.clone();
157
158 let discriminant = match &variant.discriminant {
159 Some(d) => d.1.clone(),
160 None => next_discriminant.clone(),
161 };
162
163 let mut raw_alternative_values: Vec<Expr> = vec![];
164 // Keep the attribute around for better error reporting.
165 let mut alt_attr_ref: Vec<&Attribute> = vec![];
166
167 // `#[num_enum(default)]` is required by `#[derive(FromPrimitive)]`
168 // and forbidden by `#[derive(UnsafeFromPrimitive)]`, so we need to
169 // keep track of whether we encountered such an attribute:
170 let mut is_default: bool = false;
171 let mut is_catch_all: bool = false;
172
173 for attribute in &variant.attrs {
174 if attribute.path().is_ident("default") {
175 if has_default_variant {
176 die!(attribute =>
177 "Multiple variants marked `#[default]` or `#[num_enum(default)]` found"
178 );
179 } else if has_catch_all_variant {
180 die!(attribute =>
181 "Attribute `default` is mutually exclusive with `catch_all`"
182 );
183 }
184 is_default = true;
185 has_default_variant = true;
186 }
187
188 if attribute.path().is_ident("num_enum") {
189 match attribute.parse_args_with(NumEnumVariantAttributes::parse) {
190 Ok(variant_attributes) => {
191 for variant_attribute in variant_attributes.items {
192 match variant_attribute {
193 NumEnumVariantAttributeItem::Default(default) => {
194 if has_default_variant {
195 die!(default.keyword =>
196 "Multiple variants marked `#[default]` or `#[num_enum(default)]` found"
197 );
198 } else if has_catch_all_variant {
199 die!(default.keyword =>
200 "Attribute `default` is mutually exclusive with `catch_all`"
201 );
202 }
203 is_default = true;
204 has_default_variant = true;
205 }
206 NumEnumVariantAttributeItem::CatchAll(catch_all) => {
207 if has_catch_all_variant {
208 die!(catch_all.keyword =>
209 "Multiple variants marked with `#[num_enum(catch_all)]`"
210 );
211 } else if has_default_variant {
212 die!(catch_all.keyword =>
213 "Attribute `catch_all` is mutually exclusive with `default`"
214 );
215 }
216
217 match variant
218 .fields
219 .iter()
220 .collect::<Vec<_>>()
221 .as_slice()
222 {
223 [syn::Field {
224 ty: syn::Type::Path(syn::TypePath { path, .. }),
225 ..
226 }] if path.is_ident(&repr) => {
227 is_catch_all = true;
228 has_catch_all_variant = true;
229 }
230 _ => {
231 die!(catch_all.keyword =>
232 "Variant with `catch_all` must be a tuple with exactly 1 field matching the repr type"
233 );
234 }
235 }
236 }
237 NumEnumVariantAttributeItem::Alternatives(alternatives) => {
238 raw_alternative_values.extend(alternatives.expressions);
239 alt_attr_ref.push(attribute);
240 }
241 }
242 }
243 }
244 Err(err) => {
245 if cfg!(not(feature = "complex-expressions")) {
246 let tokens = attribute.meta.to_token_stream();
247
248 let attribute_str = format!("{}", tokens);
249 if attribute_str.contains("alternatives")
250 && attribute_str.contains("..")
251 {
252 // Give a nice error message suggesting how to fix the problem.
253 die!(attribute => "Ranges are only supported as num_enum alternate values if the `complex-expressions` feature of the crate `num_enum` is enabled".to_string())
254 }
255 }
256 die!(attribute =>
257 format!("Invalid attribute: {}", err)
258 );
259 }
260 }
261 }
262 }
263
264 if !is_catch_all {
265 match &variant.fields {
266 Fields::Named(_) | Fields::Unnamed(_) => {
267 die!(variant => format!("`{}` only supports unit variants (with no associated data), but `{}::{}` was not a unit variant.", get_crate_name(), name, ident));
268 }
269 Fields::Unit => {}
270 }
271 }
272
273 let discriminant_value = parse_discriminant(&discriminant)?;
274
275 // Check for collision.
276 // We can't do const evaluation, or even compare arbitrary Exprs,
277 // so unfortunately we can't check for duplicates.
278 // That's not the end of the world, just we'll end up with compile errors for
279 // matches with duplicate branches in generated code instead of nice friendly error messages.
280 if let DiscriminantValue::Literal(canonical_value_int) = discriminant_value {
281 if discriminant_int_val_set.contains(&canonical_value_int) {
282 die!(ident => format!("The discriminant '{}' collides with a value attributed to a previous variant", canonical_value_int))
283 }
284 }
285
286 // Deal with the alternative values.
287 let mut flattened_alternative_values = Vec::new();
288 let mut flattened_raw_alternative_values = Vec::new();
289 for raw_alternative_value in raw_alternative_values {
290 let expanded_values = parse_alternative_values(&raw_alternative_value)?;
291 for expanded_value in expanded_values {
292 flattened_alternative_values.push(expanded_value);
293 flattened_raw_alternative_values.push(raw_alternative_value.clone())
294 }
295 }
296
297 if !flattened_alternative_values.is_empty() {
298 let alternate_int_values = flattened_alternative_values
299 .into_iter()
300 .map(|v| {
301 match v {
302 DiscriminantValue::Literal(value) => Ok(value),
303 DiscriminantValue::Expr(expr) => {
304 if let Expr::Range(_) = expr {
305 if cfg!(not(feature = "complex-expressions")) {
306 // Give a nice error message suggesting how to fix the problem.
307 die!(expr => "Ranges are only supported as num_enum alternate values if the `complex-expressions` feature of the crate `num_enum` is enabled".to_string())
308 }
309 }
310 // We can't do uniqueness checking on non-literals, so we don't allow them as alternate values.
311 // We could probably allow them, but there doesn't seem to be much of a use-case,
312 // and it's easier to give good error messages about duplicate values this way,
313 // rather than rustc errors on conflicting match branches.
314 die!(expr => "Only literals are allowed as num_enum alternate values".to_string())
315 },
316 }
317 })
318 .collect::<Result<Vec<i128>>>()?;
319 let mut sorted_alternate_int_values = alternate_int_values.clone();
320 sorted_alternate_int_values.sort_unstable();
321 let sorted_alternate_int_values = sorted_alternate_int_values;
322
323 // Check if the current discriminant is not in the alternative values.
324 if let DiscriminantValue::Literal(canonical_value_int) = discriminant_value {
325 if let Some(index) = alternate_int_values
326 .iter()
327 .position(|&x| x == canonical_value_int)
328 {
329 die!(&flattened_raw_alternative_values[index] => format!("'{}' in the alternative values is already attributed as the discriminant of this variant", canonical_value_int));
330 }
331 }
332
333 // Search for duplicates, the vec is sorted. Warn about them.
334 if (1..sorted_alternate_int_values.len()).any(|i| {
335 sorted_alternate_int_values[i] == sorted_alternate_int_values[i - 1]
336 }) {
337 let attr = *alt_attr_ref.last().unwrap();
338 die!(attr => "There is duplication in the alternative values");
339 }
340 // Search if those discriminant_int_val_set where already attributed.
341 // (discriminant_int_val_set is BTreeSet, and iter().next_back() is the is the maximum in the set.)
342 if let Some(last_upper_val) = discriminant_int_val_set.iter().next_back() {
343 if sorted_alternate_int_values.first().unwrap() <= last_upper_val {
344 for (index, val) in alternate_int_values.iter().enumerate() {
345 if discriminant_int_val_set.contains(val) {
346 die!(&flattened_raw_alternative_values[index] => format!("'{}' in the alternative values is already attributed to a previous variant", val));
347 }
348 }
349 }
350 }
351
352 // Reconstruct the alternative_values vec of Expr but sorted.
353 flattened_raw_alternative_values = sorted_alternate_int_values
354 .iter()
355 .map(|val| literal(val.to_owned()))
356 .collect();
357
358 // Add the alternative values to the the set to keep track.
359 discriminant_int_val_set.extend(sorted_alternate_int_values);
360 }
361
362 // Add the current discriminant to the the set to keep track.
363 if let DiscriminantValue::Literal(canonical_value_int) = discriminant_value {
364 discriminant_int_val_set.insert(canonical_value_int);
365 }
366
367 variants.push(VariantInfo {
368 ident,
369 is_default,
370 is_catch_all,
371 canonical_value: discriminant,
372 alternative_values: flattened_raw_alternative_values,
373 });
374
375 // Get the next value for the discriminant.
376 next_discriminant = match discriminant_value {
377 DiscriminantValue::Literal(int_value) => literal(int_value.wrapping_add(1)),
378 DiscriminantValue::Expr(expr) => {
379 parse_quote! {
380 #repr::wrapping_add(#expr, 1)
381 }
382 }
383 }
384 }
385
386 let error_type_info = maybe_error_type.unwrap_or_else(|| {
387 let crate_name = Ident::new(&get_crate_name(), Span::call_site());
388 ErrorType {
389 name: parse_quote! {
390 ::#crate_name::TryFromPrimitiveError<Self>
391 },
392 constructor: parse_quote! {
393 ::#crate_name::TryFromPrimitiveError::<Self>::new
394 },
395 }
396 });
397
398 EnumInfo {
399 name,
400 repr,
401 variants,
402 error_type_info,
403 }
404 })
405 }
406}
407
408fn literal(i: i128) -> Expr {
409 Expr::Lit(ExprLit {
410 lit: Lit::Int(LitInt::new(&i.to_string(), Span::call_site())),
411 attrs: vec![],
412 })
413}
414
415enum DiscriminantValue {
416 Literal(i128),
417 Expr(Expr),
418}
419
420fn parse_discriminant(val_exp: &Expr) -> Result<DiscriminantValue> {
421 let mut sign: i128 = 1;
422 let mut unsigned_expr: &Expr = val_exp;
423 if let Expr::Unary(ExprUnary {
424 op: UnOp::Neg(..),
425 expr: &Box,
426 ..
427 }) = val_exp
428 {
429 unsigned_expr = expr;
430 sign = -1;
431 }
432 if let Expr::Lit(ExprLit {
433 lit: Lit::Int(ref lit_int: &LitInt),
434 ..
435 }) = unsigned_expr
436 {
437 Ok(DiscriminantValue::Literal(
438 sign * lit_int.base10_parse::<i128>()?,
439 ))
440 } else {
441 Ok(DiscriminantValue::Expr(val_exp.clone()))
442 }
443}
444
445#[cfg(feature = "complex-expressions")]
446fn parse_alternative_values(val_expr: &Expr) -> Result<Vec<DiscriminantValue>> {
447 fn range_expr_value_to_number(
448 parent_range_expr: &Expr,
449 range_bound_value: &Option<Box<Expr>>,
450 ) -> Result<i128> {
451 // Avoid needing to calculate what the lower and upper bound would be - these are type dependent,
452 // and also may not be obvious in context (e.g. an omitted bound could reasonably mean "from the last discriminant" or "from the lower bound of the type").
453 if let Some(range_bound_value) = range_bound_value {
454 let range_bound_value = parse_discriminant(range_bound_value.as_ref())?;
455 // If non-literals are used, we can't expand to the mapped values, so can't write a nice match statement or do exhaustiveness checking.
456 // Require literals instead.
457 if let DiscriminantValue::Literal(value) = range_bound_value {
458 return Ok(value);
459 }
460 }
461 die!(parent_range_expr => "When ranges are used for alternate values, both bounds most be explicitly specified numeric literals")
462 }
463
464 if let Expr::Range(syn::ExprRange {
465 start, end, limits, ..
466 }) = val_expr
467 {
468 let lower = range_expr_value_to_number(val_expr, start)?;
469 let upper = range_expr_value_to_number(val_expr, end)?;
470 // While this is technically allowed in Rust, and results in an empty range, it's almost certainly a mistake in this context.
471 if lower > upper {
472 die!(val_expr => "When using ranges for alternate values, upper bound must not be less than lower bound");
473 }
474 let mut values = Vec::with_capacity((upper - lower) as usize);
475 let mut next = lower;
476 loop {
477 match limits {
478 syn::RangeLimits::HalfOpen(..) => {
479 if next == upper {
480 break;
481 }
482 }
483 syn::RangeLimits::Closed(..) => {
484 if next > upper {
485 break;
486 }
487 }
488 }
489 values.push(DiscriminantValue::Literal(next));
490 next += 1;
491 }
492 return Ok(values);
493 }
494 parse_discriminant(val_expr).map(|v| vec![v])
495}
496
497#[cfg(not(feature = "complex-expressions"))]
498fn parse_alternative_values(val_expr: &Expr) -> Result<Vec<DiscriminantValue>> {
499 parse_discriminant(val_expr).map(|v: DiscriminantValue| vec![v])
500}
501
502pub(crate) struct VariantInfo {
503 ident: Ident,
504 is_default: bool,
505 is_catch_all: bool,
506 canonical_value: Expr,
507 alternative_values: Vec<Expr>,
508}
509
510impl VariantInfo {
511 fn all_values(&self) -> impl Iterator<Item = &Expr> {
512 ::core::iter::once(&self.canonical_value).chain(self.alternative_values.iter())
513 }
514}
515
516pub(crate) struct ErrorType {
517 pub(crate) name: Path,
518 pub(crate) constructor: Path,
519}
520
521impl From<ErrorTypeAttribute> for ErrorType {
522 fn from(attribute: ErrorTypeAttribute) -> Self {
523 Self {
524 name: attribute.name.path,
525 constructor: attribute.constructor.path,
526 }
527 }
528}
529
530#[cfg(feature = "proc-macro-crate")]
531pub(crate) fn get_crate_name() -> String {
532 let found_crate: FoundCrate = proc_macro_crate::crate_name("num_enum").unwrap_or_else(|err: Error| {
533 eprintln!("Warning: {}\n => defaulting to `num_enum`", err,);
534 proc_macro_crate::FoundCrate::Itself
535 });
536
537 match found_crate {
538 proc_macro_crate::FoundCrate::Itself => String::from("num_enum"),
539 proc_macro_crate::FoundCrate::Name(name: String) => name,
540 }
541}
542
543// Don't depend on proc-macro-crate in no_std environments because it causes an awkward dependency
544// on serde with std.
545//
546// no_std dependees on num_enum cannot rename the num_enum crate when they depend on it. Sorry.
547//
548// See https://github.com/illicitonion/num_enum/issues/18
549#[cfg(not(feature = "proc-macro-crate"))]
550pub(crate) fn get_crate_name() -> String {
551 String::from("num_enum")
552}
553