1use proc_macro2::TokenStream as TokenStream2;
2use quote::{quote, ToTokens, TokenStreamExt};
3use syn::{parse_quote, Error, ItemEnum, Result, Type};
4
5use super::{extract_documents, parse_pyo3_attrs, util::quote_option, Attr, StubType};
6
7pub struct PyEnumInfo {
8 pyclass_name: String,
9 enum_type: Type,
10 module: Option<String>,
11 variants: Vec<String>,
12 doc: String,
13}
14
15impl From<&PyEnumInfo> for StubType {
16 fn from(info: &PyEnumInfo) -> Self {
17 let PyEnumInfo {
18 pyclass_name: &String,
19 module: &Option,
20 enum_type: &Type,
21 ..
22 } = info;
23 Self {
24 ty: enum_type.clone(),
25 name: pyclass_name.clone(),
26 module: module.clone(),
27 }
28 }
29}
30
31impl TryFrom<ItemEnum> for PyEnumInfo {
32 type Error = Error;
33 fn try_from(
34 ItemEnum {
35 variants,
36 attrs,
37 ident,
38 ..
39 }: ItemEnum,
40 ) -> Result<Self> {
41 let doc = extract_documents(&attrs).join("\n");
42 let mut pyclass_name = None;
43 let mut module = None;
44 for attr in parse_pyo3_attrs(&attrs)? {
45 match attr {
46 Attr::Name(name) => pyclass_name = Some(name),
47 Attr::Module(name) => module = Some(name),
48 _ => {}
49 }
50 }
51 let struct_type = parse_quote!(#ident);
52 let pyclass_name = pyclass_name.unwrap_or_else(|| ident.to_string());
53 let variants = variants
54 .into_iter()
55 .map(|var| -> Result<String> {
56 let mut var_name = None;
57 for attr in parse_pyo3_attrs(&var.attrs)? {
58 if let Attr::Name(name) = attr {
59 var_name = Some(name);
60 }
61 }
62 Ok(var_name.unwrap_or_else(|| var.ident.to_string()))
63 })
64 .collect::<Result<Vec<String>>>()?;
65 Ok(Self {
66 doc,
67 enum_type: struct_type,
68 pyclass_name,
69 module,
70 variants,
71 })
72 }
73}
74
75impl ToTokens for PyEnumInfo {
76 fn to_tokens(&self, tokens: &mut TokenStream2) {
77 let Self {
78 pyclass_name: &String,
79 enum_type: &Type,
80 variants: &Vec,
81 doc: &String,
82 module: &Option,
83 } = self;
84 let module: TokenStream = quote_option(module);
85 tokens.append_all(iter:quote! {
86 ::pyo3_stub_gen::type_info::PyEnumInfo {
87 pyclass_name: #pyclass_name,
88 enum_id: std::any::TypeId::of::<#enum_type>,
89 variants: &[ #(#variants),* ],
90 module: #module,
91 doc: #doc,
92 }
93 })
94 }
95}
96