1use crate::utils::{AttrParams, DeriveType, State};
2use convert_case::{Case, Casing};
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5use syn::{DeriveInput, Fields, Ident, Result};
6
7pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStream> {
8 let state = State::with_attr_params(
9 input,
10 trait_name,
11 quote!(),
12 String::from("unwrap"),
13 AttrParams {
14 enum_: vec!["ignore"],
15 variant: vec!["ignore"],
16 struct_: vec!["ignore"],
17 field: vec!["ignore"],
18 },
19 )?;
20 assert!(
21 state.derive_type == DeriveType::Enum,
22 "Unwrap can only be derived for enums"
23 );
24
25 let enum_name = &input.ident;
26 let (imp_generics, type_generics, where_clause) = input.generics.split_for_impl();
27
28 let mut funcs = vec![];
29 for variant_state in state.enabled_variant_data().variant_states {
30 let variant = variant_state.variant.unwrap();
31 let fn_name = Ident::new(
32 &format_ident!("unwrap_{}", variant.ident)
33 .to_string()
34 .to_case(Case::Snake),
35 variant.ident.span(),
36 );
37 let variant_ident = &variant.ident;
38
39 let (data_pattern, ret_value, ret_type) = match variant.fields {
40 Fields::Named(_) => panic!("cannot unwrap anonymous records"),
41 Fields::Unnamed(ref fields) => {
42 let data_pattern =
43 (0..fields.unnamed.len()).fold(vec![], |mut a, n| {
44 a.push(format_ident!("field_{}", n));
45 a
46 });
47 let ret_type = &fields.unnamed;
48 (
49 quote! { (#(#data_pattern),*) },
50 quote! { (#(#data_pattern),*) },
51 quote! { (#ret_type) },
52 )
53 }
54 Fields::Unit => (quote! {}, quote! { () }, quote! { () }),
55 };
56
57 let other_arms = state.variant_states.iter().map(|variant| {
58 variant.variant.unwrap()
59 }).filter(|variant| {
60 &variant.ident != variant_ident
61 }).map(|variant| {
62 let data_pattern = match variant.fields {
63 Fields::Named(_) => quote! { {..} },
64 Fields::Unnamed(_) => quote! { (..) },
65 Fields::Unit => quote! {},
66 };
67 let variant_ident = &variant.ident;
68 quote! { #enum_name :: #variant_ident #data_pattern =>
69 panic!(concat!("called `", stringify!(#enum_name), "::", stringify!(#fn_name),
70 "()` on a `", stringify!(#variant_ident), "` value"))
71 }
72 });
73
74 // The `track-caller` feature is set by our build script based
75 // on rustc version detection, as `#[track_caller]` was
76 // stabilized in a later version (1.46) of Rust than our MSRV (1.36).
77 let track_caller = if cfg!(feature = "track-caller") {
78 quote! { #[track_caller] }
79 } else {
80 quote! {}
81 };
82 let func = quote! {
83 #track_caller
84 pub fn #fn_name(self) -> #ret_type {
85 match self {
86 #enum_name ::#variant_ident #data_pattern => #ret_value,
87 #(#other_arms),*
88 }
89 }
90 };
91 funcs.push(func);
92 }
93
94 let imp = quote! {
95 impl #imp_generics #enum_name #type_generics #where_clause{
96 #(#funcs)*
97 }
98 };
99
100 Ok(imp)
101}
102