1//! Code generation for the function that initializes a python module and adds classes and function.
2
3use crate::{
4 attributes::{self, take_attributes, take_pyo3_options, CrateAttribute, NameAttribute},
5 pyfunction::{impl_wrap_pyfunction, PyFunctionOptions},
6 utils::{get_pyo3_crate, PythonDoc},
7};
8use proc_macro2::TokenStream;
9use quote::quote;
10use syn::{
11 ext::IdentExt,
12 parse::{Parse, ParseStream},
13 spanned::Spanned,
14 token::Comma,
15 Ident, Path, Result, Visibility,
16};
17
18#[derive(Default)]
19pub struct PyModuleOptions {
20 krate: Option<CrateAttribute>,
21 name: Option<syn::Ident>,
22}
23
24impl PyModuleOptions {
25 pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> Result<Self> {
26 let mut options: PyModuleOptions = Default::default();
27
28 for option in take_pyo3_options(attrs)? {
29 match option {
30 PyModulePyO3Option::Name(name) => options.set_name(name.value.0)?,
31 PyModulePyO3Option::Crate(path) => options.set_crate(path)?,
32 }
33 }
34
35 Ok(options)
36 }
37
38 fn set_name(&mut self, name: syn::Ident) -> Result<()> {
39 ensure_spanned!(
40 self.name.is_none(),
41 name.span() => "`name` may only be specified once"
42 );
43
44 self.name = Some(name);
45 Ok(())
46 }
47
48 fn set_crate(&mut self, path: CrateAttribute) -> Result<()> {
49 ensure_spanned!(
50 self.krate.is_none(),
51 path.span() => "`crate` may only be specified once"
52 );
53
54 self.krate = Some(path);
55 Ok(())
56 }
57}
58
59/// Generates the function that is called by the python interpreter to initialize the native
60/// module
61pub fn pymodule_impl(
62 fnname: &Ident,
63 options: PyModuleOptions,
64 doc: PythonDoc,
65 visibility: &Visibility,
66) -> TokenStream {
67 let name = options.name.unwrap_or_else(|| fnname.unraw());
68 let krate = get_pyo3_crate(&options.krate);
69 let pyinit_symbol = format!("PyInit_{}", name);
70
71 quote! {
72 // Create a module with the same name as the `#[pymodule]` - this way `use <the module>`
73 // will actually bring both the module and the function into scope.
74 #[doc(hidden)]
75 #visibility mod #fnname {
76 pub(crate) struct MakeDef;
77 pub static DEF: #krate::impl_::pymodule::ModuleDef = MakeDef::make_def();
78 pub const NAME: &'static str = concat!(stringify!(#name), "\0");
79
80 /// This autogenerated function is called by the python interpreter when importing
81 /// the module.
82 #[export_name = #pyinit_symbol]
83 pub unsafe extern "C" fn init() -> *mut #krate::ffi::PyObject {
84 #krate::impl_::trampoline::module_init(|py| DEF.make_module(py))
85 }
86 }
87
88 // Generate the definition inside an anonymous function in the same scope as the original function -
89 // this avoids complications around the fact that the generated module has a different scope
90 // (and `super` doesn't always refer to the outer scope, e.g. if the `#[pymodule] is
91 // inside a function body)
92 const _: () = {
93 use #krate::impl_::pymodule as impl_;
94 impl #fnname::MakeDef {
95 const fn make_def() -> impl_::ModuleDef {
96 const INITIALIZER: impl_::ModuleInitializer = impl_::ModuleInitializer(#fnname);
97 unsafe {
98 impl_::ModuleDef::new(#fnname::NAME, #doc, INITIALIZER)
99 }
100 }
101 }
102 };
103 }
104}
105
106/// Finds and takes care of the #[pyfn(...)] in `#[pymodule]`
107pub fn process_functions_in_module(
108 options: &PyModuleOptions,
109 func: &mut syn::ItemFn,
110) -> syn::Result<()> {
111 let mut stmts: Vec<syn::Stmt> = Vec::new();
112 let krate: Path = get_pyo3_crate(&options.krate);
113
114 for mut stmt: Stmt in func.block.stmts.drain(..) {
115 if let syn::Stmt::Item(syn::Item::Fn(func: &mut ItemFn)) = &mut stmt {
116 if let Some(pyfn_args: PyFnArgs) = get_pyfn_attr(&mut func.attrs)? {
117 let module_name: Path = pyfn_args.modname;
118 let wrapped_function: TokenStream = impl_wrap_pyfunction(func, pyfn_args.options)?;
119 let name: &Ident = &func.sig.ident;
120 let statements: Vec<syn::Stmt> = syn::parse_quote! {
121 #wrapped_function
122 #module_name.add_function(#krate::impl_::pyfunction::_wrap_pyfunction(&#name::DEF, #module_name)?)?;
123 };
124 stmts.extend(iter:statements);
125 }
126 };
127 stmts.push(stmt);
128 }
129
130 func.block.stmts = stmts;
131 Ok(())
132}
133
134pub struct PyFnArgs {
135 modname: Path,
136 options: PyFunctionOptions,
137}
138
139impl Parse for PyFnArgs {
140 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
141 let modname: Path = input.parse().map_err(
142 |e: Error| err_spanned!(e.span() => "expected module as first argument to #[pyfn()]"),
143 )?;
144
145 if input.is_empty() {
146 return Ok(Self {
147 modname,
148 options: Default::default(),
149 });
150 }
151
152 let _: Comma = input.parse()?;
153
154 Ok(Self {
155 modname,
156 options: input.parse()?,
157 })
158 }
159}
160
161/// Extracts the data from the #[pyfn(...)] attribute of a function
162fn get_pyfn_attr(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Option<PyFnArgs>> {
163 let mut pyfn_args: Option<PyFnArgs> = None;
164
165 take_attributes(attrs, |attr: &Attribute| {
166 if attr.path().is_ident("pyfn") {
167 ensure_spanned!(
168 pyfn_args.is_none(),
169 attr.span() => "`#[pyfn] may only be specified once"
170 );
171 pyfn_args = Some(attr.parse_args()?);
172 Ok(true)
173 } else {
174 Ok(false)
175 }
176 })?;
177
178 if let Some(pyfn_args: &mut PyFnArgs) = &mut pyfn_args {
179 pyfn_args
180 .options
181 .add_attributes(attrs:take_pyo3_options(attrs)?)?;
182 }
183
184 Ok(pyfn_args)
185}
186
187enum PyModulePyO3Option {
188 Crate(CrateAttribute),
189 Name(NameAttribute),
190}
191
192impl Parse for PyModulePyO3Option {
193 fn parse(input: ParseStream<'_>) -> Result<Self> {
194 let lookahead: Lookahead1<'_> = input.lookahead1();
195 if lookahead.peek(token:attributes::kw::name) {
196 input.parse().map(op:PyModulePyO3Option::Name)
197 } else if lookahead.peek(syn::Token![crate]) {
198 input.parse().map(op:PyModulePyO3Option::Crate)
199 } else {
200 Err(lookahead.error())
201 }
202 }
203}
204