1 | use proc_macro2::TokenStream as TokenStream2; |
2 | use quote::{quote, ToTokens, TokenStreamExt}; |
3 | use syn::{parse_quote, Error, ItemStruct, Result, Type}; |
4 | |
5 | use super::{extract_documents, parse_pyo3_attrs, util::quote_option, Attr, MemberInfo, StubType}; |
6 | |
7 | pub struct PyClassInfo { |
8 | pyclass_name: String, |
9 | struct_type: Type, |
10 | module: Option<String>, |
11 | members: Vec<MemberInfo>, |
12 | doc: String, |
13 | } |
14 | |
15 | impl From<&PyClassInfo> for StubType { |
16 | fn from(info: &PyClassInfo) -> Self { |
17 | let PyClassInfo { |
18 | pyclass_name: &String, |
19 | module: &Option, |
20 | struct_type: &Type, |
21 | .. |
22 | } = info; |
23 | Self { |
24 | ty: struct_type.clone(), |
25 | name: pyclass_name.clone(), |
26 | module: module.clone(), |
27 | } |
28 | } |
29 | } |
30 | |
31 | impl TryFrom<ItemStruct> for PyClassInfo { |
32 | type Error = Error; |
33 | fn try_from(item: ItemStruct) -> Result<Self> { |
34 | let ItemStruct { |
35 | ident, |
36 | attrs, |
37 | fields, |
38 | .. |
39 | } = item; |
40 | let struct_type: Type = parse_quote!(#ident); |
41 | let mut pyclass_name = None; |
42 | let mut module = None; |
43 | let mut is_get_all = false; |
44 | for attr in parse_pyo3_attrs(&attrs)? { |
45 | match attr { |
46 | Attr::Name(name) => pyclass_name = Some(name), |
47 | Attr::Module(name) => { |
48 | module = Some(name); |
49 | } |
50 | Attr::GetAll => is_get_all = true, |
51 | _ => {} |
52 | } |
53 | } |
54 | let pyclass_name = pyclass_name.unwrap_or_else(|| ident.to_string()); |
55 | let mut members = Vec::new(); |
56 | for field in fields { |
57 | if is_get_all || MemberInfo::is_candidate_field(&field)? { |
58 | members.push(MemberInfo::try_from(field)?) |
59 | } |
60 | } |
61 | let doc = extract_documents(&attrs).join(" \n" ); |
62 | Ok(Self { |
63 | struct_type, |
64 | pyclass_name, |
65 | members, |
66 | module, |
67 | doc, |
68 | }) |
69 | } |
70 | } |
71 | |
72 | impl ToTokens for PyClassInfo { |
73 | fn to_tokens(&self, tokens: &mut TokenStream2) { |
74 | let Self { |
75 | pyclass_name: &String, |
76 | struct_type: &Type, |
77 | members: &Vec, |
78 | doc: &String, |
79 | module: &Option, |
80 | } = self; |
81 | let module: TokenStream = quote_option(module); |
82 | tokens.append_all(iter:quote! { |
83 | ::pyo3_stub_gen::type_info::PyClassInfo { |
84 | pyclass_name: #pyclass_name, |
85 | struct_id: std::any::TypeId::of::<#struct_type>, |
86 | members: &[ #( #members),* ], |
87 | module: #module, |
88 | doc: #doc, |
89 | } |
90 | }) |
91 | } |
92 | } |
93 | |
94 | #[cfg (test)] |
95 | mod test { |
96 | use super::*; |
97 | use syn::parse_str; |
98 | |
99 | #[test ] |
100 | fn test_pyclass() -> Result<()> { |
101 | let input: ItemStruct = parse_str( |
102 | r#" |
103 | #[pyclass(mapping, module = "my_module", name = "Placeholder")] |
104 | #[derive( |
105 | Debug, Clone, PyNeg, PyAdd, PySub, PyMul, PyDiv, PyMod, PyPow, PyCmp, PyIndex, PyPrint, |
106 | )] |
107 | pub struct PyPlaceholder { |
108 | #[pyo3(get)] |
109 | pub name: String, |
110 | #[pyo3(get)] |
111 | pub ndim: usize, |
112 | #[pyo3(get)] |
113 | pub description: Option<String>, |
114 | pub custom_latex: Option<String>, |
115 | } |
116 | "# , |
117 | )?; |
118 | let out = PyClassInfo::try_from(input)?.to_token_stream(); |
119 | insta::assert_snapshot!(format_as_value(out), @r###" |
120 | ::pyo3_stub_gen::type_info::PyClassInfo { |
121 | pyclass_name: "Placeholder", |
122 | struct_id: std::any::TypeId::of::<PyPlaceholder>, |
123 | members: &[ |
124 | ::pyo3_stub_gen::type_info::MemberInfo { |
125 | name: "name", |
126 | r#type: <String as ::pyo3_stub_gen::PyStubType>::type_output, |
127 | }, |
128 | ::pyo3_stub_gen::type_info::MemberInfo { |
129 | name: "ndim", |
130 | r#type: <usize as ::pyo3_stub_gen::PyStubType>::type_output, |
131 | }, |
132 | ::pyo3_stub_gen::type_info::MemberInfo { |
133 | name: "description", |
134 | r#type: <Option<String> as ::pyo3_stub_gen::PyStubType>::type_output, |
135 | }, |
136 | ], |
137 | module: Some("my_module"), |
138 | doc: "", |
139 | } |
140 | "### ); |
141 | Ok(()) |
142 | } |
143 | |
144 | fn format_as_value(tt: TokenStream2) -> String { |
145 | let ttt = quote! { const _: () = #tt; }; |
146 | let formatted = prettyplease::unparse(&syn::parse_file(&ttt.to_string()).unwrap()); |
147 | formatted |
148 | .trim() |
149 | .strip_prefix("const _: () = " ) |
150 | .unwrap() |
151 | .strip_suffix(';' ) |
152 | .unwrap() |
153 | .to_string() |
154 | } |
155 | } |
156 | |