1 | use crate::enum_attributes::ErrorTypeAttribute; |
2 | use crate::utils::die; |
3 | use crate::variant_attributes::{NumEnumVariantAttributeItem, NumEnumVariantAttributes}; |
4 | use proc_macro2::Span; |
5 | use quote::{format_ident, ToTokens}; |
6 | use std::collections::BTreeSet; |
7 | use syn::{ |
8 | parse::{Parse, ParseStream}, |
9 | parse_quote, Attribute, Data, DeriveInput, Expr, ExprLit, ExprUnary, Fields, Ident, Lit, |
10 | LitInt, Meta, Path, Result, UnOp, |
11 | }; |
12 | |
13 | pub(crate) struct EnumInfo { |
14 | pub(crate) name: Ident, |
15 | pub(crate) repr: Ident, |
16 | pub(crate) variants: Vec<VariantInfo>, |
17 | pub(crate) error_type_info: ErrorType, |
18 | } |
19 | |
20 | impl EnumInfo { |
21 | /// Returns whether the number of variants (ignoring defaults, catch-alls, etc) is the same as |
22 | /// the capacity of the repr. |
23 | pub(crate) fn is_naturally_exhaustive(&self) -> Result<bool> { |
24 | let repr_str = self.repr.to_string(); |
25 | if !repr_str.is_empty() { |
26 | let suffix = repr_str |
27 | .strip_prefix('i' ) |
28 | .or_else(|| repr_str.strip_prefix('u' )); |
29 | if let Some(suffix) = suffix { |
30 | if suffix == "size" { |
31 | return Ok(false); |
32 | } else if let Ok(bits) = suffix.parse::<u32>() { |
33 | let variants = 1usize.checked_shl(bits); |
34 | return Ok(variants.map_or(false, |v| { |
35 | v == self |
36 | .variants |
37 | .iter() |
38 | .map(|v| v.alternative_values.len() + 1) |
39 | .sum() |
40 | })); |
41 | } |
42 | } |
43 | } |
44 | die!(self.repr.clone() => "Failed to parse repr into bit size" ); |
45 | } |
46 | |
47 | pub(crate) fn default(&self) -> Option<&Ident> { |
48 | self.variants |
49 | .iter() |
50 | .find(|info| info.is_default) |
51 | .map(|info| &info.ident) |
52 | } |
53 | |
54 | pub(crate) fn catch_all(&self) -> Option<&Ident> { |
55 | self.variants |
56 | .iter() |
57 | .find(|info| info.is_catch_all) |
58 | .map(|info| &info.ident) |
59 | } |
60 | |
61 | pub(crate) fn variant_idents(&self) -> Vec<Ident> { |
62 | self.variants |
63 | .iter() |
64 | .filter(|variant| !variant.is_catch_all) |
65 | .map(|variant| variant.ident.clone()) |
66 | .collect() |
67 | } |
68 | |
69 | pub(crate) fn expression_idents(&self) -> Vec<Vec<Ident>> { |
70 | self.variants |
71 | .iter() |
72 | .filter(|variant| !variant.is_catch_all) |
73 | .map(|info| { |
74 | let indices = 0..(info.alternative_values.len() + 1); |
75 | indices |
76 | .map(|index| format_ident!(" {}__num_enum_ {}__" , info.ident, index)) |
77 | .collect() |
78 | }) |
79 | .collect() |
80 | } |
81 | |
82 | pub(crate) fn variant_expressions(&self) -> Vec<Vec<Expr>> { |
83 | self.variants |
84 | .iter() |
85 | .filter(|variant| !variant.is_catch_all) |
86 | .map(|variant| variant.all_values().cloned().collect()) |
87 | .collect() |
88 | } |
89 | |
90 | fn parse_attrs<Attrs: Iterator<Item = Attribute>>( |
91 | attrs: Attrs, |
92 | ) -> Result<(Ident, Option<ErrorType>)> { |
93 | let mut maybe_repr = None; |
94 | let mut maybe_error_type = None; |
95 | for attr in attrs { |
96 | if let Meta::List(meta_list) = &attr.meta { |
97 | if let Some(ident) = meta_list.path.get_ident() { |
98 | if ident == "repr" { |
99 | let mut nested = meta_list.tokens.clone().into_iter(); |
100 | let repr_tree = match (nested.next(), nested.next()) { |
101 | (Some(repr_tree), None) => repr_tree, |
102 | _ => die!(attr => |
103 | "Expected exactly one `repr` argument" |
104 | ), |
105 | }; |
106 | let repr_ident: Ident = parse_quote! { |
107 | #repr_tree |
108 | }; |
109 | if repr_ident == "C" { |
110 | die!(repr_ident => |
111 | "repr(C) doesn't have a well defined size" |
112 | ); |
113 | } else { |
114 | maybe_repr = Some(repr_ident); |
115 | } |
116 | } else if ident == "num_enum" { |
117 | let attributes = |
118 | attr.parse_args_with(crate::enum_attributes::Attributes::parse)?; |
119 | if let Some(error_type) = attributes.error_type { |
120 | if maybe_error_type.is_some() { |
121 | die!(attr => "At most one num_enum error_type attribute may be specified" ); |
122 | } |
123 | maybe_error_type = Some(error_type.into()); |
124 | } |
125 | } |
126 | } |
127 | } |
128 | } |
129 | if maybe_repr.is_none() { |
130 | die!("Missing `#[repr({Integer})]` attribute" ); |
131 | } |
132 | Ok((maybe_repr.unwrap(), maybe_error_type)) |
133 | } |
134 | } |
135 | |
136 | impl Parse for EnumInfo { |
137 | fn parse(input: ParseStream) -> Result<Self> { |
138 | Ok({ |
139 | let input: DeriveInput = input.parse()?; |
140 | let name = input.ident; |
141 | let data = match input.data { |
142 | Data::Enum(data) => data, |
143 | Data::Union(data) => die!(data.union_token => "Expected enum but found union" ), |
144 | Data::Struct(data) => die!(data.struct_token => "Expected enum but found struct" ), |
145 | }; |
146 | |
147 | let (repr, maybe_error_type) = Self::parse_attrs(input.attrs.into_iter())?; |
148 | |
149 | let mut variants: Vec<VariantInfo> = vec![]; |
150 | let mut has_default_variant: bool = false; |
151 | let mut has_catch_all_variant: bool = false; |
152 | |
153 | // Vec to keep track of the used discriminants and alt values. |
154 | let mut discriminant_int_val_set = BTreeSet::new(); |
155 | |
156 | let mut next_discriminant = literal(0); |
157 | for variant in data.variants.into_iter() { |
158 | let ident = variant.ident.clone(); |
159 | |
160 | let discriminant = match &variant.discriminant { |
161 | Some(d) => d.1.clone(), |
162 | None => next_discriminant.clone(), |
163 | }; |
164 | |
165 | let mut raw_alternative_values: Vec<Expr> = vec![]; |
166 | // Keep the attribute around for better error reporting. |
167 | let mut alt_attr_ref: Vec<&Attribute> = vec![]; |
168 | |
169 | // `#[num_enum(default)]` is required by `#[derive(FromPrimitive)]` |
170 | // and forbidden by `#[derive(UnsafeFromPrimitive)]`, so we need to |
171 | // keep track of whether we encountered such an attribute: |
172 | let mut is_default: bool = false; |
173 | let mut is_catch_all: bool = false; |
174 | |
175 | for attribute in &variant.attrs { |
176 | if attribute.path().is_ident("default" ) { |
177 | if has_default_variant { |
178 | die!(attribute => |
179 | "Multiple variants marked `#[default]` or `#[num_enum(default)]` found" |
180 | ); |
181 | } else if has_catch_all_variant { |
182 | die!(attribute => |
183 | "Attribute `default` is mutually exclusive with `catch_all`" |
184 | ); |
185 | } |
186 | is_default = true; |
187 | has_default_variant = true; |
188 | } |
189 | |
190 | if attribute.path().is_ident("num_enum" ) { |
191 | match attribute.parse_args_with(NumEnumVariantAttributes::parse) { |
192 | Ok(variant_attributes) => { |
193 | for variant_attribute in variant_attributes.items { |
194 | match variant_attribute { |
195 | NumEnumVariantAttributeItem::Default(default) => { |
196 | if has_default_variant { |
197 | die!(default.keyword => |
198 | "Multiple variants marked `#[default]` or `#[num_enum(default)]` found" |
199 | ); |
200 | } else if has_catch_all_variant { |
201 | die!(default.keyword => |
202 | "Attribute `default` is mutually exclusive with `catch_all`" |
203 | ); |
204 | } |
205 | is_default = true; |
206 | has_default_variant = true; |
207 | } |
208 | NumEnumVariantAttributeItem::CatchAll(catch_all) => { |
209 | if has_catch_all_variant { |
210 | die!(catch_all.keyword => |
211 | "Multiple variants marked with `#[num_enum(catch_all)]`" |
212 | ); |
213 | } else if has_default_variant { |
214 | die!(catch_all.keyword => |
215 | "Attribute `catch_all` is mutually exclusive with `default`" |
216 | ); |
217 | } |
218 | |
219 | match variant |
220 | .fields |
221 | .iter() |
222 | .collect::<Vec<_>>() |
223 | .as_slice() |
224 | { |
225 | [syn::Field { |
226 | ty: syn::Type::Path(syn::TypePath { path, .. }), |
227 | .. |
228 | }] if path.is_ident(&repr) => { |
229 | is_catch_all = true; |
230 | has_catch_all_variant = true; |
231 | } |
232 | _ => { |
233 | die!(catch_all.keyword => |
234 | "Variant with `catch_all` must be a tuple with exactly 1 field matching the repr type" |
235 | ); |
236 | } |
237 | } |
238 | } |
239 | NumEnumVariantAttributeItem::Alternatives(alternatives) => { |
240 | raw_alternative_values.extend(alternatives.expressions); |
241 | alt_attr_ref.push(attribute); |
242 | } |
243 | } |
244 | } |
245 | } |
246 | Err(err) => { |
247 | if cfg!(not(feature = "complex-expressions" )) { |
248 | let tokens = attribute.meta.to_token_stream(); |
249 | |
250 | let attribute_str = format!(" {}" , tokens); |
251 | if attribute_str.contains("alternatives" ) |
252 | && attribute_str.contains(".." ) |
253 | { |
254 | // Give a nice error message suggesting how to fix the problem. |
255 | die!(attribute => "Ranges are only supported as num_enum alternate values if the `complex-expressions` feature of the crate `num_enum` is enabled" .to_string()) |
256 | } |
257 | } |
258 | die!(attribute => |
259 | format!("Invalid attribute: {}" , err) |
260 | ); |
261 | } |
262 | } |
263 | } |
264 | } |
265 | |
266 | if !is_catch_all { |
267 | match &variant.fields { |
268 | Fields::Named(_) | Fields::Unnamed(_) => { |
269 | die!(variant => format!("` {}` only supports unit variants (with no associated data), but ` {}:: {}` was not a unit variant." , get_crate_name(), name, ident)); |
270 | } |
271 | Fields::Unit => {} |
272 | } |
273 | } |
274 | |
275 | let discriminant_value = parse_discriminant(&discriminant)?; |
276 | |
277 | // Check for collision. |
278 | // We can't do const evaluation, or even compare arbitrary Exprs, |
279 | // so unfortunately we can't check for duplicates. |
280 | // That's not the end of the world, just we'll end up with compile errors for |
281 | // matches with duplicate branches in generated code instead of nice friendly error messages. |
282 | if let DiscriminantValue::Literal(canonical_value_int) = discriminant_value { |
283 | if discriminant_int_val_set.contains(&canonical_value_int) { |
284 | die!(ident => format!("The discriminant ' {}' collides with a value attributed to a previous variant" , canonical_value_int)) |
285 | } |
286 | } |
287 | |
288 | // Deal with the alternative values. |
289 | let mut flattened_alternative_values = Vec::new(); |
290 | let mut flattened_raw_alternative_values = Vec::new(); |
291 | for raw_alternative_value in raw_alternative_values { |
292 | let expanded_values = parse_alternative_values(&raw_alternative_value)?; |
293 | for expanded_value in expanded_values { |
294 | flattened_alternative_values.push(expanded_value); |
295 | flattened_raw_alternative_values.push(raw_alternative_value.clone()) |
296 | } |
297 | } |
298 | |
299 | if !flattened_alternative_values.is_empty() { |
300 | let alternate_int_values = flattened_alternative_values |
301 | .into_iter() |
302 | .map(|v| { |
303 | match v { |
304 | DiscriminantValue::Literal(value) => Ok(value), |
305 | DiscriminantValue::Expr(expr) => { |
306 | if let Expr::Range(_) = expr { |
307 | if cfg!(not(feature = "complex-expressions" )) { |
308 | // Give a nice error message suggesting how to fix the problem. |
309 | die!(expr => "Ranges are only supported as num_enum alternate values if the `complex-expressions` feature of the crate `num_enum` is enabled" .to_string()) |
310 | } |
311 | } |
312 | // We can't do uniqueness checking on non-literals, so we don't allow them as alternate values. |
313 | // We could probably allow them, but there doesn't seem to be much of a use-case, |
314 | // and it's easier to give good error messages about duplicate values this way, |
315 | // rather than rustc errors on conflicting match branches. |
316 | die!(expr => "Only literals are allowed as num_enum alternate values" .to_string()) |
317 | }, |
318 | } |
319 | }) |
320 | .collect::<Result<Vec<i128>>>()?; |
321 | let mut sorted_alternate_int_values = alternate_int_values.clone(); |
322 | sorted_alternate_int_values.sort_unstable(); |
323 | let sorted_alternate_int_values = sorted_alternate_int_values; |
324 | |
325 | // Check if the current discriminant is not in the alternative values. |
326 | if let DiscriminantValue::Literal(canonical_value_int) = discriminant_value { |
327 | if let Some(index) = alternate_int_values |
328 | .iter() |
329 | .position(|&x| x == canonical_value_int) |
330 | { |
331 | die!(&flattened_raw_alternative_values[index] => format!("' {}' in the alternative values is already attributed as the discriminant of this variant" , canonical_value_int)); |
332 | } |
333 | } |
334 | |
335 | // Search for duplicates, the vec is sorted. Warn about them. |
336 | if (1..sorted_alternate_int_values.len()).any(|i| { |
337 | sorted_alternate_int_values[i] == sorted_alternate_int_values[i - 1] |
338 | }) { |
339 | let attr = *alt_attr_ref.last().unwrap(); |
340 | die!(attr => "There is duplication in the alternative values" ); |
341 | } |
342 | // Search if those discriminant_int_val_set where already attributed. |
343 | // (discriminant_int_val_set is BTreeSet, and iter().next_back() is the is the maximum in the set.) |
344 | if let Some(last_upper_val) = discriminant_int_val_set.iter().next_back() { |
345 | if sorted_alternate_int_values.first().unwrap() <= last_upper_val { |
346 | for (index, val) in alternate_int_values.iter().enumerate() { |
347 | if discriminant_int_val_set.contains(val) { |
348 | die!(&flattened_raw_alternative_values[index] => format!("' {}' in the alternative values is already attributed to a previous variant" , val)); |
349 | } |
350 | } |
351 | } |
352 | } |
353 | |
354 | // Reconstruct the alternative_values vec of Expr but sorted. |
355 | flattened_raw_alternative_values = sorted_alternate_int_values |
356 | .iter() |
357 | .map(|val| literal(val.to_owned())) |
358 | .collect(); |
359 | |
360 | // Add the alternative values to the the set to keep track. |
361 | discriminant_int_val_set.extend(sorted_alternate_int_values); |
362 | } |
363 | |
364 | // Add the current discriminant to the the set to keep track. |
365 | if let DiscriminantValue::Literal(canonical_value_int) = discriminant_value { |
366 | discriminant_int_val_set.insert(canonical_value_int); |
367 | } |
368 | |
369 | variants.push(VariantInfo { |
370 | ident, |
371 | is_default, |
372 | is_catch_all, |
373 | canonical_value: discriminant, |
374 | alternative_values: flattened_raw_alternative_values, |
375 | }); |
376 | |
377 | // Get the next value for the discriminant. |
378 | next_discriminant = match discriminant_value { |
379 | DiscriminantValue::Literal(int_value) => literal(int_value.wrapping_add(1)), |
380 | DiscriminantValue::Expr(expr) => { |
381 | parse_quote! { |
382 | #repr::wrapping_add(#expr, 1) |
383 | } |
384 | } |
385 | } |
386 | } |
387 | |
388 | let error_type_info = maybe_error_type.unwrap_or_else(|| { |
389 | let crate_name = Ident::new(&get_crate_name(), Span::call_site()); |
390 | ErrorType { |
391 | name: parse_quote! { |
392 | ::#crate_name::TryFromPrimitiveError<Self> |
393 | }, |
394 | constructor: parse_quote! { |
395 | ::#crate_name::TryFromPrimitiveError::<Self>::new |
396 | }, |
397 | } |
398 | }); |
399 | |
400 | EnumInfo { |
401 | name, |
402 | repr, |
403 | variants, |
404 | error_type_info, |
405 | } |
406 | }) |
407 | } |
408 | } |
409 | |
410 | fn literal(i: i128) -> Expr { |
411 | Expr::Lit(ExprLit { |
412 | lit: Lit::Int(LitInt::new(&i.to_string(), Span::call_site())), |
413 | attrs: vec![], |
414 | }) |
415 | } |
416 | |
417 | enum DiscriminantValue { |
418 | Literal(i128), |
419 | Expr(Expr), |
420 | } |
421 | |
422 | fn parse_discriminant(val_exp: &Expr) -> Result<DiscriminantValue> { |
423 | let mut sign: i128 = 1; |
424 | let mut unsigned_expr: &Expr = val_exp; |
425 | if let Expr::Unary(ExprUnary { |
426 | op: UnOp::Neg(..), |
427 | expr: &Box, |
428 | .. |
429 | }) = val_exp |
430 | { |
431 | unsigned_expr = expr; |
432 | sign = -1; |
433 | } |
434 | if let Expr::Lit(ExprLit { |
435 | lit: Lit::Int(ref lit_int: &LitInt), |
436 | .. |
437 | }) = unsigned_expr |
438 | { |
439 | Ok(DiscriminantValue::Literal( |
440 | sign * lit_int.base10_parse::<i128>()?, |
441 | )) |
442 | } else { |
443 | Ok(DiscriminantValue::Expr(val_exp.clone())) |
444 | } |
445 | } |
446 | |
447 | #[cfg (feature = "complex-expressions" )] |
448 | fn parse_alternative_values(val_expr: &Expr) -> Result<Vec<DiscriminantValue>> { |
449 | fn range_expr_value_to_number( |
450 | parent_range_expr: &Expr, |
451 | range_bound_value: &Option<Box<Expr>>, |
452 | ) -> Result<i128> { |
453 | // Avoid needing to calculate what the lower and upper bound would be - these are type dependent, |
454 | // and also may not be obvious in context (e.g. an omitted bound could reasonably mean "from the last discriminant" or "from the lower bound of the type"). |
455 | if let Some(range_bound_value) = range_bound_value { |
456 | let range_bound_value = parse_discriminant(range_bound_value.as_ref())?; |
457 | // If non-literals are used, we can't expand to the mapped values, so can't write a nice match statement or do exhaustiveness checking. |
458 | // Require literals instead. |
459 | if let DiscriminantValue::Literal(value) = range_bound_value { |
460 | return Ok(value); |
461 | } |
462 | } |
463 | die!(parent_range_expr => "When ranges are used for alternate values, both bounds most be explicitly specified numeric literals" ) |
464 | } |
465 | |
466 | if let Expr::Range(syn::ExprRange { |
467 | start, end, limits, .. |
468 | }) = val_expr |
469 | { |
470 | let lower = range_expr_value_to_number(val_expr, start)?; |
471 | let upper = range_expr_value_to_number(val_expr, end)?; |
472 | // While this is technically allowed in Rust, and results in an empty range, it's almost certainly a mistake in this context. |
473 | if lower > upper { |
474 | die!(val_expr => "When using ranges for alternate values, upper bound must not be less than lower bound" ); |
475 | } |
476 | let mut values = Vec::with_capacity((upper - lower) as usize); |
477 | let mut next = lower; |
478 | loop { |
479 | match limits { |
480 | syn::RangeLimits::HalfOpen(..) => { |
481 | if next == upper { |
482 | break; |
483 | } |
484 | } |
485 | syn::RangeLimits::Closed(..) => { |
486 | if next > upper { |
487 | break; |
488 | } |
489 | } |
490 | } |
491 | values.push(DiscriminantValue::Literal(next)); |
492 | next += 1; |
493 | } |
494 | return Ok(values); |
495 | } |
496 | parse_discriminant(val_expr).map(|v| vec![v]) |
497 | } |
498 | |
499 | #[cfg (not(feature = "complex-expressions" ))] |
500 | fn parse_alternative_values(val_expr: &Expr) -> Result<Vec<DiscriminantValue>> { |
501 | parse_discriminant(val_expr).map(|v: DiscriminantValue| vec![v]) |
502 | } |
503 | |
504 | pub(crate) struct VariantInfo { |
505 | ident: Ident, |
506 | is_default: bool, |
507 | is_catch_all: bool, |
508 | canonical_value: Expr, |
509 | alternative_values: Vec<Expr>, |
510 | } |
511 | |
512 | impl VariantInfo { |
513 | fn all_values(&self) -> impl Iterator<Item = &Expr> { |
514 | ::core::iter::once(&self.canonical_value).chain(self.alternative_values.iter()) |
515 | } |
516 | } |
517 | |
518 | pub(crate) struct ErrorType { |
519 | pub(crate) name: Path, |
520 | pub(crate) constructor: Path, |
521 | } |
522 | |
523 | impl From<ErrorTypeAttribute> for ErrorType { |
524 | fn from(attribute: ErrorTypeAttribute) -> Self { |
525 | Self { |
526 | name: attribute.name.path, |
527 | constructor: attribute.constructor.path, |
528 | } |
529 | } |
530 | } |
531 | |
532 | #[cfg (feature = "proc-macro-crate" )] |
533 | pub(crate) fn get_crate_name() -> String { |
534 | let found_crate: FoundCrate = proc_macro_crate::crate_name("num_enum" ).unwrap_or_else(|err: Error| { |
535 | eprintln!("Warning: {}\n => defaulting to `num_enum`" , err,); |
536 | proc_macro_crate::FoundCrate::Itself |
537 | }); |
538 | |
539 | match found_crate { |
540 | proc_macro_crate::FoundCrate::Itself => String::from("num_enum" ), |
541 | proc_macro_crate::FoundCrate::Name(name: String) => name, |
542 | } |
543 | } |
544 | |
545 | // Don't depend on proc-macro-crate in no_std environments because it causes an awkward dependency |
546 | // on serde with std. |
547 | // |
548 | // no_std dependees on num_enum cannot rename the num_enum crate when they depend on it. Sorry. |
549 | // |
550 | // See https://github.com/illicitonion/num_enum/issues/18 |
551 | #[cfg (not(feature = "proc-macro-crate" ))] |
552 | pub(crate) fn get_crate_name() -> String { |
553 | String::from("num_enum" ) |
554 | } |
555 | |