1 | use crate::{ |
2 | attributes::{ |
3 | self, get_pyo3_options, take_attributes, take_pyo3_options, CrateAttribute, |
4 | FromPyWithAttribute, NameAttribute, TextSignatureAttribute, |
5 | }, |
6 | deprecations::Deprecations, |
7 | method::{self, CallingConvention, FnArg}, |
8 | pymethod::check_generic, |
9 | utils::{ensure_not_async_fn, get_pyo3_crate}, |
10 | }; |
11 | use proc_macro2::TokenStream; |
12 | use quote::{format_ident, quote}; |
13 | use syn::{ext::IdentExt, spanned::Spanned, Result}; |
14 | use syn::{ |
15 | parse::{Parse, ParseStream}, |
16 | token::Comma, |
17 | }; |
18 | |
19 | mod signature; |
20 | |
21 | pub use self::signature::{FunctionSignature, SignatureAttribute}; |
22 | |
23 | #[derive (Clone, Debug)] |
24 | pub struct PyFunctionArgPyO3Attributes { |
25 | pub from_py_with: Option<FromPyWithAttribute>, |
26 | } |
27 | |
28 | enum PyFunctionArgPyO3Attribute { |
29 | FromPyWith(FromPyWithAttribute), |
30 | } |
31 | |
32 | impl Parse for PyFunctionArgPyO3Attribute { |
33 | fn parse(input: ParseStream<'_>) -> Result<Self> { |
34 | let lookahead: Lookahead1<'_> = input.lookahead1(); |
35 | if lookahead.peek(token:attributes::kw::from_py_with) { |
36 | input.parse().map(op:PyFunctionArgPyO3Attribute::FromPyWith) |
37 | } else { |
38 | Err(lookahead.error()) |
39 | } |
40 | } |
41 | } |
42 | |
43 | impl PyFunctionArgPyO3Attributes { |
44 | /// Parses #[pyo3(from_python_with = "func")] |
45 | pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> { |
46 | let mut attributes = PyFunctionArgPyO3Attributes { from_py_with: None }; |
47 | take_attributes(attrs, |attr| { |
48 | if let Some(pyo3_attrs) = get_pyo3_options(attr)? { |
49 | for attr in pyo3_attrs { |
50 | match attr { |
51 | PyFunctionArgPyO3Attribute::FromPyWith(from_py_with) => { |
52 | ensure_spanned!( |
53 | attributes.from_py_with.is_none(), |
54 | from_py_with.span() => "`from_py_with` may only be specified once per argument" |
55 | ); |
56 | attributes.from_py_with = Some(from_py_with); |
57 | } |
58 | } |
59 | } |
60 | Ok(true) |
61 | } else { |
62 | Ok(false) |
63 | } |
64 | })?; |
65 | Ok(attributes) |
66 | } |
67 | } |
68 | |
69 | #[derive (Default)] |
70 | pub struct PyFunctionOptions { |
71 | pub pass_module: Option<attributes::kw::pass_module>, |
72 | pub name: Option<NameAttribute>, |
73 | pub signature: Option<SignatureAttribute>, |
74 | pub text_signature: Option<TextSignatureAttribute>, |
75 | pub krate: Option<CrateAttribute>, |
76 | } |
77 | |
78 | impl Parse for PyFunctionOptions { |
79 | fn parse(input: ParseStream<'_>) -> Result<Self> { |
80 | let mut options = PyFunctionOptions::default(); |
81 | |
82 | while !input.is_empty() { |
83 | let lookahead = input.lookahead1(); |
84 | if lookahead.peek(attributes::kw::name) |
85 | || lookahead.peek(attributes::kw::pass_module) |
86 | || lookahead.peek(attributes::kw::signature) |
87 | || lookahead.peek(attributes::kw::text_signature) |
88 | { |
89 | options.add_attributes(std::iter::once(input.parse()?))?; |
90 | if !input.is_empty() { |
91 | let _: Comma = input.parse()?; |
92 | } |
93 | } else if lookahead.peek(syn::Token![crate]) { |
94 | // TODO needs duplicate check? |
95 | options.krate = Some(input.parse()?); |
96 | } else { |
97 | return Err(lookahead.error()); |
98 | } |
99 | } |
100 | |
101 | Ok(options) |
102 | } |
103 | } |
104 | |
105 | pub enum PyFunctionOption { |
106 | Name(NameAttribute), |
107 | PassModule(attributes::kw::pass_module), |
108 | Signature(SignatureAttribute), |
109 | TextSignature(TextSignatureAttribute), |
110 | Crate(CrateAttribute), |
111 | } |
112 | |
113 | impl Parse for PyFunctionOption { |
114 | fn parse(input: ParseStream<'_>) -> Result<Self> { |
115 | let lookahead: Lookahead1<'_> = input.lookahead1(); |
116 | if lookahead.peek(token:attributes::kw::name) { |
117 | input.parse().map(op:PyFunctionOption::Name) |
118 | } else if lookahead.peek(token:attributes::kw::pass_module) { |
119 | input.parse().map(op:PyFunctionOption::PassModule) |
120 | } else if lookahead.peek(token:attributes::kw::signature) { |
121 | input.parse().map(op:PyFunctionOption::Signature) |
122 | } else if lookahead.peek(token:attributes::kw::text_signature) { |
123 | input.parse().map(op:PyFunctionOption::TextSignature) |
124 | } else if lookahead.peek(syn::Token![crate]) { |
125 | input.parse().map(op:PyFunctionOption::Crate) |
126 | } else { |
127 | Err(lookahead.error()) |
128 | } |
129 | } |
130 | } |
131 | |
132 | impl PyFunctionOptions { |
133 | pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> { |
134 | let mut options = PyFunctionOptions::default(); |
135 | options.add_attributes(take_pyo3_options(attrs)?)?; |
136 | Ok(options) |
137 | } |
138 | |
139 | pub fn add_attributes( |
140 | &mut self, |
141 | attrs: impl IntoIterator<Item = PyFunctionOption>, |
142 | ) -> Result<()> { |
143 | macro_rules! set_option { |
144 | ($key:ident) => { |
145 | { |
146 | ensure_spanned!( |
147 | self.$key.is_none(), |
148 | $key.span() => concat!("`" , stringify!($key), "` may only be specified once" ) |
149 | ); |
150 | self.$key = Some($key); |
151 | } |
152 | }; |
153 | } |
154 | for attr in attrs { |
155 | match attr { |
156 | PyFunctionOption::Name(name) => set_option!(name), |
157 | PyFunctionOption::PassModule(pass_module) => set_option!(pass_module), |
158 | PyFunctionOption::Signature(signature) => set_option!(signature), |
159 | PyFunctionOption::TextSignature(text_signature) => set_option!(text_signature), |
160 | PyFunctionOption::Crate(krate) => set_option!(krate), |
161 | } |
162 | } |
163 | Ok(()) |
164 | } |
165 | } |
166 | |
167 | pub fn build_py_function( |
168 | ast: &mut syn::ItemFn, |
169 | mut options: PyFunctionOptions, |
170 | ) -> syn::Result<TokenStream> { |
171 | options.add_attributes(attrs:take_pyo3_options(&mut ast.attrs)?)?; |
172 | impl_wrap_pyfunction(func:ast, options) |
173 | } |
174 | |
175 | /// Generates python wrapper over a function that allows adding it to a python module as a python |
176 | /// function |
177 | pub fn impl_wrap_pyfunction( |
178 | func: &mut syn::ItemFn, |
179 | options: PyFunctionOptions, |
180 | ) -> syn::Result<TokenStream> { |
181 | check_generic(&func.sig)?; |
182 | ensure_not_async_fn(&func.sig)?; |
183 | |
184 | let PyFunctionOptions { |
185 | pass_module, |
186 | name, |
187 | signature, |
188 | text_signature, |
189 | krate, |
190 | } = options; |
191 | |
192 | let python_name = name.map_or_else(|| func.sig.ident.unraw(), |name| name.value.0); |
193 | |
194 | let tp = if pass_module.is_some() { |
195 | let span = match func.sig.inputs.first() { |
196 | Some(syn::FnArg::Typed(first_arg)) => first_arg.ty.span(), |
197 | Some(syn::FnArg::Receiver(_)) | None => bail_spanned!( |
198 | func.sig.paren_token.span.join() => "expected `&PyModule` or `Py<PyModule>` as first argument with `pass_module`" |
199 | ), |
200 | }; |
201 | method::FnType::FnModule(span) |
202 | } else { |
203 | method::FnType::FnStatic |
204 | }; |
205 | |
206 | let arguments = func |
207 | .sig |
208 | .inputs |
209 | .iter_mut() |
210 | .skip(if tp.skip_first_rust_argument_in_python_signature() { |
211 | 1 |
212 | } else { |
213 | 0 |
214 | }) |
215 | .map(FnArg::parse) |
216 | .collect::<syn::Result<Vec<_>>>()?; |
217 | |
218 | let signature = if let Some(signature) = signature { |
219 | FunctionSignature::from_arguments_and_attribute(arguments, signature)? |
220 | } else { |
221 | FunctionSignature::from_arguments(arguments)? |
222 | }; |
223 | |
224 | let ty = method::get_return_info(&func.sig.output); |
225 | |
226 | let spec = method::FnSpec { |
227 | tp, |
228 | name: &func.sig.ident, |
229 | convention: CallingConvention::from_signature(&signature), |
230 | python_name, |
231 | signature, |
232 | output: ty, |
233 | text_signature, |
234 | unsafety: func.sig.unsafety, |
235 | deprecations: Deprecations::new(), |
236 | }; |
237 | |
238 | let krate = get_pyo3_crate(&krate); |
239 | |
240 | let vis = &func.vis; |
241 | let name = &func.sig.ident; |
242 | |
243 | let wrapper_ident = format_ident!("__pyfunction_ {}" , spec.name); |
244 | let wrapper = spec.get_wrapper_function(&wrapper_ident, None)?; |
245 | let methoddef = spec.get_methoddef(wrapper_ident, &spec.get_doc(&func.attrs)); |
246 | |
247 | let wrapped_pyfunction = quote! { |
248 | |
249 | // Create a module with the same name as the `#[pyfunction]` - this way `use <the function>` |
250 | // will actually bring both the module and the function into scope. |
251 | #[doc(hidden)] |
252 | #vis mod #name { |
253 | pub(crate) struct MakeDef; |
254 | pub const DEF: #krate::impl_::pyfunction::PyMethodDef = MakeDef::DEF; |
255 | } |
256 | |
257 | // Generate the definition inside an anonymous function in the same scope as the original function - |
258 | // this avoids complications around the fact that the generated module has a different scope |
259 | // (and `super` doesn't always refer to the outer scope, e.g. if the `#[pyfunction] is |
260 | // inside a function body) |
261 | const _: () = { |
262 | use #krate as _pyo3; |
263 | impl #name::MakeDef { |
264 | const DEF: #krate::impl_::pyfunction::PyMethodDef = #methoddef; |
265 | } |
266 | |
267 | #[allow(non_snake_case)] |
268 | #wrapper |
269 | }; |
270 | }; |
271 | Ok(wrapped_pyfunction) |
272 | } |
273 | |