| 1 | use crate::utils::{AttrParams, DeriveType, State}; |
| 2 | use convert_case::{Case, Casing}; |
| 3 | use proc_macro2::TokenStream; |
| 4 | use quote::{format_ident, quote}; |
| 5 | use syn::{DeriveInput, Fields, Result}; |
| 6 | |
| 7 | pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStream> { |
| 8 | let state = State::with_attr_params( |
| 9 | input, |
| 10 | trait_name, |
| 11 | "is_variant" .into(), |
| 12 | AttrParams { |
| 13 | enum_: vec!["ignore" ], |
| 14 | variant: vec!["ignore" ], |
| 15 | struct_: vec!["ignore" ], |
| 16 | field: vec!["ignore" ], |
| 17 | }, |
| 18 | )?; |
| 19 | assert!( |
| 20 | state.derive_type == DeriveType::Enum, |
| 21 | "IsVariant can only be derived for enums" , |
| 22 | ); |
| 23 | |
| 24 | let enum_name = &input.ident; |
| 25 | let (imp_generics, type_generics, where_clause) = input.generics.split_for_impl(); |
| 26 | |
| 27 | let mut funcs = vec![]; |
| 28 | for variant_state in state.enabled_variant_data().variant_states { |
| 29 | let variant = variant_state.variant.unwrap(); |
| 30 | let fn_name = format_ident!( |
| 31 | "is_ {}" , |
| 32 | variant.ident.to_string().to_case(Case::Snake), |
| 33 | span = variant.ident.span(), |
| 34 | ); |
| 35 | let variant_ident = &variant.ident; |
| 36 | |
| 37 | let data_pattern = match variant.fields { |
| 38 | Fields::Named(_) => quote! { {..} }, |
| 39 | Fields::Unnamed(_) => quote! { (..) }, |
| 40 | Fields::Unit => quote! {}, |
| 41 | }; |
| 42 | let func = quote! { |
| 43 | #[doc = "Returns `true` if this value is of type `" ] |
| 44 | #[doc = stringify!(#variant_ident)] |
| 45 | #[doc = "`. Returns `false` otherwise" ] |
| 46 | #[inline] |
| 47 | #[must_use] |
| 48 | pub const fn #fn_name(&self) -> bool { |
| 49 | derive_more::core::matches!(self, #enum_name ::#variant_ident #data_pattern) |
| 50 | } |
| 51 | }; |
| 52 | funcs.push(func); |
| 53 | } |
| 54 | |
| 55 | let imp = quote! { |
| 56 | #[allow(unreachable_code)] // omit warnings for `!` and other unreachable types |
| 57 | #[automatically_derived] |
| 58 | impl #imp_generics #enum_name #type_generics #where_clause { |
| 59 | #(#funcs)* |
| 60 | } |
| 61 | }; |
| 62 | |
| 63 | Ok(imp) |
| 64 | } |
| 65 | |