1 | use crate::utils::{AttrParams, DeriveType, State}; |
2 | use convert_case::{Case, Casing}; |
3 | use proc_macro2::TokenStream; |
4 | use quote::{format_ident, quote}; |
5 | use syn::{DeriveInput, Fields, Ident, Result}; |
6 | |
7 | pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStream> { |
8 | let state = State::with_attr_params( |
9 | input, |
10 | trait_name, |
11 | quote!(), |
12 | String::from("unwrap" ), |
13 | AttrParams { |
14 | enum_: vec!["ignore" ], |
15 | variant: vec!["ignore" ], |
16 | struct_: vec!["ignore" ], |
17 | field: vec!["ignore" ], |
18 | }, |
19 | )?; |
20 | assert!( |
21 | state.derive_type == DeriveType::Enum, |
22 | "Unwrap can only be derived for enums" |
23 | ); |
24 | |
25 | let enum_name = &input.ident; |
26 | let (imp_generics, type_generics, where_clause) = input.generics.split_for_impl(); |
27 | |
28 | let mut funcs = vec![]; |
29 | for variant_state in state.enabled_variant_data().variant_states { |
30 | let variant = variant_state.variant.unwrap(); |
31 | let fn_name = Ident::new( |
32 | &format_ident!("unwrap_ {}" , variant.ident) |
33 | .to_string() |
34 | .to_case(Case::Snake), |
35 | variant.ident.span(), |
36 | ); |
37 | let variant_ident = &variant.ident; |
38 | |
39 | let (data_pattern, ret_value, ret_type) = match variant.fields { |
40 | Fields::Named(_) => panic!("cannot unwrap anonymous records" ), |
41 | Fields::Unnamed(ref fields) => { |
42 | let data_pattern = |
43 | (0..fields.unnamed.len()).fold(vec![], |mut a, n| { |
44 | a.push(format_ident!("field_ {}" , n)); |
45 | a |
46 | }); |
47 | let ret_type = &fields.unnamed; |
48 | ( |
49 | quote! { (#(#data_pattern),*) }, |
50 | quote! { (#(#data_pattern),*) }, |
51 | quote! { (#ret_type) }, |
52 | ) |
53 | } |
54 | Fields::Unit => (quote! {}, quote! { () }, quote! { () }), |
55 | }; |
56 | |
57 | let other_arms = state.variant_states.iter().map(|variant| { |
58 | variant.variant.unwrap() |
59 | }).filter(|variant| { |
60 | &variant.ident != variant_ident |
61 | }).map(|variant| { |
62 | let data_pattern = match variant.fields { |
63 | Fields::Named(_) => quote! { {..} }, |
64 | Fields::Unnamed(_) => quote! { (..) }, |
65 | Fields::Unit => quote! {}, |
66 | }; |
67 | let variant_ident = &variant.ident; |
68 | quote! { #enum_name :: #variant_ident #data_pattern => |
69 | panic!(concat!("called `" , stringify!(#enum_name), "::" , stringify!(#fn_name), |
70 | "()` on a `" , stringify!(#variant_ident), "` value" )) |
71 | } |
72 | }); |
73 | |
74 | // The `track-caller` feature is set by our build script based |
75 | // on rustc version detection, as `#[track_caller]` was |
76 | // stabilized in a later version (1.46) of Rust than our MSRV (1.36). |
77 | let track_caller = if cfg!(feature = "track-caller" ) { |
78 | quote! { #[track_caller] } |
79 | } else { |
80 | quote! {} |
81 | }; |
82 | let func = quote! { |
83 | #track_caller |
84 | pub fn #fn_name(self) -> #ret_type { |
85 | match self { |
86 | #enum_name ::#variant_ident #data_pattern => #ret_value, |
87 | #(#other_arms),* |
88 | } |
89 | } |
90 | }; |
91 | funcs.push(func); |
92 | } |
93 | |
94 | let imp = quote! { |
95 | impl #imp_generics #enum_name #type_generics #where_clause{ |
96 | #(#funcs)* |
97 | } |
98 | }; |
99 | |
100 | Ok(imp) |
101 | } |
102 | |