1 | use crate::attributes::{ |
2 | self, get_pyo3_options, CrateAttribute, DefaultAttribute, FromPyWithAttribute, |
3 | RenameAllAttribute, RenamingRule, |
4 | }; |
5 | use crate::utils::{self, deprecated_from_py_with, Ctx}; |
6 | use proc_macro2::TokenStream; |
7 | use quote::{format_ident, quote, quote_spanned, ToTokens}; |
8 | use syn::{ |
9 | ext::IdentExt, |
10 | parenthesized, |
11 | parse::{Parse, ParseStream}, |
12 | parse_quote, |
13 | punctuated::Punctuated, |
14 | spanned::Spanned, |
15 | Attribute, DataEnum, DeriveInput, Fields, Ident, LitStr, Result, Token, |
16 | }; |
17 | |
18 | /// Describes derivation input of an enum. |
19 | struct Enum<'a> { |
20 | enum_ident: &'a Ident, |
21 | variants: Vec<Container<'a>>, |
22 | } |
23 | |
24 | impl<'a> Enum<'a> { |
25 | /// Construct a new enum representation. |
26 | /// |
27 | /// `data_enum` is the `syn` representation of the input enum, `ident` is the |
28 | /// `Identifier` of the enum. |
29 | fn new(data_enum: &'a DataEnum, ident: &'a Ident, options: ContainerOptions) -> Result<Self> { |
30 | ensure_spanned!( |
31 | !data_enum.variants.is_empty(), |
32 | ident.span() => "cannot derive FromPyObject for empty enum" |
33 | ); |
34 | let variants = data_enum |
35 | .variants |
36 | .iter() |
37 | .map(|variant| { |
38 | let mut variant_options = ContainerOptions::from_attrs(&variant.attrs)?; |
39 | if let Some(rename_all) = &options.rename_all { |
40 | ensure_spanned!( |
41 | variant_options.rename_all.is_none(), |
42 | variant_options.rename_all.span() => "Useless variant `rename_all` - enum is already annotated with `rename_all" |
43 | ); |
44 | variant_options.rename_all = Some(rename_all.clone()); |
45 | |
46 | } |
47 | let var_ident = &variant.ident; |
48 | Container::new( |
49 | &variant.fields, |
50 | parse_quote!(#ident::#var_ident), |
51 | variant_options, |
52 | ) |
53 | }) |
54 | .collect::<Result<Vec<_>>>()?; |
55 | |
56 | Ok(Enum { |
57 | enum_ident: ident, |
58 | variants, |
59 | }) |
60 | } |
61 | |
62 | /// Build derivation body for enums. |
63 | fn build(&self, ctx: &Ctx) -> TokenStream { |
64 | let Ctx { pyo3_path, .. } = ctx; |
65 | let mut var_extracts = Vec::new(); |
66 | let mut variant_names = Vec::new(); |
67 | let mut error_names = Vec::new(); |
68 | |
69 | for var in &self.variants { |
70 | let struct_derive = var.build(ctx); |
71 | let ext = quote!({ |
72 | let maybe_ret = || -> #pyo3_path::PyResult<Self> { |
73 | #struct_derive |
74 | }(); |
75 | |
76 | match maybe_ret { |
77 | ok @ ::std::result::Result::Ok(_) => return ok, |
78 | ::std::result::Result::Err(err) => err |
79 | } |
80 | }); |
81 | |
82 | var_extracts.push(ext); |
83 | variant_names.push(var.path.segments.last().unwrap().ident.to_string()); |
84 | error_names.push(&var.err_name); |
85 | } |
86 | let ty_name = self.enum_ident.to_string(); |
87 | quote!( |
88 | let errors = [ |
89 | #(#var_extracts),* |
90 | ]; |
91 | ::std::result::Result::Err( |
92 | #pyo3_path::impl_::frompyobject::failed_to_extract_enum( |
93 | obj.py(), |
94 | #ty_name, |
95 | &[#(#variant_names),*], |
96 | &[#(#error_names),*], |
97 | &errors |
98 | ) |
99 | ) |
100 | ) |
101 | } |
102 | } |
103 | |
104 | struct NamedStructField<'a> { |
105 | ident: &'a syn::Ident, |
106 | getter: Option<FieldGetter>, |
107 | from_py_with: Option<FromPyWithAttribute>, |
108 | default: Option<DefaultAttribute>, |
109 | } |
110 | |
111 | struct TupleStructField { |
112 | from_py_with: Option<FromPyWithAttribute>, |
113 | } |
114 | |
115 | /// Container Style |
116 | /// |
117 | /// Covers Structs, Tuplestructs and corresponding Newtypes. |
118 | enum ContainerType<'a> { |
119 | /// Struct Container, e.g. `struct Foo { a: String }` |
120 | /// |
121 | /// Variant contains the list of field identifiers and the corresponding extraction call. |
122 | Struct(Vec<NamedStructField<'a>>), |
123 | /// Newtype struct container, e.g. `#[transparent] struct Foo { a: String }` |
124 | /// |
125 | /// The field specified by the identifier is extracted directly from the object. |
126 | StructNewtype(&'a syn::Ident, Option<FromPyWithAttribute>), |
127 | /// Tuple struct, e.g. `struct Foo(String)`. |
128 | /// |
129 | /// Variant contains a list of conversion methods for each of the fields that are directly |
130 | /// extracted from the tuple. |
131 | Tuple(Vec<TupleStructField>), |
132 | /// Tuple newtype, e.g. `#[transparent] struct Foo(String)` |
133 | /// |
134 | /// The wrapped field is directly extracted from the object. |
135 | TupleNewtype(Option<FromPyWithAttribute>), |
136 | } |
137 | |
138 | /// Data container |
139 | /// |
140 | /// Either describes a struct or an enum variant. |
141 | struct Container<'a> { |
142 | path: syn::Path, |
143 | ty: ContainerType<'a>, |
144 | err_name: String, |
145 | rename_rule: Option<RenamingRule>, |
146 | } |
147 | |
148 | impl<'a> Container<'a> { |
149 | /// Construct a container based on fields, identifier and attributes. |
150 | /// |
151 | /// Fails if the variant has no fields or incompatible attributes. |
152 | fn new(fields: &'a Fields, path: syn::Path, options: ContainerOptions) -> Result<Self> { |
153 | let style = match fields { |
154 | Fields::Unnamed(unnamed) if !unnamed.unnamed.is_empty() => { |
155 | ensure_spanned!( |
156 | options.rename_all.is_none(), |
157 | options.rename_all.span() => "`rename_all` is useless on tuple structs and variants." |
158 | ); |
159 | let mut tuple_fields = unnamed |
160 | .unnamed |
161 | .iter() |
162 | .map(|field| { |
163 | let attrs = FieldPyO3Attributes::from_attrs(&field.attrs)?; |
164 | ensure_spanned!( |
165 | attrs.getter.is_none(), |
166 | field.span() => "`getter` is not permitted on tuple struct elements." |
167 | ); |
168 | ensure_spanned!( |
169 | attrs.default.is_none(), |
170 | field.span() => "`default` is not permitted on tuple struct elements." |
171 | ); |
172 | Ok(TupleStructField { |
173 | from_py_with: attrs.from_py_with, |
174 | }) |
175 | }) |
176 | .collect::<Result<Vec<_>>>()?; |
177 | |
178 | if tuple_fields.len() == 1 { |
179 | // Always treat a 1-length tuple struct as "transparent", even without the |
180 | // explicit annotation. |
181 | let field = tuple_fields.pop().unwrap(); |
182 | ContainerType::TupleNewtype(field.from_py_with) |
183 | } else if options.transparent { |
184 | bail_spanned!( |
185 | fields.span() => "transparent structs and variants can only have 1 field" |
186 | ); |
187 | } else { |
188 | ContainerType::Tuple(tuple_fields) |
189 | } |
190 | } |
191 | Fields::Named(named) if !named.named.is_empty() => { |
192 | let mut struct_fields = named |
193 | .named |
194 | .iter() |
195 | .map(|field| { |
196 | let ident = field |
197 | .ident |
198 | .as_ref() |
199 | .expect("Named fields should have identifiers" ); |
200 | let mut attrs = FieldPyO3Attributes::from_attrs(&field.attrs)?; |
201 | |
202 | if let Some(ref from_item_all) = options.from_item_all { |
203 | if let Some(replaced) = attrs.getter.replace(FieldGetter::GetItem(None)) |
204 | { |
205 | match replaced { |
206 | FieldGetter::GetItem(Some(item_name)) => { |
207 | attrs.getter = Some(FieldGetter::GetItem(Some(item_name))); |
208 | } |
209 | FieldGetter::GetItem(None) => bail_spanned!(from_item_all.span() => "Useless `item` - the struct is already annotated with `from_item_all`" ), |
210 | FieldGetter::GetAttr(_) => bail_spanned!( |
211 | from_item_all.span() => "The struct is already annotated with `from_item_all`, `attribute` is not allowed" |
212 | ), |
213 | } |
214 | } |
215 | } |
216 | |
217 | Ok(NamedStructField { |
218 | ident, |
219 | getter: attrs.getter, |
220 | from_py_with: attrs.from_py_with, |
221 | default: attrs.default, |
222 | }) |
223 | }) |
224 | .collect::<Result<Vec<_>>>()?; |
225 | if struct_fields.iter().all(|field| field.default.is_some()) { |
226 | bail_spanned!( |
227 | fields.span() => "cannot derive FromPyObject for structs and variants with only default values" |
228 | ) |
229 | } else if options.transparent { |
230 | ensure_spanned!( |
231 | struct_fields.len() == 1, |
232 | fields.span() => "transparent structs and variants can only have 1 field" |
233 | ); |
234 | ensure_spanned!( |
235 | options.rename_all.is_none(), |
236 | options.rename_all.span() => "`rename_all` is not permitted on `transparent` structs and variants" |
237 | ); |
238 | let field = struct_fields.pop().unwrap(); |
239 | ensure_spanned!( |
240 | field.getter.is_none(), |
241 | field.ident.span() => "`transparent` structs may not have a `getter` for the inner field" |
242 | ); |
243 | ContainerType::StructNewtype(field.ident, field.from_py_with) |
244 | } else { |
245 | ContainerType::Struct(struct_fields) |
246 | } |
247 | } |
248 | _ => bail_spanned!( |
249 | fields.span() => "cannot derive FromPyObject for empty structs and variants" |
250 | ), |
251 | }; |
252 | let err_name = options.annotation.map_or_else( |
253 | || path.segments.last().unwrap().ident.to_string(), |
254 | |lit_str| lit_str.value(), |
255 | ); |
256 | |
257 | let v = Container { |
258 | path, |
259 | ty: style, |
260 | err_name, |
261 | rename_rule: options.rename_all.map(|v| v.value.rule), |
262 | }; |
263 | Ok(v) |
264 | } |
265 | |
266 | fn name(&self) -> String { |
267 | let mut value = String::new(); |
268 | for segment in &self.path.segments { |
269 | if !value.is_empty() { |
270 | value.push_str("::" ); |
271 | } |
272 | value.push_str(&segment.ident.to_string()); |
273 | } |
274 | value |
275 | } |
276 | |
277 | /// Build derivation body for a struct. |
278 | fn build(&self, ctx: &Ctx) -> TokenStream { |
279 | match &self.ty { |
280 | ContainerType::StructNewtype(ident, from_py_with) => { |
281 | self.build_newtype_struct(Some(ident), from_py_with, ctx) |
282 | } |
283 | ContainerType::TupleNewtype(from_py_with) => { |
284 | self.build_newtype_struct(None, from_py_with, ctx) |
285 | } |
286 | ContainerType::Tuple(tups) => self.build_tuple_struct(tups, ctx), |
287 | ContainerType::Struct(tups) => self.build_struct(tups, ctx), |
288 | } |
289 | } |
290 | |
291 | fn build_newtype_struct( |
292 | &self, |
293 | field_ident: Option<&Ident>, |
294 | from_py_with: &Option<FromPyWithAttribute>, |
295 | ctx: &Ctx, |
296 | ) -> TokenStream { |
297 | let Ctx { pyo3_path, .. } = ctx; |
298 | let self_ty = &self.path; |
299 | let struct_name = self.name(); |
300 | if let Some(ident) = field_ident { |
301 | let field_name = ident.to_string(); |
302 | if let Some(FromPyWithAttribute { |
303 | kw, |
304 | value: expr_path, |
305 | }) = from_py_with |
306 | { |
307 | let deprecation = deprecated_from_py_with(expr_path).unwrap_or_default(); |
308 | |
309 | let extractor = quote_spanned! { kw.span => |
310 | { let from_py_with: fn(_) -> _ = #expr_path; from_py_with } |
311 | }; |
312 | quote! { |
313 | #deprecation |
314 | Ok(#self_ty { |
315 | #ident: #pyo3_path::impl_::frompyobject::extract_struct_field_with(#extractor, obj, #struct_name, #field_name)? |
316 | }) |
317 | } |
318 | } else { |
319 | quote! { |
320 | Ok(#self_ty { |
321 | #ident: #pyo3_path::impl_::frompyobject::extract_struct_field(obj, #struct_name, #field_name)? |
322 | }) |
323 | } |
324 | } |
325 | } else if let Some(FromPyWithAttribute { |
326 | kw, |
327 | value: expr_path, |
328 | }) = from_py_with |
329 | { |
330 | let deprecation = deprecated_from_py_with(expr_path).unwrap_or_default(); |
331 | |
332 | let extractor = quote_spanned! { kw.span => |
333 | { let from_py_with: fn(_) -> _ = #expr_path; from_py_with } |
334 | }; |
335 | quote! { |
336 | #deprecation |
337 | #pyo3_path::impl_::frompyobject::extract_tuple_struct_field_with(#extractor, obj, #struct_name, 0).map(#self_ty) |
338 | } |
339 | } else { |
340 | quote! { |
341 | #pyo3_path::impl_::frompyobject::extract_tuple_struct_field(obj, #struct_name, 0).map(#self_ty) |
342 | } |
343 | } |
344 | } |
345 | |
346 | fn build_tuple_struct(&self, struct_fields: &[TupleStructField], ctx: &Ctx) -> TokenStream { |
347 | let Ctx { pyo3_path, .. } = ctx; |
348 | let self_ty = &self.path; |
349 | let struct_name = &self.name(); |
350 | let field_idents: Vec<_> = (0..struct_fields.len()) |
351 | .map(|i| format_ident!("arg {}" , i)) |
352 | .collect(); |
353 | let fields = struct_fields.iter().zip(&field_idents).enumerate().map(|(index, (field, ident))| { |
354 | if let Some(FromPyWithAttribute { |
355 | kw, |
356 | value: expr_path, .. |
357 | }) = &field.from_py_with { |
358 | let extractor = quote_spanned! { kw.span => |
359 | { let from_py_with: fn(_) -> _ = #expr_path; from_py_with } |
360 | }; |
361 | quote! { |
362 | #pyo3_path::impl_::frompyobject::extract_tuple_struct_field_with(#extractor, &#ident, #struct_name, #index)? |
363 | } |
364 | } else { |
365 | quote!{ |
366 | #pyo3_path::impl_::frompyobject::extract_tuple_struct_field(&#ident, #struct_name, #index)? |
367 | }} |
368 | }); |
369 | |
370 | let deprecations = struct_fields |
371 | .iter() |
372 | .filter_map(|fields| fields.from_py_with.as_ref()) |
373 | .filter_map(|kw| deprecated_from_py_with(&kw.value)) |
374 | .collect::<TokenStream>(); |
375 | |
376 | quote!( |
377 | #deprecations |
378 | match #pyo3_path::types::PyAnyMethods::extract(obj) { |
379 | ::std::result::Result::Ok((#(#field_idents),*)) => ::std::result::Result::Ok(#self_ty(#(#fields),*)), |
380 | ::std::result::Result::Err(err) => ::std::result::Result::Err(err), |
381 | } |
382 | ) |
383 | } |
384 | |
385 | fn build_struct(&self, struct_fields: &[NamedStructField<'_>], ctx: &Ctx) -> TokenStream { |
386 | let Ctx { pyo3_path, .. } = ctx; |
387 | let self_ty = &self.path; |
388 | let struct_name = self.name(); |
389 | let mut fields: Punctuated<TokenStream, Token![,]> = Punctuated::new(); |
390 | for field in struct_fields { |
391 | let ident = field.ident; |
392 | let field_name = ident.unraw().to_string(); |
393 | let getter = match field.getter.as_ref().unwrap_or(&FieldGetter::GetAttr(None)) { |
394 | FieldGetter::GetAttr(Some(name)) => { |
395 | quote!(#pyo3_path::types::PyAnyMethods::getattr(obj, #pyo3_path::intern!(obj.py(), #name))) |
396 | } |
397 | FieldGetter::GetAttr(None) => { |
398 | let name = self |
399 | .rename_rule |
400 | .map(|rule| utils::apply_renaming_rule(rule, &field_name)); |
401 | let name = name.as_deref().unwrap_or(&field_name); |
402 | quote!(#pyo3_path::types::PyAnyMethods::getattr(obj, #pyo3_path::intern!(obj.py(), #name))) |
403 | } |
404 | FieldGetter::GetItem(Some(syn::Lit::Str(key))) => { |
405 | quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #pyo3_path::intern!(obj.py(), #key))) |
406 | } |
407 | FieldGetter::GetItem(Some(key)) => { |
408 | quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #key)) |
409 | } |
410 | FieldGetter::GetItem(None) => { |
411 | let name = self |
412 | .rename_rule |
413 | .map(|rule| utils::apply_renaming_rule(rule, &field_name)); |
414 | let name = name.as_deref().unwrap_or(&field_name); |
415 | quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #pyo3_path::intern!(obj.py(), #name))) |
416 | } |
417 | }; |
418 | let extractor = if let Some(FromPyWithAttribute { |
419 | kw, |
420 | value: expr_path, |
421 | }) = &field.from_py_with |
422 | { |
423 | let extractor = quote_spanned! { kw.span => |
424 | { let from_py_with: fn(_) -> _ = #expr_path; from_py_with } |
425 | }; |
426 | quote! (#pyo3_path::impl_::frompyobject::extract_struct_field_with(#extractor, &#getter?, #struct_name, #field_name)?) |
427 | } else { |
428 | quote!(#pyo3_path::impl_::frompyobject::extract_struct_field(&value, #struct_name, #field_name)?) |
429 | }; |
430 | let extracted = if let Some(default) = &field.default { |
431 | let default_expr = if let Some(default_expr) = &default.value { |
432 | default_expr.to_token_stream() |
433 | } else { |
434 | quote!(::std::default::Default::default()) |
435 | }; |
436 | quote!(if let ::std::result::Result::Ok(value) = #getter { |
437 | #extractor |
438 | } else { |
439 | #default_expr |
440 | }) |
441 | } else { |
442 | quote!({ |
443 | let value = #getter?; |
444 | #extractor |
445 | }) |
446 | }; |
447 | |
448 | fields.push(quote!(#ident: #extracted)); |
449 | } |
450 | |
451 | let d = struct_fields |
452 | .iter() |
453 | .filter_map(|field| field.from_py_with.as_ref()) |
454 | .filter_map(|kw| deprecated_from_py_with(&kw.value)) |
455 | .collect::<TokenStream>(); |
456 | |
457 | quote!(#d ::std::result::Result::Ok(#self_ty{#fields})) |
458 | } |
459 | } |
460 | |
461 | #[derive (Default)] |
462 | struct ContainerOptions { |
463 | /// Treat the Container as a Wrapper, directly extract its fields from the input object. |
464 | transparent: bool, |
465 | /// Force every field to be extracted from item of source Python object. |
466 | from_item_all: Option<attributes::kw::from_item_all>, |
467 | /// Change the name of an enum variant in the generated error message. |
468 | annotation: Option<syn::LitStr>, |
469 | /// Change the path for the pyo3 crate |
470 | krate: Option<CrateAttribute>, |
471 | /// Converts the field idents according to the [RenamingRule] before extraction |
472 | rename_all: Option<RenameAllAttribute>, |
473 | } |
474 | |
475 | /// Attributes for deriving FromPyObject scoped on containers. |
476 | enum ContainerPyO3Attribute { |
477 | /// Treat the Container as a Wrapper, directly extract its fields from the input object. |
478 | Transparent(attributes::kw::transparent), |
479 | /// Force every field to be extracted from item of source Python object. |
480 | ItemAll(attributes::kw::from_item_all), |
481 | /// Change the name of an enum variant in the generated error message. |
482 | ErrorAnnotation(LitStr), |
483 | /// Change the path for the pyo3 crate |
484 | Crate(CrateAttribute), |
485 | /// Converts the field idents according to the [RenamingRule] before extraction |
486 | RenameAll(RenameAllAttribute), |
487 | } |
488 | |
489 | impl Parse for ContainerPyO3Attribute { |
490 | fn parse(input: ParseStream<'_>) -> Result<Self> { |
491 | let lookahead: Lookahead1<'_> = input.lookahead1(); |
492 | if lookahead.peek(token:attributes::kw::transparent) { |
493 | let kw: attributes::kw::transparent = input.parse()?; |
494 | Ok(ContainerPyO3Attribute::Transparent(kw)) |
495 | } else if lookahead.peek(token:attributes::kw::from_item_all) { |
496 | let kw: attributes::kw::from_item_all = input.parse()?; |
497 | Ok(ContainerPyO3Attribute::ItemAll(kw)) |
498 | } else if lookahead.peek(token:attributes::kw::annotation) { |
499 | let _: attributes::kw::annotation = input.parse()?; |
500 | let _: Token![=] = input.parse()?; |
501 | input.parse().map(op:ContainerPyO3Attribute::ErrorAnnotation) |
502 | } else if lookahead.peek(Token![crate]) { |
503 | input.parse().map(op:ContainerPyO3Attribute::Crate) |
504 | } else if lookahead.peek(token:attributes::kw::rename_all) { |
505 | input.parse().map(op:ContainerPyO3Attribute::RenameAll) |
506 | } else { |
507 | Err(lookahead.error()) |
508 | } |
509 | } |
510 | } |
511 | |
512 | impl ContainerOptions { |
513 | fn from_attrs(attrs: &[Attribute]) -> Result<Self> { |
514 | let mut options = ContainerOptions::default(); |
515 | |
516 | for attr in attrs { |
517 | if let Some(pyo3_attrs) = get_pyo3_options(attr)? { |
518 | for pyo3_attr in pyo3_attrs { |
519 | match pyo3_attr { |
520 | ContainerPyO3Attribute::Transparent(kw) => { |
521 | ensure_spanned!( |
522 | !options.transparent, |
523 | kw.span() => "`transparent` may only be provided once" |
524 | ); |
525 | options.transparent = true; |
526 | } |
527 | ContainerPyO3Attribute::ItemAll(kw) => { |
528 | ensure_spanned!( |
529 | options.from_item_all.is_none(), |
530 | kw.span() => "`from_item_all` may only be provided once" |
531 | ); |
532 | options.from_item_all = Some(kw); |
533 | } |
534 | ContainerPyO3Attribute::ErrorAnnotation(lit_str) => { |
535 | ensure_spanned!( |
536 | options.annotation.is_none(), |
537 | lit_str.span() => "`annotation` may only be provided once" |
538 | ); |
539 | options.annotation = Some(lit_str); |
540 | } |
541 | ContainerPyO3Attribute::Crate(path) => { |
542 | ensure_spanned!( |
543 | options.krate.is_none(), |
544 | path.span() => "`crate` may only be provided once" |
545 | ); |
546 | options.krate = Some(path); |
547 | } |
548 | ContainerPyO3Attribute::RenameAll(rename_all) => { |
549 | ensure_spanned!( |
550 | options.rename_all.is_none(), |
551 | rename_all.span() => "`rename_all` may only be provided once" |
552 | ); |
553 | options.rename_all = Some(rename_all); |
554 | } |
555 | } |
556 | } |
557 | } |
558 | } |
559 | Ok(options) |
560 | } |
561 | } |
562 | |
563 | /// Attributes for deriving FromPyObject scoped on fields. |
564 | #[derive (Clone, Debug)] |
565 | struct FieldPyO3Attributes { |
566 | getter: Option<FieldGetter>, |
567 | from_py_with: Option<FromPyWithAttribute>, |
568 | default: Option<DefaultAttribute>, |
569 | } |
570 | |
571 | #[derive (Clone, Debug)] |
572 | enum FieldGetter { |
573 | GetItem(Option<syn::Lit>), |
574 | GetAttr(Option<LitStr>), |
575 | } |
576 | |
577 | enum FieldPyO3Attribute { |
578 | Getter(FieldGetter), |
579 | FromPyWith(FromPyWithAttribute), |
580 | Default(DefaultAttribute), |
581 | } |
582 | |
583 | impl Parse for FieldPyO3Attribute { |
584 | fn parse(input: ParseStream<'_>) -> Result<Self> { |
585 | let lookahead = input.lookahead1(); |
586 | if lookahead.peek(attributes::kw::attribute) { |
587 | let _: attributes::kw::attribute = input.parse()?; |
588 | if input.peek(syn::token::Paren) { |
589 | let content; |
590 | let _ = parenthesized!(content in input); |
591 | let attr_name: LitStr = content.parse()?; |
592 | if !content.is_empty() { |
593 | return Err(content.error( |
594 | "expected at most one argument: `attribute` or `attribute( \"name \")`" , |
595 | )); |
596 | } |
597 | ensure_spanned!( |
598 | !attr_name.value().is_empty(), |
599 | attr_name.span() => "attribute name cannot be empty" |
600 | ); |
601 | Ok(FieldPyO3Attribute::Getter(FieldGetter::GetAttr(Some( |
602 | attr_name, |
603 | )))) |
604 | } else { |
605 | Ok(FieldPyO3Attribute::Getter(FieldGetter::GetAttr(None))) |
606 | } |
607 | } else if lookahead.peek(attributes::kw::item) { |
608 | let _: attributes::kw::item = input.parse()?; |
609 | if input.peek(syn::token::Paren) { |
610 | let content; |
611 | let _ = parenthesized!(content in input); |
612 | let key = content.parse()?; |
613 | if !content.is_empty() { |
614 | return Err( |
615 | content.error("expected at most one argument: `item` or `item(key)`" ) |
616 | ); |
617 | } |
618 | Ok(FieldPyO3Attribute::Getter(FieldGetter::GetItem(Some(key)))) |
619 | } else { |
620 | Ok(FieldPyO3Attribute::Getter(FieldGetter::GetItem(None))) |
621 | } |
622 | } else if lookahead.peek(attributes::kw::from_py_with) { |
623 | input.parse().map(FieldPyO3Attribute::FromPyWith) |
624 | } else if lookahead.peek(Token![default]) { |
625 | input.parse().map(FieldPyO3Attribute::Default) |
626 | } else { |
627 | Err(lookahead.error()) |
628 | } |
629 | } |
630 | } |
631 | |
632 | impl FieldPyO3Attributes { |
633 | /// Extract the field attributes. |
634 | fn from_attrs(attrs: &[Attribute]) -> Result<Self> { |
635 | let mut getter = None; |
636 | let mut from_py_with = None; |
637 | let mut default = None; |
638 | |
639 | for attr in attrs { |
640 | if let Some(pyo3_attrs) = get_pyo3_options(attr)? { |
641 | for pyo3_attr in pyo3_attrs { |
642 | match pyo3_attr { |
643 | FieldPyO3Attribute::Getter(field_getter) => { |
644 | ensure_spanned!( |
645 | getter.is_none(), |
646 | attr.span() => "only one of `attribute` or `item` can be provided" |
647 | ); |
648 | getter = Some(field_getter); |
649 | } |
650 | FieldPyO3Attribute::FromPyWith(from_py_with_attr) => { |
651 | ensure_spanned!( |
652 | from_py_with.is_none(), |
653 | attr.span() => "`from_py_with` may only be provided once" |
654 | ); |
655 | from_py_with = Some(from_py_with_attr); |
656 | } |
657 | FieldPyO3Attribute::Default(default_attr) => { |
658 | ensure_spanned!( |
659 | default.is_none(), |
660 | attr.span() => "`default` may only be provided once" |
661 | ); |
662 | default = Some(default_attr); |
663 | } |
664 | } |
665 | } |
666 | } |
667 | } |
668 | |
669 | Ok(FieldPyO3Attributes { |
670 | getter, |
671 | from_py_with, |
672 | default, |
673 | }) |
674 | } |
675 | } |
676 | |
677 | fn verify_and_get_lifetime(generics: &syn::Generics) -> Result<Option<&syn::LifetimeParam>> { |
678 | let mut lifetimes: Lifetimes<'_> = generics.lifetimes(); |
679 | let lifetime: Option<&LifetimeParam> = lifetimes.next(); |
680 | ensure_spanned!( |
681 | lifetimes.next().is_none(), |
682 | generics.span() => "FromPyObject can be derived with at most one lifetime parameter" |
683 | ); |
684 | Ok(lifetime) |
685 | } |
686 | |
687 | /// Derive FromPyObject for enums and structs. |
688 | /// |
689 | /// * Max 1 lifetime specifier, will be tied to `FromPyObject`'s specifier |
690 | /// * At least one field, in case of `#[transparent]`, exactly one field |
691 | /// * At least one variant for enums. |
692 | /// * Fields of input structs and enums must implement `FromPyObject` or be annotated with `from_py_with` |
693 | /// * Derivation for structs with generic fields like `struct<T> Foo(T)` |
694 | /// adds `T: FromPyObject` on the derived implementation. |
695 | pub fn build_derive_from_pyobject(tokens: &DeriveInput) -> Result<TokenStream> { |
696 | let options = ContainerOptions::from_attrs(&tokens.attrs)?; |
697 | let ctx = &Ctx::new(&options.krate, None); |
698 | let Ctx { pyo3_path, .. } = &ctx; |
699 | |
700 | let (_, ty_generics, _) = tokens.generics.split_for_impl(); |
701 | let mut trait_generics = tokens.generics.clone(); |
702 | let lt_param = if let Some(lt) = verify_and_get_lifetime(&trait_generics)? { |
703 | lt.clone() |
704 | } else { |
705 | trait_generics.params.push(parse_quote!('py)); |
706 | parse_quote!('py) |
707 | }; |
708 | let (impl_generics, _, where_clause) = trait_generics.split_for_impl(); |
709 | |
710 | let mut where_clause = where_clause.cloned().unwrap_or_else(|| parse_quote!(where)); |
711 | for param in trait_generics.type_params() { |
712 | let gen_ident = ¶m.ident; |
713 | where_clause |
714 | .predicates |
715 | .push(parse_quote!(#gen_ident: #pyo3_path::FromPyObject<'py>)) |
716 | } |
717 | |
718 | let derives = match &tokens.data { |
719 | syn::Data::Enum(en) => { |
720 | if options.transparent || options.annotation.is_some() { |
721 | bail_spanned!(tokens.span() => "`transparent` or `annotation` is not supported \ |
722 | at top level for enums" ); |
723 | } |
724 | let en = Enum::new(en, &tokens.ident, options)?; |
725 | en.build(ctx) |
726 | } |
727 | syn::Data::Struct(st) => { |
728 | if let Some(lit_str) = &options.annotation { |
729 | bail_spanned!(lit_str.span() => "`annotation` is unsupported for structs" ); |
730 | } |
731 | let ident = &tokens.ident; |
732 | let st = Container::new(&st.fields, parse_quote!(#ident), options)?; |
733 | st.build(ctx) |
734 | } |
735 | syn::Data::Union(_) => bail_spanned!( |
736 | tokens.span() => "#[derive(FromPyObject)] is not supported for unions" |
737 | ), |
738 | }; |
739 | |
740 | let ident = &tokens.ident; |
741 | Ok(quote!( |
742 | #[automatically_derived] |
743 | impl #impl_generics #pyo3_path::FromPyObject<#lt_param> for #ident #ty_generics #where_clause { |
744 | fn extract_bound(obj: &#pyo3_path::Bound<#lt_param, #pyo3_path::PyAny>) -> #pyo3_path::PyResult<Self> { |
745 | #derives |
746 | } |
747 | } |
748 | )) |
749 | } |
750 | |