1 | //! Code generation for the function that initializes a python module and adds classes and function. |
2 | |
3 | use crate::{ |
4 | attributes::{self, take_attributes, take_pyo3_options, CrateAttribute, NameAttribute}, |
5 | pyfunction::{impl_wrap_pyfunction, PyFunctionOptions}, |
6 | utils::{get_pyo3_crate, PythonDoc}, |
7 | }; |
8 | use proc_macro2::TokenStream; |
9 | use quote::quote; |
10 | use syn::{ |
11 | ext::IdentExt, |
12 | parse::{Parse, ParseStream}, |
13 | spanned::Spanned, |
14 | token::Comma, |
15 | Ident, Path, Result, Visibility, |
16 | }; |
17 | |
18 | #[derive (Default)] |
19 | pub struct PyModuleOptions { |
20 | krate: Option<CrateAttribute>, |
21 | name: Option<syn::Ident>, |
22 | } |
23 | |
24 | impl PyModuleOptions { |
25 | pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> Result<Self> { |
26 | let mut options: PyModuleOptions = Default::default(); |
27 | |
28 | for option in take_pyo3_options(attrs)? { |
29 | match option { |
30 | PyModulePyO3Option::Name(name) => options.set_name(name.value.0)?, |
31 | PyModulePyO3Option::Crate(path) => options.set_crate(path)?, |
32 | } |
33 | } |
34 | |
35 | Ok(options) |
36 | } |
37 | |
38 | fn set_name(&mut self, name: syn::Ident) -> Result<()> { |
39 | ensure_spanned!( |
40 | self.name.is_none(), |
41 | name.span() => "`name` may only be specified once" |
42 | ); |
43 | |
44 | self.name = Some(name); |
45 | Ok(()) |
46 | } |
47 | |
48 | fn set_crate(&mut self, path: CrateAttribute) -> Result<()> { |
49 | ensure_spanned!( |
50 | self.krate.is_none(), |
51 | path.span() => "`crate` may only be specified once" |
52 | ); |
53 | |
54 | self.krate = Some(path); |
55 | Ok(()) |
56 | } |
57 | } |
58 | |
59 | /// Generates the function that is called by the python interpreter to initialize the native |
60 | /// module |
61 | pub fn pymodule_impl( |
62 | fnname: &Ident, |
63 | options: PyModuleOptions, |
64 | doc: PythonDoc, |
65 | visibility: &Visibility, |
66 | ) -> TokenStream { |
67 | let name = options.name.unwrap_or_else(|| fnname.unraw()); |
68 | let krate = get_pyo3_crate(&options.krate); |
69 | let pyinit_symbol = format!("PyInit_ {}" , name); |
70 | |
71 | quote! { |
72 | // Create a module with the same name as the `#[pymodule]` - this way `use <the module>` |
73 | // will actually bring both the module and the function into scope. |
74 | #[doc(hidden)] |
75 | #visibility mod #fnname { |
76 | pub(crate) struct MakeDef; |
77 | pub static DEF: #krate::impl_::pymodule::ModuleDef = MakeDef::make_def(); |
78 | pub const NAME: &'static str = concat!(stringify!(#name), " \0" ); |
79 | |
80 | /// This autogenerated function is called by the python interpreter when importing |
81 | /// the module. |
82 | #[export_name = #pyinit_symbol] |
83 | pub unsafe extern "C" fn init() -> *mut #krate::ffi::PyObject { |
84 | #krate::impl_::trampoline::module_init(|py| DEF.make_module(py)) |
85 | } |
86 | } |
87 | |
88 | // Generate the definition inside an anonymous function in the same scope as the original function - |
89 | // this avoids complications around the fact that the generated module has a different scope |
90 | // (and `super` doesn't always refer to the outer scope, e.g. if the `#[pymodule] is |
91 | // inside a function body) |
92 | const _: () = { |
93 | use #krate::impl_::pymodule as impl_; |
94 | impl #fnname::MakeDef { |
95 | const fn make_def() -> impl_::ModuleDef { |
96 | const INITIALIZER: impl_::ModuleInitializer = impl_::ModuleInitializer(#fnname); |
97 | unsafe { |
98 | impl_::ModuleDef::new(#fnname::NAME, #doc, INITIALIZER) |
99 | } |
100 | } |
101 | } |
102 | }; |
103 | } |
104 | } |
105 | |
106 | /// Finds and takes care of the #[pyfn(...)] in `#[pymodule]` |
107 | pub fn process_functions_in_module( |
108 | options: &PyModuleOptions, |
109 | func: &mut syn::ItemFn, |
110 | ) -> syn::Result<()> { |
111 | let mut stmts: Vec<syn::Stmt> = Vec::new(); |
112 | let krate: Path = get_pyo3_crate(&options.krate); |
113 | |
114 | for mut stmt: Stmt in func.block.stmts.drain(..) { |
115 | if let syn::Stmt::Item(syn::Item::Fn(func: &mut ItemFn)) = &mut stmt { |
116 | if let Some(pyfn_args: PyFnArgs) = get_pyfn_attr(&mut func.attrs)? { |
117 | let module_name: Path = pyfn_args.modname; |
118 | let wrapped_function: TokenStream = impl_wrap_pyfunction(func, pyfn_args.options)?; |
119 | let name: &Ident = &func.sig.ident; |
120 | let statements: Vec<syn::Stmt> = syn::parse_quote! { |
121 | #wrapped_function |
122 | #module_name.add_function(#krate::impl_::pyfunction::_wrap_pyfunction(&#name::DEF, #module_name)?)?; |
123 | }; |
124 | stmts.extend(iter:statements); |
125 | } |
126 | }; |
127 | stmts.push(stmt); |
128 | } |
129 | |
130 | func.block.stmts = stmts; |
131 | Ok(()) |
132 | } |
133 | |
134 | pub struct PyFnArgs { |
135 | modname: Path, |
136 | options: PyFunctionOptions, |
137 | } |
138 | |
139 | impl Parse for PyFnArgs { |
140 | fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { |
141 | let modname: Path = input.parse().map_err( |
142 | |e: Error| err_spanned!(e.span() => "expected module as first argument to #[pyfn()]" ), |
143 | )?; |
144 | |
145 | if input.is_empty() { |
146 | return Ok(Self { |
147 | modname, |
148 | options: Default::default(), |
149 | }); |
150 | } |
151 | |
152 | let _: Comma = input.parse()?; |
153 | |
154 | Ok(Self { |
155 | modname, |
156 | options: input.parse()?, |
157 | }) |
158 | } |
159 | } |
160 | |
161 | /// Extracts the data from the #[pyfn(...)] attribute of a function |
162 | fn get_pyfn_attr(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Option<PyFnArgs>> { |
163 | let mut pyfn_args: Option<PyFnArgs> = None; |
164 | |
165 | take_attributes(attrs, |attr: &Attribute| { |
166 | if attr.path().is_ident("pyfn" ) { |
167 | ensure_spanned!( |
168 | pyfn_args.is_none(), |
169 | attr.span() => "`#[pyfn] may only be specified once" |
170 | ); |
171 | pyfn_args = Some(attr.parse_args()?); |
172 | Ok(true) |
173 | } else { |
174 | Ok(false) |
175 | } |
176 | })?; |
177 | |
178 | if let Some(pyfn_args: &mut PyFnArgs) = &mut pyfn_args { |
179 | pyfn_args |
180 | .options |
181 | .add_attributes(attrs:take_pyo3_options(attrs)?)?; |
182 | } |
183 | |
184 | Ok(pyfn_args) |
185 | } |
186 | |
187 | enum PyModulePyO3Option { |
188 | Crate(CrateAttribute), |
189 | Name(NameAttribute), |
190 | } |
191 | |
192 | impl Parse for PyModulePyO3Option { |
193 | fn parse(input: ParseStream<'_>) -> Result<Self> { |
194 | let lookahead: Lookahead1<'_> = input.lookahead1(); |
195 | if lookahead.peek(token:attributes::kw::name) { |
196 | input.parse().map(op:PyModulePyO3Option::Name) |
197 | } else if lookahead.peek(syn::Token![crate]) { |
198 | input.parse().map(op:PyModulePyO3Option::Crate) |
199 | } else { |
200 | Err(lookahead.error()) |
201 | } |
202 | } |
203 | } |
204 | |