1 | use crate::utils::Ctx; |
2 | use crate::{ |
3 | attributes::{ |
4 | self, get_pyo3_options, take_attributes, take_pyo3_options, CrateAttribute, |
5 | FromPyWithAttribute, NameAttribute, TextSignatureAttribute, |
6 | }, |
7 | method::{self, CallingConvention, FnArg}, |
8 | pymethod::check_generic, |
9 | }; |
10 | use proc_macro2::TokenStream; |
11 | use quote::{format_ident, quote}; |
12 | use syn::parse::{Parse, ParseStream}; |
13 | use syn::punctuated::Punctuated; |
14 | use syn::{ext::IdentExt, spanned::Spanned, Result}; |
15 | |
16 | mod signature; |
17 | |
18 | pub use self::signature::{ConstructorAttribute, FunctionSignature, SignatureAttribute}; |
19 | |
20 | #[derive (Clone, Debug)] |
21 | pub struct PyFunctionArgPyO3Attributes { |
22 | pub from_py_with: Option<FromPyWithAttribute>, |
23 | pub cancel_handle: Option<attributes::kw::cancel_handle>, |
24 | } |
25 | |
26 | enum PyFunctionArgPyO3Attribute { |
27 | FromPyWith(FromPyWithAttribute), |
28 | CancelHandle(attributes::kw::cancel_handle), |
29 | } |
30 | |
31 | impl Parse for PyFunctionArgPyO3Attribute { |
32 | fn parse(input: ParseStream<'_>) -> Result<Self> { |
33 | let lookahead: Lookahead1<'_> = input.lookahead1(); |
34 | if lookahead.peek(token:attributes::kw::cancel_handle) { |
35 | input.parse().map(op:PyFunctionArgPyO3Attribute::CancelHandle) |
36 | } else if lookahead.peek(token:attributes::kw::from_py_with) { |
37 | input.parse().map(op:PyFunctionArgPyO3Attribute::FromPyWith) |
38 | } else { |
39 | Err(lookahead.error()) |
40 | } |
41 | } |
42 | } |
43 | |
44 | impl PyFunctionArgPyO3Attributes { |
45 | /// Parses #[pyo3(from_python_with = "func")] |
46 | pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> { |
47 | let mut attributes = PyFunctionArgPyO3Attributes { |
48 | from_py_with: None, |
49 | cancel_handle: None, |
50 | }; |
51 | take_attributes(attrs, |attr| { |
52 | if let Some(pyo3_attrs) = get_pyo3_options(attr)? { |
53 | for attr in pyo3_attrs { |
54 | match attr { |
55 | PyFunctionArgPyO3Attribute::FromPyWith(from_py_with) => { |
56 | ensure_spanned!( |
57 | attributes.from_py_with.is_none(), |
58 | from_py_with.span() => "`from_py_with` may only be specified once per argument" |
59 | ); |
60 | attributes.from_py_with = Some(from_py_with); |
61 | } |
62 | PyFunctionArgPyO3Attribute::CancelHandle(cancel_handle) => { |
63 | ensure_spanned!( |
64 | attributes.cancel_handle.is_none(), |
65 | cancel_handle.span() => "`cancel_handle` may only be specified once per argument" |
66 | ); |
67 | attributes.cancel_handle = Some(cancel_handle); |
68 | } |
69 | } |
70 | ensure_spanned!( |
71 | attributes.from_py_with.is_none() || attributes.cancel_handle.is_none(), |
72 | attributes.cancel_handle.unwrap().span() => "`from_py_with` and `cancel_handle` cannot be specified together" |
73 | ); |
74 | } |
75 | Ok(true) |
76 | } else { |
77 | Ok(false) |
78 | } |
79 | })?; |
80 | Ok(attributes) |
81 | } |
82 | } |
83 | |
84 | #[derive (Default)] |
85 | pub struct PyFunctionOptions { |
86 | pub pass_module: Option<attributes::kw::pass_module>, |
87 | pub name: Option<NameAttribute>, |
88 | pub signature: Option<SignatureAttribute>, |
89 | pub text_signature: Option<TextSignatureAttribute>, |
90 | pub krate: Option<CrateAttribute>, |
91 | } |
92 | |
93 | impl Parse for PyFunctionOptions { |
94 | fn parse(input: ParseStream<'_>) -> Result<Self> { |
95 | let mut options: PyFunctionOptions = PyFunctionOptions::default(); |
96 | |
97 | let attrs: Punctuated = Punctuated::<PyFunctionOption, syn::Token![,]>::parse_terminated(input)?; |
98 | options.add_attributes(attrs)?; |
99 | |
100 | Ok(options) |
101 | } |
102 | } |
103 | |
104 | pub enum PyFunctionOption { |
105 | Name(NameAttribute), |
106 | PassModule(attributes::kw::pass_module), |
107 | Signature(SignatureAttribute), |
108 | TextSignature(TextSignatureAttribute), |
109 | Crate(CrateAttribute), |
110 | } |
111 | |
112 | impl Parse for PyFunctionOption { |
113 | fn parse(input: ParseStream<'_>) -> Result<Self> { |
114 | let lookahead: Lookahead1<'_> = input.lookahead1(); |
115 | if lookahead.peek(token:attributes::kw::name) { |
116 | input.parse().map(op:PyFunctionOption::Name) |
117 | } else if lookahead.peek(token:attributes::kw::pass_module) { |
118 | input.parse().map(op:PyFunctionOption::PassModule) |
119 | } else if lookahead.peek(token:attributes::kw::signature) { |
120 | input.parse().map(op:PyFunctionOption::Signature) |
121 | } else if lookahead.peek(token:attributes::kw::text_signature) { |
122 | input.parse().map(op:PyFunctionOption::TextSignature) |
123 | } else if lookahead.peek(syn::Token![crate]) { |
124 | input.parse().map(op:PyFunctionOption::Crate) |
125 | } else { |
126 | Err(lookahead.error()) |
127 | } |
128 | } |
129 | } |
130 | |
131 | impl PyFunctionOptions { |
132 | pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> { |
133 | let mut options = PyFunctionOptions::default(); |
134 | options.add_attributes(take_pyo3_options(attrs)?)?; |
135 | Ok(options) |
136 | } |
137 | |
138 | pub fn add_attributes( |
139 | &mut self, |
140 | attrs: impl IntoIterator<Item = PyFunctionOption>, |
141 | ) -> Result<()> { |
142 | macro_rules! set_option { |
143 | ($key:ident) => { |
144 | { |
145 | ensure_spanned!( |
146 | self.$key.is_none(), |
147 | $key.span() => concat!("`" , stringify!($key), "` may only be specified once" ) |
148 | ); |
149 | self.$key = Some($key); |
150 | } |
151 | }; |
152 | } |
153 | for attr in attrs { |
154 | match attr { |
155 | PyFunctionOption::Name(name) => set_option!(name), |
156 | PyFunctionOption::PassModule(pass_module) => set_option!(pass_module), |
157 | PyFunctionOption::Signature(signature) => set_option!(signature), |
158 | PyFunctionOption::TextSignature(text_signature) => set_option!(text_signature), |
159 | PyFunctionOption::Crate(krate) => set_option!(krate), |
160 | } |
161 | } |
162 | Ok(()) |
163 | } |
164 | } |
165 | |
166 | pub fn build_py_function( |
167 | ast: &mut syn::ItemFn, |
168 | mut options: PyFunctionOptions, |
169 | ) -> syn::Result<TokenStream> { |
170 | options.add_attributes(attrs:take_pyo3_options(&mut ast.attrs)?)?; |
171 | impl_wrap_pyfunction(func:ast, options) |
172 | } |
173 | |
174 | /// Generates python wrapper over a function that allows adding it to a python module as a python |
175 | /// function |
176 | pub fn impl_wrap_pyfunction( |
177 | func: &mut syn::ItemFn, |
178 | options: PyFunctionOptions, |
179 | ) -> syn::Result<TokenStream> { |
180 | check_generic(&func.sig)?; |
181 | let PyFunctionOptions { |
182 | pass_module, |
183 | name, |
184 | signature, |
185 | text_signature, |
186 | krate, |
187 | } = options; |
188 | |
189 | let ctx = &Ctx::new(&krate, Some(&func.sig)); |
190 | let Ctx { pyo3_path, .. } = &ctx; |
191 | |
192 | let python_name = name |
193 | .as_ref() |
194 | .map_or_else(|| &func.sig.ident, |name| &name.value.0) |
195 | .unraw(); |
196 | |
197 | let tp = if pass_module.is_some() { |
198 | let span = match func.sig.inputs.first() { |
199 | Some(syn::FnArg::Typed(first_arg)) => first_arg.ty.span(), |
200 | Some(syn::FnArg::Receiver(_)) | None => bail_spanned!( |
201 | func.sig.paren_token.span.join() => "expected `&PyModule` or `Py<PyModule>` as first argument with `pass_module`" |
202 | ), |
203 | }; |
204 | method::FnType::FnModule(span) |
205 | } else { |
206 | method::FnType::FnStatic |
207 | }; |
208 | |
209 | let arguments = func |
210 | .sig |
211 | .inputs |
212 | .iter_mut() |
213 | .skip(if tp.skip_first_rust_argument_in_python_signature() { |
214 | 1 |
215 | } else { |
216 | 0 |
217 | }) |
218 | .map(FnArg::parse) |
219 | .collect::<syn::Result<Vec<_>>>()?; |
220 | |
221 | let signature = if let Some(signature) = signature { |
222 | FunctionSignature::from_arguments_and_attribute(arguments, signature)? |
223 | } else { |
224 | FunctionSignature::from_arguments(arguments) |
225 | }; |
226 | |
227 | let spec = method::FnSpec { |
228 | tp, |
229 | name: &func.sig.ident, |
230 | convention: CallingConvention::from_signature(&signature), |
231 | python_name, |
232 | signature, |
233 | text_signature, |
234 | asyncness: func.sig.asyncness, |
235 | unsafety: func.sig.unsafety, |
236 | }; |
237 | |
238 | let vis = &func.vis; |
239 | let name = &func.sig.ident; |
240 | |
241 | let wrapper_ident = format_ident!("__pyfunction_ {}" , spec.name); |
242 | let wrapper = spec.get_wrapper_function(&wrapper_ident, None, ctx)?; |
243 | let methoddef = spec.get_methoddef(wrapper_ident, &spec.get_doc(&func.attrs, ctx), ctx); |
244 | |
245 | let wrapped_pyfunction = quote! { |
246 | |
247 | // Create a module with the same name as the `#[pyfunction]` - this way `use <the function>` |
248 | // will actually bring both the module and the function into scope. |
249 | #[doc(hidden)] |
250 | #vis mod #name { |
251 | pub(crate) struct MakeDef; |
252 | pub const _PYO3_DEF: #pyo3_path::impl_::pymethods::PyMethodDef = MakeDef::_PYO3_DEF; |
253 | } |
254 | |
255 | // Generate the definition inside an anonymous function in the same scope as the original function - |
256 | // this avoids complications around the fact that the generated module has a different scope |
257 | // (and `super` doesn't always refer to the outer scope, e.g. if the `#[pyfunction] is |
258 | // inside a function body) |
259 | #[allow(unknown_lints, non_local_definitions)] |
260 | impl #name::MakeDef { |
261 | const _PYO3_DEF: #pyo3_path::impl_::pymethods::PyMethodDef = #methoddef; |
262 | } |
263 | |
264 | #[allow(non_snake_case)] |
265 | #wrapper |
266 | }; |
267 | Ok(wrapped_pyfunction) |
268 | } |
269 | |