1 | use crate::utils::{AttrParams, DeriveType, State}; |
2 | use convert_case::{Case, Casing}; |
3 | use proc_macro2::TokenStream; |
4 | use quote::{format_ident, quote}; |
5 | use syn::{DeriveInput, Fields, Ident, Result}; |
6 | |
7 | pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStream> { |
8 | let state = State::with_attr_params( |
9 | input, |
10 | trait_name, |
11 | quote!(), |
12 | String::from("is_variant" ), |
13 | AttrParams { |
14 | enum_: vec!["ignore" ], |
15 | variant: vec!["ignore" ], |
16 | struct_: vec!["ignore" ], |
17 | field: vec!["ignore" ], |
18 | }, |
19 | )?; |
20 | assert!( |
21 | state.derive_type == DeriveType::Enum, |
22 | "IsVariant can only be derived for enums" |
23 | ); |
24 | |
25 | let enum_name = &input.ident; |
26 | let (imp_generics, type_generics, where_clause) = input.generics.split_for_impl(); |
27 | |
28 | let mut funcs = vec![]; |
29 | for variant_state in state.enabled_variant_data().variant_states { |
30 | let variant = variant_state.variant.unwrap(); |
31 | let fn_name = Ident::new( |
32 | &format_ident!("is_ {}" , variant.ident) |
33 | .to_string() |
34 | .to_case(Case::Snake), |
35 | variant.ident.span(), |
36 | ); |
37 | let variant_ident = &variant.ident; |
38 | |
39 | let data_pattern = match variant.fields { |
40 | Fields::Named(_) => quote! { {..} }, |
41 | Fields::Unnamed(_) => quote! { (..) }, |
42 | Fields::Unit => quote! {}, |
43 | }; |
44 | let func = quote! { |
45 | pub fn #fn_name(&self) -> bool { |
46 | match self { |
47 | #enum_name ::#variant_ident #data_pattern => true, |
48 | _ => false |
49 | } |
50 | } |
51 | }; |
52 | funcs.push(func); |
53 | } |
54 | |
55 | let imp = quote! { |
56 | impl #imp_generics #enum_name #type_generics #where_clause{ |
57 | #(#funcs)* |
58 | } |
59 | }; |
60 | |
61 | Ok(imp) |
62 | } |
63 | |