1 | use std::borrow::Cow; |
2 | |
3 | use crate::attributes::{NameAttribute, RenamingRule}; |
4 | use crate::method::{CallingConvention, ExtractErrorMode}; |
5 | use crate::utils::{ensure_not_async_fn, PythonDoc}; |
6 | use crate::{ |
7 | method::{FnArg, FnSpec, FnType, SelfType}, |
8 | pyfunction::PyFunctionOptions, |
9 | }; |
10 | use crate::{quotes, utils}; |
11 | use proc_macro2::{Span, TokenStream}; |
12 | use quote::{format_ident, quote, ToTokens}; |
13 | use syn::{ext::IdentExt, spanned::Spanned, Result}; |
14 | |
15 | /// Generated code for a single pymethod item. |
16 | pub struct MethodAndMethodDef { |
17 | /// The implementation of the Python wrapper for the pymethod |
18 | pub associated_method: TokenStream, |
19 | /// The method def which will be used to register this pymethod |
20 | pub method_def: TokenStream, |
21 | } |
22 | |
23 | /// Generated code for a single pymethod item which is registered by a slot. |
24 | pub struct MethodAndSlotDef { |
25 | /// The implementation of the Python wrapper for the pymethod |
26 | pub associated_method: TokenStream, |
27 | /// The slot def which will be used to register this pymethod |
28 | pub slot_def: TokenStream, |
29 | } |
30 | |
31 | pub enum GeneratedPyMethod { |
32 | Method(MethodAndMethodDef), |
33 | Proto(MethodAndSlotDef), |
34 | SlotTraitImpl(String, TokenStream), |
35 | } |
36 | |
37 | pub struct PyMethod<'a> { |
38 | kind: PyMethodKind, |
39 | method_name: String, |
40 | spec: FnSpec<'a>, |
41 | } |
42 | |
43 | enum PyMethodKind { |
44 | Fn, |
45 | Proto(PyMethodProtoKind), |
46 | } |
47 | |
48 | impl PyMethodKind { |
49 | fn from_name(name: &str) -> Self { |
50 | match name { |
51 | // Protocol implemented through slots |
52 | "__str__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__STR__)), |
53 | "__repr__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__REPR__)), |
54 | "__hash__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__HASH__)), |
55 | "__richcmp__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__RICHCMP__)), |
56 | "__get__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__GET__)), |
57 | "__iter__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__ITER__)), |
58 | "__next__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__NEXT__)), |
59 | "__await__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__AWAIT__)), |
60 | "__aiter__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__AITER__)), |
61 | "__anext__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__ANEXT__)), |
62 | "__len__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__LEN__)), |
63 | "__contains__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__CONTAINS__)), |
64 | "__concat__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__CONCAT__)), |
65 | "__repeat__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__REPEAT__)), |
66 | "__inplace_concat__" => { |
67 | PyMethodKind::Proto(PyMethodProtoKind::Slot(&__INPLACE_CONCAT__)) |
68 | } |
69 | "__inplace_repeat__" => { |
70 | PyMethodKind::Proto(PyMethodProtoKind::Slot(&__INPLACE_REPEAT__)) |
71 | } |
72 | "__getitem__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__GETITEM__)), |
73 | "__pos__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__POS__)), |
74 | "__neg__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__NEG__)), |
75 | "__abs__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__ABS__)), |
76 | "__invert__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__INVERT__)), |
77 | "__index__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__INDEX__)), |
78 | "__int__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__INT__)), |
79 | "__float__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__FLOAT__)), |
80 | "__bool__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__BOOL__)), |
81 | "__iadd__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IADD__)), |
82 | "__isub__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__ISUB__)), |
83 | "__imul__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IMUL__)), |
84 | "__imatmul__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IMATMUL__)), |
85 | "__itruediv__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__ITRUEDIV__)), |
86 | "__ifloordiv__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IFLOORDIV__)), |
87 | "__imod__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IMOD__)), |
88 | "__ipow__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IPOW__)), |
89 | "__ilshift__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__ILSHIFT__)), |
90 | "__irshift__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IRSHIFT__)), |
91 | "__iand__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IAND__)), |
92 | "__ixor__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IXOR__)), |
93 | "__ior__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__IOR__)), |
94 | "__getbuffer__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__GETBUFFER__)), |
95 | "__releasebuffer__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__RELEASEBUFFER__)), |
96 | "__clear__" => PyMethodKind::Proto(PyMethodProtoKind::Slot(&__CLEAR__)), |
97 | // Protocols implemented through traits |
98 | "__getattribute__" => { |
99 | PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__GETATTRIBUTE__)) |
100 | } |
101 | "__getattr__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__GETATTR__)), |
102 | "__setattr__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__SETATTR__)), |
103 | "__delattr__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__DELATTR__)), |
104 | "__set__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__SET__)), |
105 | "__delete__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__DELETE__)), |
106 | "__setitem__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__SETITEM__)), |
107 | "__delitem__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__DELITEM__)), |
108 | "__add__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__ADD__)), |
109 | "__radd__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RADD__)), |
110 | "__sub__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__SUB__)), |
111 | "__rsub__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RSUB__)), |
112 | "__mul__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__MUL__)), |
113 | "__rmul__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RMUL__)), |
114 | "__matmul__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__MATMUL__)), |
115 | "__rmatmul__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RMATMUL__)), |
116 | "__floordiv__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__FLOORDIV__)), |
117 | "__rfloordiv__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RFLOORDIV__)), |
118 | "__truediv__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__TRUEDIV__)), |
119 | "__rtruediv__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RTRUEDIV__)), |
120 | "__divmod__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__DIVMOD__)), |
121 | "__rdivmod__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RDIVMOD__)), |
122 | "__mod__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__MOD__)), |
123 | "__rmod__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RMOD__)), |
124 | "__lshift__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__LSHIFT__)), |
125 | "__rlshift__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RLSHIFT__)), |
126 | "__rshift__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RSHIFT__)), |
127 | "__rrshift__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RRSHIFT__)), |
128 | "__and__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__AND__)), |
129 | "__rand__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RAND__)), |
130 | "__xor__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__XOR__)), |
131 | "__rxor__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RXOR__)), |
132 | "__or__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__OR__)), |
133 | "__ror__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__ROR__)), |
134 | "__pow__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__POW__)), |
135 | "__rpow__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RPOW__)), |
136 | "__lt__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__LT__)), |
137 | "__le__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__LE__)), |
138 | "__eq__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__EQ__)), |
139 | "__ne__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__NE__)), |
140 | "__gt__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__GT__)), |
141 | "__ge__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__GE__)), |
142 | // Some tricky protocols which don't fit the pattern of the rest |
143 | "__call__" => PyMethodKind::Proto(PyMethodProtoKind::Call), |
144 | "__traverse__" => PyMethodKind::Proto(PyMethodProtoKind::Traverse), |
145 | // Not a proto |
146 | _ => PyMethodKind::Fn, |
147 | } |
148 | } |
149 | } |
150 | |
151 | enum PyMethodProtoKind { |
152 | Slot(&'static SlotDef), |
153 | Call, |
154 | Traverse, |
155 | SlotFragment(&'static SlotFragmentDef), |
156 | } |
157 | |
158 | impl<'a> PyMethod<'a> { |
159 | fn parse( |
160 | sig: &'a mut syn::Signature, |
161 | meth_attrs: &mut Vec<syn::Attribute>, |
162 | options: PyFunctionOptions, |
163 | ) -> Result<Self> { |
164 | let spec: FnSpec<'_> = FnSpec::parse(sig, meth_attrs, options)?; |
165 | |
166 | let method_name: String = spec.python_name.to_string(); |
167 | let kind: PyMethodKind = PyMethodKind::from_name(&method_name); |
168 | |
169 | Ok(Self { |
170 | kind, |
171 | method_name, |
172 | spec, |
173 | }) |
174 | } |
175 | } |
176 | |
177 | pub fn is_proto_method(name: &str) -> bool { |
178 | match PyMethodKind::from_name(name) { |
179 | PyMethodKind::Fn => false, |
180 | PyMethodKind::Proto(_) => true, |
181 | } |
182 | } |
183 | |
184 | pub fn gen_py_method( |
185 | cls: &syn::Type, |
186 | sig: &mut syn::Signature, |
187 | meth_attrs: &mut Vec<syn::Attribute>, |
188 | options: PyFunctionOptions, |
189 | ) -> Result<GeneratedPyMethod> { |
190 | check_generic(sig)?; |
191 | ensure_not_async_fn(sig)?; |
192 | ensure_function_options_valid(&options)?; |
193 | let method = PyMethod::parse(sig, meth_attrs, options)?; |
194 | let spec = &method.spec; |
195 | |
196 | Ok(match (method.kind, &spec.tp) { |
197 | // Class attributes go before protos so that class attributes can be used to set proto |
198 | // method to None. |
199 | (_, FnType::ClassAttribute) => { |
200 | GeneratedPyMethod::Method(impl_py_class_attribute(cls, spec)?) |
201 | } |
202 | (PyMethodKind::Proto(proto_kind), _) => { |
203 | ensure_no_forbidden_protocol_attributes(&proto_kind, spec, &method.method_name)?; |
204 | match proto_kind { |
205 | PyMethodProtoKind::Slot(slot_def) => { |
206 | let slot = slot_def.generate_type_slot(cls, spec, &method.method_name)?; |
207 | GeneratedPyMethod::Proto(slot) |
208 | } |
209 | PyMethodProtoKind::Call => { |
210 | GeneratedPyMethod::Proto(impl_call_slot(cls, method.spec)?) |
211 | } |
212 | PyMethodProtoKind::Traverse => { |
213 | GeneratedPyMethod::Proto(impl_traverse_slot(cls, spec)?) |
214 | } |
215 | PyMethodProtoKind::SlotFragment(slot_fragment_def) => { |
216 | let proto = slot_fragment_def.generate_pyproto_fragment(cls, spec)?; |
217 | GeneratedPyMethod::SlotTraitImpl(method.method_name, proto) |
218 | } |
219 | } |
220 | } |
221 | // ordinary functions (with some specialties) |
222 | (_, FnType::Fn(_)) => GeneratedPyMethod::Method(impl_py_method_def( |
223 | cls, |
224 | spec, |
225 | &spec.get_doc(meth_attrs), |
226 | None, |
227 | )?), |
228 | (_, FnType::FnClass(_)) => GeneratedPyMethod::Method(impl_py_method_def( |
229 | cls, |
230 | spec, |
231 | &spec.get_doc(meth_attrs), |
232 | Some(quote!(_pyo3::ffi::METH_CLASS)), |
233 | )?), |
234 | (_, FnType::FnStatic) => GeneratedPyMethod::Method(impl_py_method_def( |
235 | cls, |
236 | spec, |
237 | &spec.get_doc(meth_attrs), |
238 | Some(quote!(_pyo3::ffi::METH_STATIC)), |
239 | )?), |
240 | // special prototypes |
241 | (_, FnType::FnNew) | (_, FnType::FnNewClass(_)) => { |
242 | GeneratedPyMethod::Proto(impl_py_method_def_new(cls, spec)?) |
243 | } |
244 | |
245 | (_, FnType::Getter(self_type)) => GeneratedPyMethod::Method(impl_py_getter_def( |
246 | cls, |
247 | PropertyType::Function { |
248 | self_type, |
249 | spec, |
250 | doc: spec.get_doc(meth_attrs), |
251 | }, |
252 | )?), |
253 | (_, FnType::Setter(self_type)) => GeneratedPyMethod::Method(impl_py_setter_def( |
254 | cls, |
255 | PropertyType::Function { |
256 | self_type, |
257 | spec, |
258 | doc: spec.get_doc(meth_attrs), |
259 | }, |
260 | )?), |
261 | (_, FnType::FnModule(_)) => { |
262 | unreachable!("methods cannot be FnModule" ) |
263 | } |
264 | }) |
265 | } |
266 | |
267 | pub fn check_generic(sig: &syn::Signature) -> syn::Result<()> { |
268 | let err_msg: impl Fn(&str) -> String = |typ: &str| format!("Python functions cannot have generic {} parameters" , typ); |
269 | for param: &GenericParam in &sig.generics.params { |
270 | match param { |
271 | syn::GenericParam::Lifetime(_) => {} |
272 | syn::GenericParam::Type(_) => bail_spanned!(param.span() => err_msg("type" )), |
273 | syn::GenericParam::Const(_) => bail_spanned!(param.span() => err_msg("const" )), |
274 | } |
275 | } |
276 | Ok(()) |
277 | } |
278 | |
279 | fn ensure_function_options_valid(options: &PyFunctionOptions) -> syn::Result<()> { |
280 | if let Some(pass_module: &pass_module) = &options.pass_module { |
281 | bail_spanned!(pass_module.span() => "`pass_module` cannot be used on Python methods" ); |
282 | } |
283 | Ok(()) |
284 | } |
285 | |
286 | fn ensure_no_forbidden_protocol_attributes( |
287 | proto_kind: &PyMethodProtoKind, |
288 | spec: &FnSpec<'_>, |
289 | method_name: &str, |
290 | ) -> syn::Result<()> { |
291 | if let Some(signature: &KeywordAttribute) = &spec.signature.attribute { |
292 | // __call__ is allowed to have a signature, but nothing else is. |
293 | if !matches!(proto_kind, PyMethodProtoKind::Call) { |
294 | bail_spanned!(signature.kw.span() => format!("`signature` cannot be used with magic method ` {}`" , method_name)); |
295 | } |
296 | } |
297 | if let Some(text_signature: &KeywordAttribute) = &spec.text_signature { |
298 | bail_spanned!(text_signature.kw.span() => format!("`text_signature` cannot be used with magic method ` {}`" , method_name)); |
299 | } |
300 | Ok(()) |
301 | } |
302 | |
303 | /// Also used by pyfunction. |
304 | pub fn impl_py_method_def( |
305 | cls: &syn::Type, |
306 | spec: &FnSpec<'_>, |
307 | doc: &PythonDoc, |
308 | flags: Option<TokenStream>, |
309 | ) -> Result<MethodAndMethodDef> { |
310 | let wrapper_ident: Ident = format_ident!("__pymethod_ {}__" , spec.python_name); |
311 | let associated_method: TokenStream = spec.get_wrapper_function(&wrapper_ident, cls:Some(cls))?; |
312 | let add_flags: Option = flags.map(|flags: TokenStream| quote!(.flags(#flags))); |
313 | let methoddef_type: TokenStream = match spec.tp { |
314 | FnType::FnStatic => quote!(Static), |
315 | FnType::FnClass(_) => quote!(Class), |
316 | _ => quote!(Method), |
317 | }; |
318 | let methoddef: TokenStream = spec.get_methoddef(wrapper:quote! { #cls::#wrapper_ident }, doc); |
319 | let method_def: TokenStream = quote! { |
320 | _pyo3::class::PyMethodDefType::#methoddef_type(#methoddef #add_flags) |
321 | }; |
322 | Ok(MethodAndMethodDef { |
323 | associated_method, |
324 | method_def, |
325 | }) |
326 | } |
327 | |
328 | fn impl_py_method_def_new(cls: &syn::Type, spec: &FnSpec<'_>) -> Result<MethodAndSlotDef> { |
329 | let wrapper_ident = syn::Ident::new("__pymethod___new____" , Span::call_site()); |
330 | let associated_method = spec.get_wrapper_function(&wrapper_ident, Some(cls))?; |
331 | // Use just the text_signature_call_signature() because the class' Python name |
332 | // isn't known to `#[pymethods]` - that has to be attached at runtime from the PyClassImpl |
333 | // trait implementation created by `#[pyclass]`. |
334 | let text_signature_body = spec.text_signature_call_signature().map_or_else( |
335 | || quote!(::std::option::Option::None), |
336 | |text_signature| quote!(::std::option::Option::Some(#text_signature)), |
337 | ); |
338 | let deprecations = &spec.deprecations; |
339 | let slot_def = quote! { |
340 | _pyo3::ffi::PyType_Slot { |
341 | slot: _pyo3::ffi::Py_tp_new, |
342 | pfunc: { |
343 | unsafe extern "C" fn trampoline( |
344 | subtype: *mut _pyo3::ffi::PyTypeObject, |
345 | args: *mut _pyo3::ffi::PyObject, |
346 | kwargs: *mut _pyo3::ffi::PyObject, |
347 | ) -> *mut _pyo3::ffi::PyObject |
348 | { |
349 | #deprecations |
350 | |
351 | use _pyo3::impl_::pyclass::*; |
352 | impl PyClassNewTextSignature<#cls> for PyClassImplCollector<#cls> { |
353 | #[inline] |
354 | fn new_text_signature(self) -> ::std::option::Option<&'static str> { |
355 | #text_signature_body |
356 | } |
357 | } |
358 | |
359 | _pyo3::impl_::trampoline::newfunc( |
360 | subtype, |
361 | args, |
362 | kwargs, |
363 | #cls::#wrapper_ident |
364 | ) |
365 | } |
366 | trampoline |
367 | } as _pyo3::ffi::newfunc as _ |
368 | } |
369 | }; |
370 | Ok(MethodAndSlotDef { |
371 | associated_method, |
372 | slot_def, |
373 | }) |
374 | } |
375 | |
376 | fn impl_call_slot(cls: &syn::Type, mut spec: FnSpec<'_>) -> Result<MethodAndSlotDef> { |
377 | // HACK: __call__ proto slot must always use varargs calling convention, so change the spec. |
378 | // Probably indicates there's a refactoring opportunity somewhere. |
379 | spec.convention = CallingConvention::Varargs; |
380 | |
381 | let wrapper_ident = syn::Ident::new("__pymethod___call____" , Span::call_site()); |
382 | let associated_method = spec.get_wrapper_function(&wrapper_ident, Some(cls))?; |
383 | let slot_def = quote! { |
384 | _pyo3::ffi::PyType_Slot { |
385 | slot: _pyo3::ffi::Py_tp_call, |
386 | pfunc: { |
387 | unsafe extern "C" fn trampoline( |
388 | slf: *mut _pyo3::ffi::PyObject, |
389 | args: *mut _pyo3::ffi::PyObject, |
390 | kwargs: *mut _pyo3::ffi::PyObject, |
391 | ) -> *mut _pyo3::ffi::PyObject |
392 | { |
393 | _pyo3::impl_::trampoline::ternaryfunc( |
394 | slf, |
395 | args, |
396 | kwargs, |
397 | #cls::#wrapper_ident |
398 | ) |
399 | } |
400 | trampoline |
401 | } as _pyo3::ffi::ternaryfunc as _ |
402 | } |
403 | }; |
404 | Ok(MethodAndSlotDef { |
405 | associated_method, |
406 | slot_def, |
407 | }) |
408 | } |
409 | |
410 | fn impl_traverse_slot(cls: &syn::Type, spec: &FnSpec<'_>) -> syn::Result<MethodAndSlotDef> { |
411 | if let (Some(py_arg), _) = split_off_python_arg(&spec.signature.arguments) { |
412 | return Err(syn::Error::new_spanned(py_arg.ty, "__traverse__ may not take `Python`. \ |
413 | Usually, an implementation of `__traverse__` should do nothing but calls to `visit.call`. \ |
414 | Most importantly, safe access to the GIL is prohibited inside implementations of `__traverse__`, \ |
415 | i.e. `Python::with_gil` will panic." )); |
416 | } |
417 | |
418 | let rust_fn_ident = spec.name; |
419 | |
420 | let associated_method = quote! { |
421 | pub unsafe extern "C" fn __pymethod_traverse__( |
422 | slf: *mut _pyo3::ffi::PyObject, |
423 | visit: _pyo3::ffi::visitproc, |
424 | arg: *mut ::std::os::raw::c_void, |
425 | ) -> ::std::os::raw::c_int { |
426 | _pyo3::impl_::pymethods::_call_traverse::<#cls>(slf, #cls::#rust_fn_ident, visit, arg) |
427 | } |
428 | }; |
429 | let slot_def = quote! { |
430 | _pyo3::ffi::PyType_Slot { |
431 | slot: _pyo3::ffi::Py_tp_traverse, |
432 | pfunc: #cls::__pymethod_traverse__ as _pyo3::ffi::traverseproc as _ |
433 | } |
434 | }; |
435 | Ok(MethodAndSlotDef { |
436 | associated_method, |
437 | slot_def, |
438 | }) |
439 | } |
440 | |
441 | fn impl_py_class_attribute(cls: &syn::Type, spec: &FnSpec<'_>) -> syn::Result<MethodAndMethodDef> { |
442 | let (py_arg, args) = split_off_python_arg(&spec.signature.arguments); |
443 | ensure_spanned!( |
444 | args.is_empty(), |
445 | args[0].ty.span() => "#[classattr] can only have one argument (of type pyo3::Python)" |
446 | ); |
447 | |
448 | let name = &spec.name; |
449 | let fncall = if py_arg.is_some() { |
450 | quote!(function(py)) |
451 | } else { |
452 | quote!(function()) |
453 | }; |
454 | |
455 | let wrapper_ident = format_ident!("__pymethod_ {}__" , name); |
456 | let python_name = spec.null_terminated_python_name(); |
457 | let body = quotes::ok_wrap(fncall); |
458 | |
459 | let associated_method = quote! { |
460 | fn #wrapper_ident(py: _pyo3::Python<'_>) -> _pyo3::PyResult<_pyo3::PyObject> { |
461 | let function = #cls::#name; // Shadow the method name to avoid #3017 |
462 | #body |
463 | } |
464 | }; |
465 | |
466 | let method_def = quote! { |
467 | _pyo3::class::PyMethodDefType::ClassAttribute({ |
468 | _pyo3::class::PyClassAttributeDef::new( |
469 | #python_name, |
470 | _pyo3::impl_::pymethods::PyClassAttributeFactory(#cls::#wrapper_ident) |
471 | ) |
472 | }) |
473 | }; |
474 | |
475 | Ok(MethodAndMethodDef { |
476 | associated_method, |
477 | method_def, |
478 | }) |
479 | } |
480 | |
481 | fn impl_call_setter( |
482 | cls: &syn::Type, |
483 | spec: &FnSpec<'_>, |
484 | self_type: &SelfType, |
485 | ) -> syn::Result<TokenStream> { |
486 | let (py_arg: Option<&FnArg<'_>>, args: &[FnArg<'_>]) = split_off_python_arg(&spec.signature.arguments); |
487 | let slf: TokenStream = self_type.receiver(cls, error_mode:ExtractErrorMode::Raise); |
488 | |
489 | if args.is_empty() { |
490 | bail_spanned!(spec.name.span() => "setter function expected to have one argument" ); |
491 | } else if args.len() > 1 { |
492 | bail_spanned!( |
493 | args[1].ty.span() => |
494 | "setter function can have at most two arguments ([pyo3::Python,] and value)" |
495 | ); |
496 | } |
497 | |
498 | let name: &&Ident = &spec.name; |
499 | let fncall: TokenStream = if py_arg.is_some() { |
500 | quote!(#cls::#name(#slf, py, _val)) |
501 | } else { |
502 | quote!(#cls::#name(#slf, _val)) |
503 | }; |
504 | |
505 | Ok(fncall) |
506 | } |
507 | |
508 | // Used here for PropertyType::Function, used in pyclass for descriptors. |
509 | pub fn impl_py_setter_def( |
510 | cls: &syn::Type, |
511 | property_type: PropertyType<'_>, |
512 | ) -> Result<MethodAndMethodDef> { |
513 | let python_name = property_type.null_terminated_python_name()?; |
514 | let doc = property_type.doc(); |
515 | let setter_impl = match property_type { |
516 | PropertyType::Descriptor { |
517 | field_index, field, .. |
518 | } => { |
519 | let slf = SelfType::Receiver { |
520 | mutable: true, |
521 | span: Span::call_site(), |
522 | } |
523 | .receiver(cls, ExtractErrorMode::Raise); |
524 | if let Some(ident) = &field.ident { |
525 | // named struct field |
526 | quote!({ #slf.#ident = _val; }) |
527 | } else { |
528 | // tuple struct field |
529 | let index = syn::Index::from(field_index); |
530 | quote!({ #slf.#index = _val; }) |
531 | } |
532 | } |
533 | PropertyType::Function { |
534 | spec, self_type, .. |
535 | } => impl_call_setter(cls, spec, self_type)?, |
536 | }; |
537 | |
538 | let wrapper_ident = match property_type { |
539 | PropertyType::Descriptor { |
540 | field: syn::Field { |
541 | ident: Some(ident), .. |
542 | }, |
543 | .. |
544 | } => { |
545 | format_ident!("__pymethod_set_ {}__" , ident) |
546 | } |
547 | PropertyType::Descriptor { field_index, .. } => { |
548 | format_ident!("__pymethod_set_field_ {}__" , field_index) |
549 | } |
550 | PropertyType::Function { spec, .. } => { |
551 | format_ident!("__pymethod_set_ {}__" , spec.name) |
552 | } |
553 | }; |
554 | |
555 | let mut cfg_attrs = TokenStream::new(); |
556 | if let PropertyType::Descriptor { field, .. } = &property_type { |
557 | for attr in field |
558 | .attrs |
559 | .iter() |
560 | .filter(|attr| attr.path().is_ident("cfg" )) |
561 | { |
562 | attr.to_tokens(&mut cfg_attrs); |
563 | } |
564 | } |
565 | |
566 | let associated_method = quote! { |
567 | #cfg_attrs |
568 | unsafe fn #wrapper_ident( |
569 | py: _pyo3::Python<'_>, |
570 | _slf: *mut _pyo3::ffi::PyObject, |
571 | _value: *mut _pyo3::ffi::PyObject, |
572 | ) -> _pyo3::PyResult<::std::os::raw::c_int> { |
573 | let _value = py |
574 | .from_borrowed_ptr_or_opt(_value) |
575 | .ok_or_else(|| { |
576 | _pyo3::exceptions::PyAttributeError::new_err("can't delete attribute" ) |
577 | })?; |
578 | let _val = _pyo3::FromPyObject::extract(_value)?; |
579 | |
580 | _pyo3::callback::convert(py, #setter_impl) |
581 | } |
582 | }; |
583 | |
584 | let method_def = quote! { |
585 | #cfg_attrs |
586 | _pyo3::class::PyMethodDefType::Setter( |
587 | _pyo3::class::PySetterDef::new( |
588 | #python_name, |
589 | _pyo3::impl_::pymethods::PySetter(#cls::#wrapper_ident), |
590 | #doc |
591 | ) |
592 | ) |
593 | }; |
594 | |
595 | Ok(MethodAndMethodDef { |
596 | associated_method, |
597 | method_def, |
598 | }) |
599 | } |
600 | |
601 | fn impl_call_getter( |
602 | cls: &syn::Type, |
603 | spec: &FnSpec<'_>, |
604 | self_type: &SelfType, |
605 | ) -> syn::Result<TokenStream> { |
606 | let (py_arg: Option<&FnArg<'_>>, args: &[FnArg<'_>]) = split_off_python_arg(&spec.signature.arguments); |
607 | let slf: TokenStream = self_type.receiver(cls, error_mode:ExtractErrorMode::Raise); |
608 | ensure_spanned!( |
609 | args.is_empty(), |
610 | args[0].ty.span() => "getter function can only have one argument (of type pyo3::Python)" |
611 | ); |
612 | |
613 | let name: &&Ident = &spec.name; |
614 | let fncall: TokenStream = if py_arg.is_some() { |
615 | quote!(#cls::#name(#slf, py)) |
616 | } else { |
617 | quote!(#cls::#name(#slf)) |
618 | }; |
619 | |
620 | Ok(fncall) |
621 | } |
622 | |
623 | // Used here for PropertyType::Function, used in pyclass for descriptors. |
624 | pub fn impl_py_getter_def( |
625 | cls: &syn::Type, |
626 | property_type: PropertyType<'_>, |
627 | ) -> Result<MethodAndMethodDef> { |
628 | let python_name = property_type.null_terminated_python_name()?; |
629 | let doc = property_type.doc(); |
630 | |
631 | let body = match property_type { |
632 | PropertyType::Descriptor { |
633 | field_index, field, .. |
634 | } => { |
635 | let slf = SelfType::Receiver { |
636 | mutable: false, |
637 | span: Span::call_site(), |
638 | } |
639 | .receiver(cls, ExtractErrorMode::Raise); |
640 | let field_token = if let Some(ident) = &field.ident { |
641 | // named struct field |
642 | ident.to_token_stream() |
643 | } else { |
644 | // tuple struct field |
645 | syn::Index::from(field_index).to_token_stream() |
646 | }; |
647 | quotes::map_result_into_ptr(quotes::ok_wrap(quote! { |
648 | ::std::clone::Clone::clone(&(#slf.#field_token)) |
649 | })) |
650 | } |
651 | // Forward to `IntoPyCallbackOutput`, to handle `#[getter]`s returning results. |
652 | PropertyType::Function { |
653 | spec, self_type, .. |
654 | } => { |
655 | let call = impl_call_getter(cls, spec, self_type)?; |
656 | quote! { |
657 | _pyo3::callback::convert(py, #call) |
658 | } |
659 | } |
660 | }; |
661 | |
662 | let wrapper_ident = match property_type { |
663 | PropertyType::Descriptor { |
664 | field: syn::Field { |
665 | ident: Some(ident), .. |
666 | }, |
667 | .. |
668 | } => { |
669 | format_ident!("__pymethod_get_ {}__" , ident) |
670 | } |
671 | PropertyType::Descriptor { field_index, .. } => { |
672 | format_ident!("__pymethod_get_field_ {}__" , field_index) |
673 | } |
674 | PropertyType::Function { spec, .. } => { |
675 | format_ident!("__pymethod_get_ {}__" , spec.name) |
676 | } |
677 | }; |
678 | |
679 | let mut cfg_attrs = TokenStream::new(); |
680 | if let PropertyType::Descriptor { field, .. } = &property_type { |
681 | for attr in field |
682 | .attrs |
683 | .iter() |
684 | .filter(|attr| attr.path().is_ident("cfg" )) |
685 | { |
686 | attr.to_tokens(&mut cfg_attrs); |
687 | } |
688 | } |
689 | |
690 | let associated_method = quote! { |
691 | #cfg_attrs |
692 | unsafe fn #wrapper_ident( |
693 | py: _pyo3::Python<'_>, |
694 | _slf: *mut _pyo3::ffi::PyObject |
695 | ) -> _pyo3::PyResult<*mut _pyo3::ffi::PyObject> { |
696 | #body |
697 | } |
698 | }; |
699 | |
700 | let method_def = quote! { |
701 | #cfg_attrs |
702 | _pyo3::class::PyMethodDefType::Getter( |
703 | _pyo3::class::PyGetterDef::new( |
704 | #python_name, |
705 | _pyo3::impl_::pymethods::PyGetter(#cls::#wrapper_ident), |
706 | #doc |
707 | ) |
708 | ) |
709 | }; |
710 | |
711 | Ok(MethodAndMethodDef { |
712 | associated_method, |
713 | method_def, |
714 | }) |
715 | } |
716 | |
717 | /// Split an argument of pyo3::Python from the front of the arg list, if present |
718 | fn split_off_python_arg<'a>(args: &'a [FnArg<'a>]) -> (Option<&FnArg<'_>>, &[FnArg<'_>]) { |
719 | match args { |
720 | [py: &FnArg<'_>, args: &[FnArg<'_>] @ ..] if utils::is_python(py.ty) => (Some(py), args), |
721 | args: &[FnArg<'_>] => (None, args), |
722 | } |
723 | } |
724 | |
725 | pub enum PropertyType<'a> { |
726 | Descriptor { |
727 | field_index: usize, |
728 | field: &'a syn::Field, |
729 | python_name: Option<&'a NameAttribute>, |
730 | renaming_rule: Option<RenamingRule>, |
731 | }, |
732 | Function { |
733 | self_type: &'a SelfType, |
734 | spec: &'a FnSpec<'a>, |
735 | doc: PythonDoc, |
736 | }, |
737 | } |
738 | |
739 | impl PropertyType<'_> { |
740 | fn null_terminated_python_name(&self) -> Result<syn::LitStr> { |
741 | match self { |
742 | PropertyType::Descriptor { |
743 | field, |
744 | python_name, |
745 | renaming_rule, |
746 | .. |
747 | } => { |
748 | let name = match (python_name, &field.ident) { |
749 | (Some(name), _) => name.value.0.to_string(), |
750 | (None, Some(field_name)) => { |
751 | let mut name = field_name.unraw().to_string(); |
752 | if let Some(rule) = renaming_rule { |
753 | name = utils::apply_renaming_rule(*rule, &name); |
754 | } |
755 | name.push(' \0' ); |
756 | name |
757 | } |
758 | (None, None) => { |
759 | bail_spanned!(field.span() => "`get` and `set` with tuple struct fields require `name`" ); |
760 | } |
761 | }; |
762 | Ok(syn::LitStr::new(&name, field.span())) |
763 | } |
764 | PropertyType::Function { spec, .. } => Ok(spec.null_terminated_python_name()), |
765 | } |
766 | } |
767 | |
768 | fn doc(&self) -> Cow<'_, PythonDoc> { |
769 | match self { |
770 | PropertyType::Descriptor { field, .. } => { |
771 | Cow::Owned(utils::get_doc(&field.attrs, None)) |
772 | } |
773 | PropertyType::Function { doc, .. } => Cow::Borrowed(doc), |
774 | } |
775 | } |
776 | } |
777 | |
778 | const __STR__: SlotDef = SlotDef::new(slot:"Py_tp_str" , func_ty:"reprfunc" ); |
779 | pub const __REPR__: SlotDef = SlotDef::new(slot:"Py_tp_repr" , func_ty:"reprfunc" ); |
780 | const __HASH__: SlotDef = SlotDefSlotDef::new(slot:"Py_tp_hash" , func_ty:"hashfunc" ) |
781 | .ret_ty(Ty::PyHashT) |
782 | .return_conversion(TokenGenerator( |
783 | || quote! { _pyo3::callback::HashCallbackOutput }, |
784 | )); |
785 | pub const __RICHCMP__: SlotDef = SlotDefSlotDef::new(slot:"Py_tp_richcompare" , func_ty:"richcmpfunc" ) |
786 | .extract_error_mode(ExtractErrorMode::NotImplemented) |
787 | .arguments(&[Ty::Object, Ty::CompareOp]); |
788 | const __GET__: SlotDef = SlotDefSlotDef::new(slot:"Py_tp_descr_get" , func_ty:"descrgetfunc" ) |
789 | .arguments(&[Ty::MaybeNullObject, Ty::MaybeNullObject]); |
790 | const __ITER__: SlotDef = SlotDef::new(slot:"Py_tp_iter" , func_ty:"getiterfunc" ); |
791 | const __NEXT__: SlotDef = SlotDef::new(slot:"Py_tp_iternext" , func_ty:"iternextfunc" ).return_conversion( |
792 | TokenGenerator(|| quote! { _pyo3::class::iter::IterNextOutput::<_, _> }), |
793 | ); |
794 | const __AWAIT__: SlotDef = SlotDef::new(slot:"Py_am_await" , func_ty:"unaryfunc" ); |
795 | const __AITER__: SlotDef = SlotDef::new(slot:"Py_am_aiter" , func_ty:"unaryfunc" ); |
796 | const __ANEXT__: SlotDef = SlotDef::new(slot:"Py_am_anext" , func_ty:"unaryfunc" ).return_conversion( |
797 | TokenGenerator(|| quote! { _pyo3::class::pyasync::IterANextOutput::<_, _> }), |
798 | ); |
799 | const __LEN__: SlotDef = SlotDef::new(slot:"Py_mp_length" , func_ty:"lenfunc" ).ret_ty(Ty::PySsizeT); |
800 | const __CONTAINS__: SlotDef = SlotDefSlotDef::new(slot:"Py_sq_contains" , func_ty:"objobjproc" ) |
801 | .arguments(&[Ty::Object]) |
802 | .ret_ty(Ty::Int); |
803 | const __CONCAT__: SlotDef = SlotDef::new(slot:"Py_sq_concat" , func_ty:"binaryfunc" ).arguments(&[Ty::Object]); |
804 | const __REPEAT__: SlotDef = SlotDef::new(slot:"Py_sq_repeat" , func_ty:"ssizeargfunc" ).arguments(&[Ty::PySsizeT]); |
805 | const __INPLACE_CONCAT__: SlotDef = |
806 | SlotDef::new(slot:"Py_sq_concat" , func_ty:"binaryfunc" ).arguments(&[Ty::Object]); |
807 | const __INPLACE_REPEAT__: SlotDef = |
808 | SlotDef::new(slot:"Py_sq_repeat" , func_ty:"ssizeargfunc" ).arguments(&[Ty::PySsizeT]); |
809 | const __GETITEM__: SlotDef = SlotDef::new(slot:"Py_mp_subscript" , func_ty:"binaryfunc" ).arguments(&[Ty::Object]); |
810 | |
811 | const __POS__: SlotDef = SlotDef::new(slot:"Py_nb_positive" , func_ty:"unaryfunc" ); |
812 | const __NEG__: SlotDef = SlotDef::new(slot:"Py_nb_negative" , func_ty:"unaryfunc" ); |
813 | const __ABS__: SlotDef = SlotDef::new(slot:"Py_nb_absolute" , func_ty:"unaryfunc" ); |
814 | const __INVERT__: SlotDef = SlotDef::new(slot:"Py_nb_invert" , func_ty:"unaryfunc" ); |
815 | const __INDEX__: SlotDef = SlotDef::new(slot:"Py_nb_index" , func_ty:"unaryfunc" ); |
816 | pub const __INT__: SlotDef = SlotDef::new(slot:"Py_nb_int" , func_ty:"unaryfunc" ); |
817 | const __FLOAT__: SlotDef = SlotDef::new(slot:"Py_nb_float" , func_ty:"unaryfunc" ); |
818 | const __BOOL__: SlotDef = SlotDef::new(slot:"Py_nb_bool" , func_ty:"inquiry" ).ret_ty(Ty::Int); |
819 | |
820 | const __IADD__: SlotDef = SlotDefSlotDef::new(slot:"Py_nb_inplace_add" , func_ty:"binaryfunc" ) |
821 | .arguments(&[Ty::Object]) |
822 | .extract_error_mode(ExtractErrorMode::NotImplemented) |
823 | .return_self(); |
824 | const __ISUB__: SlotDef = SlotDefSlotDef::new(slot:"Py_nb_inplace_subtract" , func_ty:"binaryfunc" ) |
825 | .arguments(&[Ty::Object]) |
826 | .extract_error_mode(ExtractErrorMode::NotImplemented) |
827 | .return_self(); |
828 | const __IMUL__: SlotDef = SlotDefSlotDef::new(slot:"Py_nb_inplace_multiply" , func_ty:"binaryfunc" ) |
829 | .arguments(&[Ty::Object]) |
830 | .extract_error_mode(ExtractErrorMode::NotImplemented) |
831 | .return_self(); |
832 | const __IMATMUL__: SlotDef = SlotDefSlotDef::new(slot:"Py_nb_inplace_matrix_multiply" , func_ty:"binaryfunc" ) |
833 | .arguments(&[Ty::Object]) |
834 | .extract_error_mode(ExtractErrorMode::NotImplemented) |
835 | .return_self(); |
836 | const __ITRUEDIV__: SlotDef = SlotDefSlotDef::new(slot:"Py_nb_inplace_true_divide" , func_ty:"binaryfunc" ) |
837 | .arguments(&[Ty::Object]) |
838 | .extract_error_mode(ExtractErrorMode::NotImplemented) |
839 | .return_self(); |
840 | const __IFLOORDIV__: SlotDef = SlotDefSlotDef::new(slot:"Py_nb_inplace_floor_divide" , func_ty:"binaryfunc" ) |
841 | .arguments(&[Ty::Object]) |
842 | .extract_error_mode(ExtractErrorMode::NotImplemented) |
843 | .return_self(); |
844 | const __IMOD__: SlotDef = SlotDefSlotDef::new(slot:"Py_nb_inplace_remainder" , func_ty:"binaryfunc" ) |
845 | .arguments(&[Ty::Object]) |
846 | .extract_error_mode(ExtractErrorMode::NotImplemented) |
847 | .return_self(); |
848 | const __IPOW__: SlotDef = SlotDefSlotDef::new(slot:"Py_nb_inplace_power" , func_ty:"ipowfunc" ) |
849 | .arguments(&[Ty::Object, Ty::IPowModulo]) |
850 | .extract_error_mode(ExtractErrorMode::NotImplemented) |
851 | .return_self(); |
852 | const __ILSHIFT__: SlotDef = SlotDefSlotDef::new(slot:"Py_nb_inplace_lshift" , func_ty:"binaryfunc" ) |
853 | .arguments(&[Ty::Object]) |
854 | .extract_error_mode(ExtractErrorMode::NotImplemented) |
855 | .return_self(); |
856 | const __IRSHIFT__: SlotDef = SlotDefSlotDef::new(slot:"Py_nb_inplace_rshift" , func_ty:"binaryfunc" ) |
857 | .arguments(&[Ty::Object]) |
858 | .extract_error_mode(ExtractErrorMode::NotImplemented) |
859 | .return_self(); |
860 | const __IAND__: SlotDef = SlotDefSlotDef::new(slot:"Py_nb_inplace_and" , func_ty:"binaryfunc" ) |
861 | .arguments(&[Ty::Object]) |
862 | .extract_error_mode(ExtractErrorMode::NotImplemented) |
863 | .return_self(); |
864 | const __IXOR__: SlotDef = SlotDefSlotDef::new(slot:"Py_nb_inplace_xor" , func_ty:"binaryfunc" ) |
865 | .arguments(&[Ty::Object]) |
866 | .extract_error_mode(ExtractErrorMode::NotImplemented) |
867 | .return_self(); |
868 | const __IOR__: SlotDef = SlotDefSlotDef::new(slot:"Py_nb_inplace_or" , func_ty:"binaryfunc" ) |
869 | .arguments(&[Ty::Object]) |
870 | .extract_error_mode(ExtractErrorMode::NotImplemented) |
871 | .return_self(); |
872 | const __GETBUFFER__: SlotDef = SlotDefSlotDef::new(slot:"Py_bf_getbuffer" , func_ty:"getbufferproc" ) |
873 | .arguments(&[Ty::PyBuffer, Ty::Int]) |
874 | .ret_ty(Ty::Int) |
875 | .require_unsafe(); |
876 | const __RELEASEBUFFER__: SlotDef = SlotDefSlotDef::new(slot:"Py_bf_releasebuffer" , func_ty:"releasebufferproc" ) |
877 | .arguments(&[Ty::PyBuffer]) |
878 | .ret_ty(Ty::Void) |
879 | .require_unsafe(); |
880 | const __CLEAR__: SlotDef = SlotDefSlotDef::new(slot:"Py_tp_clear" , func_ty:"inquiry" ) |
881 | .arguments(&[]) |
882 | .ret_ty(Ty::Int); |
883 | |
884 | #[derive (Clone, Copy)] |
885 | enum Ty { |
886 | Object, |
887 | MaybeNullObject, |
888 | NonNullObject, |
889 | IPowModulo, |
890 | CompareOp, |
891 | Int, |
892 | PyHashT, |
893 | PySsizeT, |
894 | Void, |
895 | PyBuffer, |
896 | } |
897 | |
898 | impl Ty { |
899 | fn ffi_type(self) -> TokenStream { |
900 | match self { |
901 | Ty::Object | Ty::MaybeNullObject => quote! { *mut _pyo3::ffi::PyObject }, |
902 | Ty::NonNullObject => quote! { ::std::ptr::NonNull<_pyo3::ffi::PyObject> }, |
903 | Ty::IPowModulo => quote! { _pyo3::impl_::pymethods::IPowModulo }, |
904 | Ty::Int | Ty::CompareOp => quote! { ::std::os::raw::c_int }, |
905 | Ty::PyHashT => quote! { _pyo3::ffi::Py_hash_t }, |
906 | Ty::PySsizeT => quote! { _pyo3::ffi::Py_ssize_t }, |
907 | Ty::Void => quote! { () }, |
908 | Ty::PyBuffer => quote! { *mut _pyo3::ffi::Py_buffer }, |
909 | } |
910 | } |
911 | |
912 | fn extract( |
913 | self, |
914 | ident: &syn::Ident, |
915 | arg: &FnArg<'_>, |
916 | extract_error_mode: ExtractErrorMode, |
917 | ) -> TokenStream { |
918 | let name_str = arg.name.unraw().to_string(); |
919 | match self { |
920 | Ty::Object => extract_object( |
921 | extract_error_mode, |
922 | &name_str, |
923 | quote! { |
924 | py.from_borrowed_ptr::<_pyo3::PyAny>(#ident) |
925 | }, |
926 | ), |
927 | Ty::MaybeNullObject => extract_object( |
928 | extract_error_mode, |
929 | &name_str, |
930 | quote! { |
931 | py.from_borrowed_ptr::<_pyo3::PyAny>( |
932 | if #ident.is_null() { |
933 | _pyo3::ffi::Py_None() |
934 | } else { |
935 | #ident |
936 | } |
937 | ) |
938 | }, |
939 | ), |
940 | Ty::NonNullObject => extract_object( |
941 | extract_error_mode, |
942 | &name_str, |
943 | quote! { |
944 | py.from_borrowed_ptr::<_pyo3::PyAny>(#ident.as_ptr()) |
945 | }, |
946 | ), |
947 | Ty::IPowModulo => extract_object( |
948 | extract_error_mode, |
949 | &name_str, |
950 | quote! { |
951 | #ident.to_borrowed_any(py) |
952 | }, |
953 | ), |
954 | Ty::CompareOp => extract_error_mode.handle_error( |
955 | quote! { |
956 | _pyo3::class::basic::CompareOp::from_raw(#ident) |
957 | .ok_or_else(|| _pyo3::exceptions::PyValueError::new_err("invalid comparison operator" )) |
958 | }, |
959 | ), |
960 | Ty::PySsizeT => { |
961 | let ty = arg.ty; |
962 | extract_error_mode.handle_error( |
963 | quote! { |
964 | ::std::convert::TryInto::<#ty>::try_into(#ident).map_err(|e| _pyo3::exceptions::PyValueError::new_err(e.to_string())) |
965 | }, |
966 | ) |
967 | } |
968 | // Just pass other types through unmodified |
969 | Ty::PyBuffer | Ty::Int | Ty::PyHashT | Ty::Void => quote! { #ident }, |
970 | } |
971 | } |
972 | } |
973 | |
974 | fn extract_object( |
975 | extract_error_mode: ExtractErrorMode, |
976 | name: &str, |
977 | source: TokenStream, |
978 | ) -> TokenStream { |
979 | extract_error_mode.handle_error(extract:quote! { |
980 | _pyo3::impl_::extract_argument::extract_argument( |
981 | #source, |
982 | &mut { _pyo3::impl_::extract_argument::FunctionArgumentHolder::INIT }, |
983 | #name |
984 | ) |
985 | }) |
986 | } |
987 | |
988 | enum ReturnMode { |
989 | ReturnSelf, |
990 | Conversion(TokenGenerator), |
991 | } |
992 | |
993 | impl ReturnMode { |
994 | fn return_call_output(&self, call: TokenStream) -> TokenStream { |
995 | match self { |
996 | ReturnMode::Conversion(conversion: &TokenGenerator) => quote! { |
997 | let _result: _pyo3::PyResult<#conversion> = #call; |
998 | _pyo3::callback::convert(py, _result) |
999 | }, |
1000 | ReturnMode::ReturnSelf => quote! { |
1001 | let _result: _pyo3::PyResult<()> = #call; |
1002 | _result?; |
1003 | _pyo3::ffi::Py_XINCREF(_raw_slf); |
1004 | ::std::result::Result::Ok(_raw_slf) |
1005 | }, |
1006 | } |
1007 | } |
1008 | } |
1009 | |
1010 | pub struct SlotDef { |
1011 | slot: StaticIdent, |
1012 | func_ty: StaticIdent, |
1013 | arguments: &'static [Ty], |
1014 | ret_ty: Ty, |
1015 | extract_error_mode: ExtractErrorMode, |
1016 | return_mode: Option<ReturnMode>, |
1017 | require_unsafe: bool, |
1018 | } |
1019 | |
1020 | const NO_ARGUMENTS: &[Ty] = &[]; |
1021 | |
1022 | impl SlotDef { |
1023 | const fn new(slot: &'static str, func_ty: &'static str) -> Self { |
1024 | SlotDef { |
1025 | slot: StaticIdent(slot), |
1026 | func_ty: StaticIdent(func_ty), |
1027 | arguments: NO_ARGUMENTS, |
1028 | ret_ty: Ty::Object, |
1029 | extract_error_mode: ExtractErrorMode::Raise, |
1030 | return_mode: None, |
1031 | require_unsafe: false, |
1032 | } |
1033 | } |
1034 | |
1035 | const fn arguments(mut self, arguments: &'static [Ty]) -> Self { |
1036 | self.arguments = arguments; |
1037 | self |
1038 | } |
1039 | |
1040 | const fn ret_ty(mut self, ret_ty: Ty) -> Self { |
1041 | self.ret_ty = ret_ty; |
1042 | self |
1043 | } |
1044 | |
1045 | const fn return_conversion(mut self, return_conversion: TokenGenerator) -> Self { |
1046 | self.return_mode = Some(ReturnMode::Conversion(return_conversion)); |
1047 | self |
1048 | } |
1049 | |
1050 | const fn extract_error_mode(mut self, extract_error_mode: ExtractErrorMode) -> Self { |
1051 | self.extract_error_mode = extract_error_mode; |
1052 | self |
1053 | } |
1054 | |
1055 | const fn return_self(mut self) -> Self { |
1056 | self.return_mode = Some(ReturnMode::ReturnSelf); |
1057 | self |
1058 | } |
1059 | |
1060 | const fn require_unsafe(mut self) -> Self { |
1061 | self.require_unsafe = true; |
1062 | self |
1063 | } |
1064 | |
1065 | pub fn generate_type_slot( |
1066 | &self, |
1067 | cls: &syn::Type, |
1068 | spec: &FnSpec<'_>, |
1069 | method_name: &str, |
1070 | ) -> Result<MethodAndSlotDef> { |
1071 | let SlotDef { |
1072 | slot, |
1073 | func_ty, |
1074 | arguments, |
1075 | extract_error_mode, |
1076 | ret_ty, |
1077 | return_mode, |
1078 | require_unsafe, |
1079 | } = self; |
1080 | if *require_unsafe { |
1081 | ensure_spanned!( |
1082 | spec.unsafety.is_some(), |
1083 | spec.name.span() => format!("` {}` must be `unsafe fn`" , method_name) |
1084 | ); |
1085 | } |
1086 | let arg_types: &Vec<_> = &arguments.iter().map(|arg| arg.ffi_type()).collect(); |
1087 | let arg_idents: &Vec<_> = &(0..arguments.len()) |
1088 | .map(|i| format_ident!("arg {}" , i)) |
1089 | .collect(); |
1090 | let wrapper_ident = format_ident!("__pymethod_ {}__" , method_name); |
1091 | let ret_ty = ret_ty.ffi_type(); |
1092 | let body = generate_method_body( |
1093 | cls, |
1094 | spec, |
1095 | arguments, |
1096 | *extract_error_mode, |
1097 | return_mode.as_ref(), |
1098 | )?; |
1099 | let name = spec.name; |
1100 | let associated_method = quote! { |
1101 | unsafe fn #wrapper_ident( |
1102 | py: _pyo3::Python<'_>, |
1103 | _raw_slf: *mut _pyo3::ffi::PyObject, |
1104 | #(#arg_idents: #arg_types),* |
1105 | ) -> _pyo3::PyResult<#ret_ty> { |
1106 | let function = #cls::#name; // Shadow the method name to avoid #3017 |
1107 | let _slf = _raw_slf; |
1108 | #body |
1109 | } |
1110 | }; |
1111 | let slot_def = quote! {{ |
1112 | unsafe extern "C" fn trampoline( |
1113 | _slf: *mut _pyo3::ffi::PyObject, |
1114 | #(#arg_idents: #arg_types),* |
1115 | ) -> #ret_ty |
1116 | { |
1117 | _pyo3::impl_::trampoline:: #func_ty ( |
1118 | _slf, |
1119 | #(#arg_idents,)* |
1120 | #cls::#wrapper_ident |
1121 | ) |
1122 | } |
1123 | |
1124 | _pyo3::ffi::PyType_Slot { |
1125 | slot: _pyo3::ffi::#slot, |
1126 | pfunc: trampoline as _pyo3::ffi::#func_ty as _ |
1127 | } |
1128 | }}; |
1129 | Ok(MethodAndSlotDef { |
1130 | associated_method, |
1131 | slot_def, |
1132 | }) |
1133 | } |
1134 | } |
1135 | |
1136 | fn generate_method_body( |
1137 | cls: &syn::Type, |
1138 | spec: &FnSpec<'_>, |
1139 | arguments: &[Ty], |
1140 | extract_error_mode: ExtractErrorMode, |
1141 | return_mode: Option<&ReturnMode>, |
1142 | ) -> Result<TokenStream> { |
1143 | let self_arg: TokenStream = spec.tp.self_arg(cls:Some(cls), extract_error_mode); |
1144 | let rust_name: &Ident = spec.name; |
1145 | let args: Vec = extract_proto_arguments(spec, proto_args:arguments, extract_error_mode)?; |
1146 | let call: TokenStream = quote! { _pyo3::callback::convert(py, #cls::#rust_name(#self_arg #(#args),*)) }; |
1147 | Ok(if let Some(return_mode: &ReturnMode) = return_mode { |
1148 | return_mode.return_call_output(call) |
1149 | } else { |
1150 | call |
1151 | }) |
1152 | } |
1153 | |
1154 | struct SlotFragmentDef { |
1155 | fragment: &'static str, |
1156 | arguments: &'static [Ty], |
1157 | extract_error_mode: ExtractErrorMode, |
1158 | ret_ty: Ty, |
1159 | } |
1160 | |
1161 | impl SlotFragmentDef { |
1162 | const fn new(fragment: &'static str, arguments: &'static [Ty]) -> Self { |
1163 | SlotFragmentDef { |
1164 | fragment, |
1165 | arguments, |
1166 | extract_error_mode: ExtractErrorMode::Raise, |
1167 | ret_ty: Ty::Void, |
1168 | } |
1169 | } |
1170 | |
1171 | const fn extract_error_mode(mut self, extract_error_mode: ExtractErrorMode) -> Self { |
1172 | self.extract_error_mode = extract_error_mode; |
1173 | self |
1174 | } |
1175 | |
1176 | const fn ret_ty(mut self, ret_ty: Ty) -> Self { |
1177 | self.ret_ty = ret_ty; |
1178 | self |
1179 | } |
1180 | |
1181 | fn generate_pyproto_fragment(&self, cls: &syn::Type, spec: &FnSpec<'_>) -> Result<TokenStream> { |
1182 | let SlotFragmentDef { |
1183 | fragment, |
1184 | arguments, |
1185 | extract_error_mode, |
1186 | ret_ty, |
1187 | } = self; |
1188 | let fragment_trait = format_ident!("PyClass {}SlotFragment" , fragment); |
1189 | let method = syn::Ident::new(fragment, Span::call_site()); |
1190 | let wrapper_ident = format_ident!("__pymethod_ {}__" , fragment); |
1191 | let arg_types: &Vec<_> = &arguments.iter().map(|arg| arg.ffi_type()).collect(); |
1192 | let arg_idents: &Vec<_> = &(0..arguments.len()) |
1193 | .map(|i| format_ident!("arg {}" , i)) |
1194 | .collect(); |
1195 | let body = generate_method_body(cls, spec, arguments, *extract_error_mode, None)?; |
1196 | let ret_ty = ret_ty.ffi_type(); |
1197 | Ok(quote! { |
1198 | impl _pyo3::impl_::pyclass::#fragment_trait<#cls> for _pyo3::impl_::pyclass::PyClassImplCollector<#cls> { |
1199 | |
1200 | #[inline] |
1201 | unsafe fn #method( |
1202 | self, |
1203 | py: _pyo3::Python, |
1204 | _raw_slf: *mut _pyo3::ffi::PyObject, |
1205 | #(#arg_idents: #arg_types),* |
1206 | ) -> _pyo3::PyResult<#ret_ty> { |
1207 | impl #cls { |
1208 | unsafe fn #wrapper_ident( |
1209 | py: _pyo3::Python, |
1210 | _raw_slf: *mut _pyo3::ffi::PyObject, |
1211 | #(#arg_idents: #arg_types),* |
1212 | ) -> _pyo3::PyResult<#ret_ty> { |
1213 | let _slf = _raw_slf; |
1214 | #body |
1215 | } |
1216 | } |
1217 | #cls::#wrapper_ident(py, _raw_slf, #(#arg_idents),*) |
1218 | } |
1219 | } |
1220 | }) |
1221 | } |
1222 | } |
1223 | |
1224 | const __GETATTRIBUTE__: SlotFragmentDef = |
1225 | SlotFragmentDef::new(fragment:"__getattribute__" , &[Ty::Object]).ret_ty(Ty::Object); |
1226 | const __GETATTR__: SlotFragmentDef = |
1227 | SlotFragmentDef::new(fragment:"__getattr__" , &[Ty::Object]).ret_ty(Ty::Object); |
1228 | const __SETATTR__: SlotFragmentDef = |
1229 | SlotFragmentDef::new(fragment:"__setattr__" , &[Ty::Object, Ty::NonNullObject]); |
1230 | const __DELATTR__: SlotFragmentDef = SlotFragmentDef::new(fragment:"__delattr__" , &[Ty::Object]); |
1231 | const __SET__: SlotFragmentDef = SlotFragmentDef::new(fragment:"__set__" , &[Ty::Object, Ty::NonNullObject]); |
1232 | const __DELETE__: SlotFragmentDef = SlotFragmentDef::new(fragment:"__delete__" , &[Ty::Object]); |
1233 | const __SETITEM__: SlotFragmentDef = |
1234 | SlotFragmentDef::new(fragment:"__setitem__" , &[Ty::Object, Ty::NonNullObject]); |
1235 | const __DELITEM__: SlotFragmentDef = SlotFragmentDef::new(fragment:"__delitem__" , &[Ty::Object]); |
1236 | |
1237 | macro_rules! binary_num_slot_fragment_def { |
1238 | ($ident:ident, $name:literal) => { |
1239 | const $ident: SlotFragmentDef = SlotFragmentDef::new($name, &[Ty::Object]) |
1240 | .extract_error_mode(ExtractErrorMode::NotImplemented) |
1241 | .ret_ty(Ty::Object); |
1242 | }; |
1243 | } |
1244 | |
1245 | binary_num_slot_fragment_def!(__ADD__, "__add__" ); |
1246 | binary_num_slot_fragment_def!(__RADD__, "__radd__" ); |
1247 | binary_num_slot_fragment_def!(__SUB__, "__sub__" ); |
1248 | binary_num_slot_fragment_def!(__RSUB__, "__rsub__" ); |
1249 | binary_num_slot_fragment_def!(__MUL__, "__mul__" ); |
1250 | binary_num_slot_fragment_def!(__RMUL__, "__rmul__" ); |
1251 | binary_num_slot_fragment_def!(__MATMUL__, "__matmul__" ); |
1252 | binary_num_slot_fragment_def!(__RMATMUL__, "__rmatmul__" ); |
1253 | binary_num_slot_fragment_def!(__FLOORDIV__, "__floordiv__" ); |
1254 | binary_num_slot_fragment_def!(__RFLOORDIV__, "__rfloordiv__" ); |
1255 | binary_num_slot_fragment_def!(__TRUEDIV__, "__truediv__" ); |
1256 | binary_num_slot_fragment_def!(__RTRUEDIV__, "__rtruediv__" ); |
1257 | binary_num_slot_fragment_def!(__DIVMOD__, "__divmod__" ); |
1258 | binary_num_slot_fragment_def!(__RDIVMOD__, "__rdivmod__" ); |
1259 | binary_num_slot_fragment_def!(__MOD__, "__mod__" ); |
1260 | binary_num_slot_fragment_def!(__RMOD__, "__rmod__" ); |
1261 | binary_num_slot_fragment_def!(__LSHIFT__, "__lshift__" ); |
1262 | binary_num_slot_fragment_def!(__RLSHIFT__, "__rlshift__" ); |
1263 | binary_num_slot_fragment_def!(__RSHIFT__, "__rshift__" ); |
1264 | binary_num_slot_fragment_def!(__RRSHIFT__, "__rrshift__" ); |
1265 | binary_num_slot_fragment_def!(__AND__, "__and__" ); |
1266 | binary_num_slot_fragment_def!(__RAND__, "__rand__" ); |
1267 | binary_num_slot_fragment_def!(__XOR__, "__xor__" ); |
1268 | binary_num_slot_fragment_def!(__RXOR__, "__rxor__" ); |
1269 | binary_num_slot_fragment_def!(__OR__, "__or__" ); |
1270 | binary_num_slot_fragment_def!(__ROR__, "__ror__" ); |
1271 | |
1272 | const __POW__: SlotFragmentDef = SlotFragmentDefSlotFragmentDef::new(fragment:"__pow__" , &[Ty::Object, Ty::Object]) |
1273 | .extract_error_mode(ExtractErrorMode::NotImplemented) |
1274 | .ret_ty(Ty::Object); |
1275 | const __RPOW__: SlotFragmentDef = SlotFragmentDefSlotFragmentDef::new(fragment:"__rpow__" , &[Ty::Object, Ty::Object]) |
1276 | .extract_error_mode(ExtractErrorMode::NotImplemented) |
1277 | .ret_ty(Ty::Object); |
1278 | |
1279 | const __LT__: SlotFragmentDef = SlotFragmentDefSlotFragmentDef::new(fragment:"__lt__" , &[Ty::Object]) |
1280 | .extract_error_mode(ExtractErrorMode::NotImplemented) |
1281 | .ret_ty(Ty::Object); |
1282 | const __LE__: SlotFragmentDef = SlotFragmentDefSlotFragmentDef::new(fragment:"__le__" , &[Ty::Object]) |
1283 | .extract_error_mode(ExtractErrorMode::NotImplemented) |
1284 | .ret_ty(Ty::Object); |
1285 | const __EQ__: SlotFragmentDef = SlotFragmentDefSlotFragmentDef::new(fragment:"__eq__" , &[Ty::Object]) |
1286 | .extract_error_mode(ExtractErrorMode::NotImplemented) |
1287 | .ret_ty(Ty::Object); |
1288 | const __NE__: SlotFragmentDef = SlotFragmentDefSlotFragmentDef::new(fragment:"__ne__" , &[Ty::Object]) |
1289 | .extract_error_mode(ExtractErrorMode::NotImplemented) |
1290 | .ret_ty(Ty::Object); |
1291 | const __GT__: SlotFragmentDef = SlotFragmentDefSlotFragmentDef::new(fragment:"__gt__" , &[Ty::Object]) |
1292 | .extract_error_mode(ExtractErrorMode::NotImplemented) |
1293 | .ret_ty(Ty::Object); |
1294 | const __GE__: SlotFragmentDef = SlotFragmentDefSlotFragmentDef::new(fragment:"__ge__" , &[Ty::Object]) |
1295 | .extract_error_mode(ExtractErrorMode::NotImplemented) |
1296 | .ret_ty(Ty::Object); |
1297 | |
1298 | fn extract_proto_arguments( |
1299 | spec: &FnSpec<'_>, |
1300 | proto_args: &[Ty], |
1301 | extract_error_mode: ExtractErrorMode, |
1302 | ) -> Result<Vec<TokenStream>> { |
1303 | let mut args: Vec = Vec::with_capacity(spec.signature.arguments.len()); |
1304 | let mut non_python_args: usize = 0; |
1305 | |
1306 | for arg: &FnArg<'_> in &spec.signature.arguments { |
1307 | if arg.py { |
1308 | args.push(quote! { py }); |
1309 | } else { |
1310 | let ident: Ident = syn::Ident::new(&format!("arg {}" , non_python_args), Span::call_site()); |
1311 | let conversions: TokenStream = proto_args&Ty.get(non_python_args) |
1312 | .ok_or_else(|| err_spanned!(arg.ty.span() => format!("Expected at most {} non-python arguments" , proto_args.len())))? |
1313 | .extract(&ident, arg, extract_error_mode); |
1314 | non_python_args += 1; |
1315 | args.push(conversions); |
1316 | } |
1317 | } |
1318 | |
1319 | if non_python_args != proto_args.len() { |
1320 | bail_spanned!(spec.name.span() => format!("Expected {} arguments, got {}" , proto_args.len(), non_python_args)); |
1321 | } |
1322 | Ok(args) |
1323 | } |
1324 | |
1325 | struct StaticIdent(&'static str); |
1326 | |
1327 | impl ToTokens for StaticIdent { |
1328 | fn to_tokens(&self, tokens: &mut TokenStream) { |
1329 | syn::Ident::new(self.0, Span::call_site()).to_tokens(tokens) |
1330 | } |
1331 | } |
1332 | |
1333 | struct TokenGenerator(fn() -> TokenStream); |
1334 | |
1335 | impl ToTokens for TokenGenerator { |
1336 | fn to_tokens(&self, tokens: &mut TokenStream) { |
1337 | self.0().to_tokens(tokens) |
1338 | } |
1339 | } |
1340 | |