1 | use crate::enum_attributes::ErrorTypeAttribute; |
2 | use crate::utils::die; |
3 | use crate::variant_attributes::{NumEnumVariantAttributeItem, NumEnumVariantAttributes}; |
4 | use proc_macro2::Span; |
5 | use quote::{format_ident, ToTokens}; |
6 | use std::collections::BTreeSet; |
7 | use syn::{ |
8 | parse::{Parse, ParseStream}, |
9 | parse_quote, Attribute, Data, DeriveInput, Expr, ExprLit, ExprUnary, Fields, Ident, Lit, |
10 | LitInt, Meta, Path, Result, UnOp, |
11 | }; |
12 | |
13 | pub(crate) struct EnumInfo { |
14 | pub(crate) name: Ident, |
15 | pub(crate) repr: Ident, |
16 | pub(crate) variants: Vec<VariantInfo>, |
17 | pub(crate) error_type_info: ErrorType, |
18 | } |
19 | |
20 | impl EnumInfo { |
21 | /// Returns whether the number of variants (ignoring defaults, catch-alls, etc) is the same as |
22 | /// the capacity of the repr. |
23 | pub(crate) fn is_naturally_exhaustive(&self) -> Result<bool> { |
24 | let repr_str = self.repr.to_string(); |
25 | if !repr_str.is_empty() { |
26 | let suffix = repr_str |
27 | .strip_prefix('i' ) |
28 | .or_else(|| repr_str.strip_prefix('u' )); |
29 | if let Some(suffix) = suffix { |
30 | if suffix == "size" { |
31 | return Ok(false); |
32 | } else if let Ok(bits) = suffix.parse::<u32>() { |
33 | let variants = 1usize.checked_shl(bits); |
34 | return Ok(variants.map_or(false, |v| { |
35 | v == self |
36 | .variants |
37 | .iter() |
38 | .map(|v| v.alternative_values.len() + 1) |
39 | .sum() |
40 | })); |
41 | } |
42 | } |
43 | } |
44 | die!(self.repr.clone() => "Failed to parse repr into bit size" ); |
45 | } |
46 | |
47 | pub(crate) fn default(&self) -> Option<&Ident> { |
48 | self.variants |
49 | .iter() |
50 | .find(|info| info.is_default) |
51 | .map(|info| &info.ident) |
52 | } |
53 | |
54 | pub(crate) fn catch_all(&self) -> Option<&Ident> { |
55 | self.variants |
56 | .iter() |
57 | .find(|info| info.is_catch_all) |
58 | .map(|info| &info.ident) |
59 | } |
60 | |
61 | pub(crate) fn variant_idents(&self) -> Vec<Ident> { |
62 | self.variants |
63 | .iter() |
64 | .map(|variant| variant.ident.clone()) |
65 | .collect() |
66 | } |
67 | |
68 | pub(crate) fn expression_idents(&self) -> Vec<Vec<Ident>> { |
69 | self.variants |
70 | .iter() |
71 | .filter(|variant| !variant.is_catch_all) |
72 | .map(|info| { |
73 | let indices = 0..(info.alternative_values.len() + 1); |
74 | indices |
75 | .map(|index| format_ident!(" {}__num_enum_ {}__" , info.ident, index)) |
76 | .collect() |
77 | }) |
78 | .collect() |
79 | } |
80 | |
81 | pub(crate) fn variant_expressions(&self) -> Vec<Vec<Expr>> { |
82 | self.variants |
83 | .iter() |
84 | .map(|variant| variant.all_values().cloned().collect()) |
85 | .collect() |
86 | } |
87 | |
88 | fn parse_attrs<Attrs: Iterator<Item = Attribute>>( |
89 | attrs: Attrs, |
90 | ) -> Result<(Ident, Option<ErrorType>)> { |
91 | let mut maybe_repr = None; |
92 | let mut maybe_error_type = None; |
93 | for attr in attrs { |
94 | if let Meta::List(meta_list) = &attr.meta { |
95 | if let Some(ident) = meta_list.path.get_ident() { |
96 | if ident == "repr" { |
97 | let mut nested = meta_list.tokens.clone().into_iter(); |
98 | let repr_tree = match (nested.next(), nested.next()) { |
99 | (Some(repr_tree), None) => repr_tree, |
100 | _ => die!(attr => |
101 | "Expected exactly one `repr` argument" |
102 | ), |
103 | }; |
104 | let repr_ident: Ident = parse_quote! { |
105 | #repr_tree |
106 | }; |
107 | if repr_ident == "C" { |
108 | die!(repr_ident => |
109 | "repr(C) doesn't have a well defined size" |
110 | ); |
111 | } else { |
112 | maybe_repr = Some(repr_ident); |
113 | } |
114 | } else if ident == "num_enum" { |
115 | let attributes = |
116 | attr.parse_args_with(crate::enum_attributes::Attributes::parse)?; |
117 | if let Some(error_type) = attributes.error_type { |
118 | if maybe_error_type.is_some() { |
119 | die!(attr => "At most one num_enum error_type attribute may be specified" ); |
120 | } |
121 | maybe_error_type = Some(error_type.into()); |
122 | } |
123 | } |
124 | } |
125 | } |
126 | } |
127 | if maybe_repr.is_none() { |
128 | die!("Missing `#[repr({Integer})]` attribute" ); |
129 | } |
130 | Ok((maybe_repr.unwrap(), maybe_error_type)) |
131 | } |
132 | } |
133 | |
134 | impl Parse for EnumInfo { |
135 | fn parse(input: ParseStream) -> Result<Self> { |
136 | Ok({ |
137 | let input: DeriveInput = input.parse()?; |
138 | let name = input.ident; |
139 | let data = match input.data { |
140 | Data::Enum(data) => data, |
141 | Data::Union(data) => die!(data.union_token => "Expected enum but found union" ), |
142 | Data::Struct(data) => die!(data.struct_token => "Expected enum but found struct" ), |
143 | }; |
144 | |
145 | let (repr, maybe_error_type) = Self::parse_attrs(input.attrs.into_iter())?; |
146 | |
147 | let mut variants: Vec<VariantInfo> = vec![]; |
148 | let mut has_default_variant: bool = false; |
149 | let mut has_catch_all_variant: bool = false; |
150 | |
151 | // Vec to keep track of the used discriminants and alt values. |
152 | let mut discriminant_int_val_set = BTreeSet::new(); |
153 | |
154 | let mut next_discriminant = literal(0); |
155 | for variant in data.variants.into_iter() { |
156 | let ident = variant.ident.clone(); |
157 | |
158 | let discriminant = match &variant.discriminant { |
159 | Some(d) => d.1.clone(), |
160 | None => next_discriminant.clone(), |
161 | }; |
162 | |
163 | let mut raw_alternative_values: Vec<Expr> = vec![]; |
164 | // Keep the attribute around for better error reporting. |
165 | let mut alt_attr_ref: Vec<&Attribute> = vec![]; |
166 | |
167 | // `#[num_enum(default)]` is required by `#[derive(FromPrimitive)]` |
168 | // and forbidden by `#[derive(UnsafeFromPrimitive)]`, so we need to |
169 | // keep track of whether we encountered such an attribute: |
170 | let mut is_default: bool = false; |
171 | let mut is_catch_all: bool = false; |
172 | |
173 | for attribute in &variant.attrs { |
174 | if attribute.path().is_ident("default" ) { |
175 | if has_default_variant { |
176 | die!(attribute => |
177 | "Multiple variants marked `#[default]` or `#[num_enum(default)]` found" |
178 | ); |
179 | } else if has_catch_all_variant { |
180 | die!(attribute => |
181 | "Attribute `default` is mutually exclusive with `catch_all`" |
182 | ); |
183 | } |
184 | is_default = true; |
185 | has_default_variant = true; |
186 | } |
187 | |
188 | if attribute.path().is_ident("num_enum" ) { |
189 | match attribute.parse_args_with(NumEnumVariantAttributes::parse) { |
190 | Ok(variant_attributes) => { |
191 | for variant_attribute in variant_attributes.items { |
192 | match variant_attribute { |
193 | NumEnumVariantAttributeItem::Default(default) => { |
194 | if has_default_variant { |
195 | die!(default.keyword => |
196 | "Multiple variants marked `#[default]` or `#[num_enum(default)]` found" |
197 | ); |
198 | } else if has_catch_all_variant { |
199 | die!(default.keyword => |
200 | "Attribute `default` is mutually exclusive with `catch_all`" |
201 | ); |
202 | } |
203 | is_default = true; |
204 | has_default_variant = true; |
205 | } |
206 | NumEnumVariantAttributeItem::CatchAll(catch_all) => { |
207 | if has_catch_all_variant { |
208 | die!(catch_all.keyword => |
209 | "Multiple variants marked with `#[num_enum(catch_all)]`" |
210 | ); |
211 | } else if has_default_variant { |
212 | die!(catch_all.keyword => |
213 | "Attribute `catch_all` is mutually exclusive with `default`" |
214 | ); |
215 | } |
216 | |
217 | match variant |
218 | .fields |
219 | .iter() |
220 | .collect::<Vec<_>>() |
221 | .as_slice() |
222 | { |
223 | [syn::Field { |
224 | ty: syn::Type::Path(syn::TypePath { path, .. }), |
225 | .. |
226 | }] if path.is_ident(&repr) => { |
227 | is_catch_all = true; |
228 | has_catch_all_variant = true; |
229 | } |
230 | _ => { |
231 | die!(catch_all.keyword => |
232 | "Variant with `catch_all` must be a tuple with exactly 1 field matching the repr type" |
233 | ); |
234 | } |
235 | } |
236 | } |
237 | NumEnumVariantAttributeItem::Alternatives(alternatives) => { |
238 | raw_alternative_values.extend(alternatives.expressions); |
239 | alt_attr_ref.push(attribute); |
240 | } |
241 | } |
242 | } |
243 | } |
244 | Err(err) => { |
245 | if cfg!(not(feature = "complex-expressions" )) { |
246 | let tokens = attribute.meta.to_token_stream(); |
247 | |
248 | let attribute_str = format!(" {}" , tokens); |
249 | if attribute_str.contains("alternatives" ) |
250 | && attribute_str.contains(".." ) |
251 | { |
252 | // Give a nice error message suggesting how to fix the problem. |
253 | die!(attribute => "Ranges are only supported as num_enum alternate values if the `complex-expressions` feature of the crate `num_enum` is enabled" .to_string()) |
254 | } |
255 | } |
256 | die!(attribute => |
257 | format!("Invalid attribute: {}" , err) |
258 | ); |
259 | } |
260 | } |
261 | } |
262 | } |
263 | |
264 | if !is_catch_all { |
265 | match &variant.fields { |
266 | Fields::Named(_) | Fields::Unnamed(_) => { |
267 | die!(variant => format!("` {}` only supports unit variants (with no associated data), but ` {}:: {}` was not a unit variant." , get_crate_name(), name, ident)); |
268 | } |
269 | Fields::Unit => {} |
270 | } |
271 | } |
272 | |
273 | let discriminant_value = parse_discriminant(&discriminant)?; |
274 | |
275 | // Check for collision. |
276 | // We can't do const evaluation, or even compare arbitrary Exprs, |
277 | // so unfortunately we can't check for duplicates. |
278 | // That's not the end of the world, just we'll end up with compile errors for |
279 | // matches with duplicate branches in generated code instead of nice friendly error messages. |
280 | if let DiscriminantValue::Literal(canonical_value_int) = discriminant_value { |
281 | if discriminant_int_val_set.contains(&canonical_value_int) { |
282 | die!(ident => format!("The discriminant ' {}' collides with a value attributed to a previous variant" , canonical_value_int)) |
283 | } |
284 | } |
285 | |
286 | // Deal with the alternative values. |
287 | let mut flattened_alternative_values = Vec::new(); |
288 | let mut flattened_raw_alternative_values = Vec::new(); |
289 | for raw_alternative_value in raw_alternative_values { |
290 | let expanded_values = parse_alternative_values(&raw_alternative_value)?; |
291 | for expanded_value in expanded_values { |
292 | flattened_alternative_values.push(expanded_value); |
293 | flattened_raw_alternative_values.push(raw_alternative_value.clone()) |
294 | } |
295 | } |
296 | |
297 | if !flattened_alternative_values.is_empty() { |
298 | let alternate_int_values = flattened_alternative_values |
299 | .into_iter() |
300 | .map(|v| { |
301 | match v { |
302 | DiscriminantValue::Literal(value) => Ok(value), |
303 | DiscriminantValue::Expr(expr) => { |
304 | if let Expr::Range(_) = expr { |
305 | if cfg!(not(feature = "complex-expressions" )) { |
306 | // Give a nice error message suggesting how to fix the problem. |
307 | die!(expr => "Ranges are only supported as num_enum alternate values if the `complex-expressions` feature of the crate `num_enum` is enabled" .to_string()) |
308 | } |
309 | } |
310 | // We can't do uniqueness checking on non-literals, so we don't allow them as alternate values. |
311 | // We could probably allow them, but there doesn't seem to be much of a use-case, |
312 | // and it's easier to give good error messages about duplicate values this way, |
313 | // rather than rustc errors on conflicting match branches. |
314 | die!(expr => "Only literals are allowed as num_enum alternate values" .to_string()) |
315 | }, |
316 | } |
317 | }) |
318 | .collect::<Result<Vec<i128>>>()?; |
319 | let mut sorted_alternate_int_values = alternate_int_values.clone(); |
320 | sorted_alternate_int_values.sort_unstable(); |
321 | let sorted_alternate_int_values = sorted_alternate_int_values; |
322 | |
323 | // Check if the current discriminant is not in the alternative values. |
324 | if let DiscriminantValue::Literal(canonical_value_int) = discriminant_value { |
325 | if let Some(index) = alternate_int_values |
326 | .iter() |
327 | .position(|&x| x == canonical_value_int) |
328 | { |
329 | die!(&flattened_raw_alternative_values[index] => format!("' {}' in the alternative values is already attributed as the discriminant of this variant" , canonical_value_int)); |
330 | } |
331 | } |
332 | |
333 | // Search for duplicates, the vec is sorted. Warn about them. |
334 | if (1..sorted_alternate_int_values.len()).any(|i| { |
335 | sorted_alternate_int_values[i] == sorted_alternate_int_values[i - 1] |
336 | }) { |
337 | let attr = *alt_attr_ref.last().unwrap(); |
338 | die!(attr => "There is duplication in the alternative values" ); |
339 | } |
340 | // Search if those discriminant_int_val_set where already attributed. |
341 | // (discriminant_int_val_set is BTreeSet, and iter().next_back() is the is the maximum in the set.) |
342 | if let Some(last_upper_val) = discriminant_int_val_set.iter().next_back() { |
343 | if sorted_alternate_int_values.first().unwrap() <= last_upper_val { |
344 | for (index, val) in alternate_int_values.iter().enumerate() { |
345 | if discriminant_int_val_set.contains(val) { |
346 | die!(&flattened_raw_alternative_values[index] => format!("' {}' in the alternative values is already attributed to a previous variant" , val)); |
347 | } |
348 | } |
349 | } |
350 | } |
351 | |
352 | // Reconstruct the alternative_values vec of Expr but sorted. |
353 | flattened_raw_alternative_values = sorted_alternate_int_values |
354 | .iter() |
355 | .map(|val| literal(val.to_owned())) |
356 | .collect(); |
357 | |
358 | // Add the alternative values to the the set to keep track. |
359 | discriminant_int_val_set.extend(sorted_alternate_int_values); |
360 | } |
361 | |
362 | // Add the current discriminant to the the set to keep track. |
363 | if let DiscriminantValue::Literal(canonical_value_int) = discriminant_value { |
364 | discriminant_int_val_set.insert(canonical_value_int); |
365 | } |
366 | |
367 | variants.push(VariantInfo { |
368 | ident, |
369 | is_default, |
370 | is_catch_all, |
371 | canonical_value: discriminant, |
372 | alternative_values: flattened_raw_alternative_values, |
373 | }); |
374 | |
375 | // Get the next value for the discriminant. |
376 | next_discriminant = match discriminant_value { |
377 | DiscriminantValue::Literal(int_value) => literal(int_value.wrapping_add(1)), |
378 | DiscriminantValue::Expr(expr) => { |
379 | parse_quote! { |
380 | #repr::wrapping_add(#expr, 1) |
381 | } |
382 | } |
383 | } |
384 | } |
385 | |
386 | let error_type_info = maybe_error_type.unwrap_or_else(|| { |
387 | let crate_name = Ident::new(&get_crate_name(), Span::call_site()); |
388 | ErrorType { |
389 | name: parse_quote! { |
390 | ::#crate_name::TryFromPrimitiveError<Self> |
391 | }, |
392 | constructor: parse_quote! { |
393 | ::#crate_name::TryFromPrimitiveError::<Self>::new |
394 | }, |
395 | } |
396 | }); |
397 | |
398 | EnumInfo { |
399 | name, |
400 | repr, |
401 | variants, |
402 | error_type_info, |
403 | } |
404 | }) |
405 | } |
406 | } |
407 | |
408 | fn literal(i: i128) -> Expr { |
409 | Expr::Lit(ExprLit { |
410 | lit: Lit::Int(LitInt::new(&i.to_string(), Span::call_site())), |
411 | attrs: vec![], |
412 | }) |
413 | } |
414 | |
415 | enum DiscriminantValue { |
416 | Literal(i128), |
417 | Expr(Expr), |
418 | } |
419 | |
420 | fn parse_discriminant(val_exp: &Expr) -> Result<DiscriminantValue> { |
421 | let mut sign: i128 = 1; |
422 | let mut unsigned_expr: &Expr = val_exp; |
423 | if let Expr::Unary(ExprUnary { |
424 | op: UnOp::Neg(..), |
425 | expr: &Box, |
426 | .. |
427 | }) = val_exp |
428 | { |
429 | unsigned_expr = expr; |
430 | sign = -1; |
431 | } |
432 | if let Expr::Lit(ExprLit { |
433 | lit: Lit::Int(ref lit_int: &LitInt), |
434 | .. |
435 | }) = unsigned_expr |
436 | { |
437 | Ok(DiscriminantValue::Literal( |
438 | sign * lit_int.base10_parse::<i128>()?, |
439 | )) |
440 | } else { |
441 | Ok(DiscriminantValue::Expr(val_exp.clone())) |
442 | } |
443 | } |
444 | |
445 | #[cfg (feature = "complex-expressions" )] |
446 | fn parse_alternative_values(val_expr: &Expr) -> Result<Vec<DiscriminantValue>> { |
447 | fn range_expr_value_to_number( |
448 | parent_range_expr: &Expr, |
449 | range_bound_value: &Option<Box<Expr>>, |
450 | ) -> Result<i128> { |
451 | // Avoid needing to calculate what the lower and upper bound would be - these are type dependent, |
452 | // and also may not be obvious in context (e.g. an omitted bound could reasonably mean "from the last discriminant" or "from the lower bound of the type"). |
453 | if let Some(range_bound_value) = range_bound_value { |
454 | let range_bound_value = parse_discriminant(range_bound_value.as_ref())?; |
455 | // If non-literals are used, we can't expand to the mapped values, so can't write a nice match statement or do exhaustiveness checking. |
456 | // Require literals instead. |
457 | if let DiscriminantValue::Literal(value) = range_bound_value { |
458 | return Ok(value); |
459 | } |
460 | } |
461 | die!(parent_range_expr => "When ranges are used for alternate values, both bounds most be explicitly specified numeric literals" ) |
462 | } |
463 | |
464 | if let Expr::Range(syn::ExprRange { |
465 | start, end, limits, .. |
466 | }) = val_expr |
467 | { |
468 | let lower = range_expr_value_to_number(val_expr, start)?; |
469 | let upper = range_expr_value_to_number(val_expr, end)?; |
470 | // While this is technically allowed in Rust, and results in an empty range, it's almost certainly a mistake in this context. |
471 | if lower > upper { |
472 | die!(val_expr => "When using ranges for alternate values, upper bound must not be less than lower bound" ); |
473 | } |
474 | let mut values = Vec::with_capacity((upper - lower) as usize); |
475 | let mut next = lower; |
476 | loop { |
477 | match limits { |
478 | syn::RangeLimits::HalfOpen(..) => { |
479 | if next == upper { |
480 | break; |
481 | } |
482 | } |
483 | syn::RangeLimits::Closed(..) => { |
484 | if next > upper { |
485 | break; |
486 | } |
487 | } |
488 | } |
489 | values.push(DiscriminantValue::Literal(next)); |
490 | next += 1; |
491 | } |
492 | return Ok(values); |
493 | } |
494 | parse_discriminant(val_expr).map(|v| vec![v]) |
495 | } |
496 | |
497 | #[cfg (not(feature = "complex-expressions" ))] |
498 | fn parse_alternative_values(val_expr: &Expr) -> Result<Vec<DiscriminantValue>> { |
499 | parse_discriminant(val_expr).map(|v: DiscriminantValue| vec![v]) |
500 | } |
501 | |
502 | pub(crate) struct VariantInfo { |
503 | ident: Ident, |
504 | is_default: bool, |
505 | is_catch_all: bool, |
506 | canonical_value: Expr, |
507 | alternative_values: Vec<Expr>, |
508 | } |
509 | |
510 | impl VariantInfo { |
511 | fn all_values(&self) -> impl Iterator<Item = &Expr> { |
512 | ::core::iter::once(&self.canonical_value).chain(self.alternative_values.iter()) |
513 | } |
514 | } |
515 | |
516 | pub(crate) struct ErrorType { |
517 | pub(crate) name: Path, |
518 | pub(crate) constructor: Path, |
519 | } |
520 | |
521 | impl From<ErrorTypeAttribute> for ErrorType { |
522 | fn from(attribute: ErrorTypeAttribute) -> Self { |
523 | Self { |
524 | name: attribute.name.path, |
525 | constructor: attribute.constructor.path, |
526 | } |
527 | } |
528 | } |
529 | |
530 | #[cfg (feature = "proc-macro-crate" )] |
531 | pub(crate) fn get_crate_name() -> String { |
532 | let found_crate: FoundCrate = proc_macro_crate::crate_name("num_enum" ).unwrap_or_else(|err: Error| { |
533 | eprintln!("Warning: {}\n => defaulting to `num_enum`" , err,); |
534 | proc_macro_crate::FoundCrate::Itself |
535 | }); |
536 | |
537 | match found_crate { |
538 | proc_macro_crate::FoundCrate::Itself => String::from("num_enum" ), |
539 | proc_macro_crate::FoundCrate::Name(name: String) => name, |
540 | } |
541 | } |
542 | |
543 | // Don't depend on proc-macro-crate in no_std environments because it causes an awkward dependency |
544 | // on serde with std. |
545 | // |
546 | // no_std dependees on num_enum cannot rename the num_enum crate when they depend on it. Sorry. |
547 | // |
548 | // See https://github.com/illicitonion/num_enum/issues/18 |
549 | #[cfg (not(feature = "proc-macro-crate" ))] |
550 | pub(crate) fn get_crate_name() -> String { |
551 | String::from("num_enum" ) |
552 | } |
553 | |