1use crate::{
2 attributes::{self, get_pyo3_options, CrateAttribute, FromPyWithAttribute},
3 utils::get_pyo3_crate,
4};
5use proc_macro2::TokenStream;
6use quote::{format_ident, quote};
7use syn::{
8 parenthesized,
9 parse::{Parse, ParseStream},
10 parse_quote,
11 punctuated::Punctuated,
12 spanned::Spanned,
13 Attribute, DataEnum, DeriveInput, Fields, Ident, LitStr, Result, Token,
14};
15
16/// Describes derivation input of an enum.
17struct Enum<'a> {
18 enum_ident: &'a Ident,
19 variants: Vec<Container<'a>>,
20}
21
22impl<'a> Enum<'a> {
23 /// Construct a new enum representation.
24 ///
25 /// `data_enum` is the `syn` representation of the input enum, `ident` is the
26 /// `Identifier` of the enum.
27 fn new(data_enum: &'a DataEnum, ident: &'a Ident) -> Result<Self> {
28 ensure_spanned!(
29 !data_enum.variants.is_empty(),
30 ident.span() => "cannot derive FromPyObject for empty enum"
31 );
32 let variants = data_enum
33 .variants
34 .iter()
35 .map(|variant| {
36 let attrs = ContainerOptions::from_attrs(&variant.attrs)?;
37 let var_ident = &variant.ident;
38 Container::new(&variant.fields, parse_quote!(#ident::#var_ident), attrs)
39 })
40 .collect::<Result<Vec<_>>>()?;
41
42 Ok(Enum {
43 enum_ident: ident,
44 variants,
45 })
46 }
47
48 /// Build derivation body for enums.
49 fn build(&self) -> TokenStream {
50 let mut var_extracts = Vec::new();
51 let mut variant_names = Vec::new();
52 let mut error_names = Vec::new();
53 for var in &self.variants {
54 let struct_derive = var.build();
55 let ext = quote!({
56 let maybe_ret = || -> _pyo3::PyResult<Self> {
57 #struct_derive
58 }();
59
60 match maybe_ret {
61 ok @ ::std::result::Result::Ok(_) => return ok,
62 ::std::result::Result::Err(err) => err
63 }
64 });
65
66 var_extracts.push(ext);
67 variant_names.push(var.path.segments.last().unwrap().ident.to_string());
68 error_names.push(&var.err_name);
69 }
70 let ty_name = self.enum_ident.to_string();
71 quote!(
72 let errors = [
73 #(#var_extracts),*
74 ];
75 ::std::result::Result::Err(
76 _pyo3::impl_::frompyobject::failed_to_extract_enum(
77 obj.py(),
78 #ty_name,
79 &[#(#variant_names),*],
80 &[#(#error_names),*],
81 &errors
82 )
83 )
84 )
85 }
86}
87
88struct NamedStructField<'a> {
89 ident: &'a syn::Ident,
90 getter: Option<FieldGetter>,
91 from_py_with: Option<FromPyWithAttribute>,
92}
93
94struct TupleStructField {
95 from_py_with: Option<FromPyWithAttribute>,
96}
97
98/// Container Style
99///
100/// Covers Structs, Tuplestructs and corresponding Newtypes.
101enum ContainerType<'a> {
102 /// Struct Container, e.g. `struct Foo { a: String }`
103 ///
104 /// Variant contains the list of field identifiers and the corresponding extraction call.
105 Struct(Vec<NamedStructField<'a>>),
106 /// Newtype struct container, e.g. `#[transparent] struct Foo { a: String }`
107 ///
108 /// The field specified by the identifier is extracted directly from the object.
109 StructNewtype(&'a syn::Ident, Option<FromPyWithAttribute>),
110 /// Tuple struct, e.g. `struct Foo(String)`.
111 ///
112 /// Variant contains a list of conversion methods for each of the fields that are directly
113 /// extracted from the tuple.
114 Tuple(Vec<TupleStructField>),
115 /// Tuple newtype, e.g. `#[transparent] struct Foo(String)`
116 ///
117 /// The wrapped field is directly extracted from the object.
118 TupleNewtype(Option<FromPyWithAttribute>),
119}
120
121/// Data container
122///
123/// Either describes a struct or an enum variant.
124struct Container<'a> {
125 path: syn::Path,
126 ty: ContainerType<'a>,
127 err_name: String,
128}
129
130impl<'a> Container<'a> {
131 /// Construct a container based on fields, identifier and attributes.
132 ///
133 /// Fails if the variant has no fields or incompatible attributes.
134 fn new(fields: &'a Fields, path: syn::Path, options: ContainerOptions) -> Result<Self> {
135 let style = match fields {
136 Fields::Unnamed(unnamed) if !unnamed.unnamed.is_empty() => {
137 let mut tuple_fields = unnamed
138 .unnamed
139 .iter()
140 .map(|field| {
141 let attrs = FieldPyO3Attributes::from_attrs(&field.attrs)?;
142 ensure_spanned!(
143 attrs.getter.is_none(),
144 field.span() => "`getter` is not permitted on tuple struct elements."
145 );
146 Ok(TupleStructField {
147 from_py_with: attrs.from_py_with,
148 })
149 })
150 .collect::<Result<Vec<_>>>()?;
151
152 if tuple_fields.len() == 1 {
153 // Always treat a 1-length tuple struct as "transparent", even without the
154 // explicit annotation.
155 let field = tuple_fields.pop().unwrap();
156 ContainerType::TupleNewtype(field.from_py_with)
157 } else if options.transparent {
158 bail_spanned!(
159 fields.span() => "transparent structs and variants can only have 1 field"
160 );
161 } else {
162 ContainerType::Tuple(tuple_fields)
163 }
164 }
165 Fields::Named(named) if !named.named.is_empty() => {
166 let mut struct_fields = named
167 .named
168 .iter()
169 .map(|field| {
170 let ident = field
171 .ident
172 .as_ref()
173 .expect("Named fields should have identifiers");
174 let mut attrs = FieldPyO3Attributes::from_attrs(&field.attrs)?;
175
176 if let Some(ref from_item_all) = options.from_item_all {
177 if let Some(replaced) = attrs.getter.replace(FieldGetter::GetItem(None))
178 {
179 match replaced {
180 FieldGetter::GetItem(Some(item_name)) => {
181 attrs.getter = Some(FieldGetter::GetItem(Some(item_name)));
182 }
183 FieldGetter::GetItem(None) => bail_spanned!(from_item_all.span() => "Useless `item` - the struct is already annotated with `from_item_all`"),
184 FieldGetter::GetAttr(_) => bail_spanned!(
185 from_item_all.span() => "The struct is already annotated with `from_item_all`, `attribute` is not allowed"
186 ),
187 }
188 }
189 }
190
191 Ok(NamedStructField {
192 ident,
193 getter: attrs.getter,
194 from_py_with: attrs.from_py_with,
195 })
196 })
197 .collect::<Result<Vec<_>>>()?;
198 if options.transparent {
199 ensure_spanned!(
200 struct_fields.len() == 1,
201 fields.span() => "transparent structs and variants can only have 1 field"
202 );
203 let field = struct_fields.pop().unwrap();
204 ensure_spanned!(
205 field.getter.is_none(),
206 field.ident.span() => "`transparent` structs may not have a `getter` for the inner field"
207 );
208 ContainerType::StructNewtype(field.ident, field.from_py_with)
209 } else {
210 ContainerType::Struct(struct_fields)
211 }
212 }
213 _ => bail_spanned!(
214 fields.span() => "cannot derive FromPyObject for empty structs and variants"
215 ),
216 };
217 let err_name = options.annotation.map_or_else(
218 || path.segments.last().unwrap().ident.to_string(),
219 |lit_str| lit_str.value(),
220 );
221
222 let v = Container {
223 path,
224 ty: style,
225 err_name,
226 };
227 Ok(v)
228 }
229
230 fn name(&self) -> String {
231 let mut value = String::new();
232 for segment in &self.path.segments {
233 if !value.is_empty() {
234 value.push_str("::");
235 }
236 value.push_str(&segment.ident.to_string());
237 }
238 value
239 }
240
241 /// Build derivation body for a struct.
242 fn build(&self) -> TokenStream {
243 match &self.ty {
244 ContainerType::StructNewtype(ident, from_py_with) => {
245 self.build_newtype_struct(Some(ident), from_py_with)
246 }
247 ContainerType::TupleNewtype(from_py_with) => {
248 self.build_newtype_struct(None, from_py_with)
249 }
250 ContainerType::Tuple(tups) => self.build_tuple_struct(tups),
251 ContainerType::Struct(tups) => self.build_struct(tups),
252 }
253 }
254
255 fn build_newtype_struct(
256 &self,
257 field_ident: Option<&Ident>,
258 from_py_with: &Option<FromPyWithAttribute>,
259 ) -> TokenStream {
260 let self_ty = &self.path;
261 let struct_name = self.name();
262 if let Some(ident) = field_ident {
263 let field_name = ident.to_string();
264 match from_py_with {
265 None => quote! {
266 Ok(#self_ty {
267 #ident: _pyo3::impl_::frompyobject::extract_struct_field(obj, #struct_name, #field_name)?
268 })
269 },
270 Some(FromPyWithAttribute {
271 value: expr_path, ..
272 }) => quote! {
273 Ok(#self_ty {
274 #ident: _pyo3::impl_::frompyobject::extract_struct_field_with(#expr_path, obj, #struct_name, #field_name)?
275 })
276 },
277 }
278 } else {
279 match from_py_with {
280 None => quote!(
281 _pyo3::impl_::frompyobject::extract_tuple_struct_field(obj, #struct_name, 0).map(#self_ty)
282 ),
283 Some(FromPyWithAttribute {
284 value: expr_path, ..
285 }) => quote! (
286 _pyo3::impl_::frompyobject::extract_tuple_struct_field_with(#expr_path, obj, #struct_name, 0).map(#self_ty)
287 ),
288 }
289 }
290 }
291
292 fn build_tuple_struct(&self, struct_fields: &[TupleStructField]) -> TokenStream {
293 let self_ty = &self.path;
294 let struct_name = &self.name();
295 let field_idents: Vec<_> = (0..struct_fields.len())
296 .map(|i| format_ident!("arg{}", i))
297 .collect();
298 let fields = struct_fields.iter().zip(&field_idents).enumerate().map(|(index, (field, ident))| {
299 match &field.from_py_with {
300 None => quote!(
301 _pyo3::impl_::frompyobject::extract_tuple_struct_field(#ident, #struct_name, #index)?
302 ),
303 Some(FromPyWithAttribute {
304 value: expr_path, ..
305 }) => quote! (
306 _pyo3::impl_::frompyobject::extract_tuple_struct_field_with(#expr_path, #ident, #struct_name, #index)?
307 ),
308 }
309 });
310 quote!(
311 match obj.extract() {
312 ::std::result::Result::Ok((#(#field_idents),*)) => ::std::result::Result::Ok(#self_ty(#(#fields),*)),
313 ::std::result::Result::Err(err) => ::std::result::Result::Err(err),
314 }
315 )
316 }
317
318 fn build_struct(&self, struct_fields: &[NamedStructField<'_>]) -> TokenStream {
319 let self_ty = &self.path;
320 let struct_name = &self.name();
321 let mut fields: Punctuated<TokenStream, syn::Token![,]> = Punctuated::new();
322 for field in struct_fields {
323 let ident = &field.ident;
324 let field_name = ident.to_string();
325 let getter = match field.getter.as_ref().unwrap_or(&FieldGetter::GetAttr(None)) {
326 FieldGetter::GetAttr(Some(name)) => {
327 quote!(getattr(_pyo3::intern!(obj.py(), #name)))
328 }
329 FieldGetter::GetAttr(None) => {
330 quote!(getattr(_pyo3::intern!(obj.py(), #field_name)))
331 }
332 FieldGetter::GetItem(Some(syn::Lit::Str(key))) => {
333 quote!(get_item(_pyo3::intern!(obj.py(), #key)))
334 }
335 FieldGetter::GetItem(Some(key)) => quote!(get_item(#key)),
336 FieldGetter::GetItem(None) => {
337 quote!(get_item(_pyo3::intern!(obj.py(), #field_name)))
338 }
339 };
340 let extractor = match &field.from_py_with {
341 None => {
342 quote!(_pyo3::impl_::frompyobject::extract_struct_field(obj.#getter?, #struct_name, #field_name)?)
343 }
344 Some(FromPyWithAttribute {
345 value: expr_path, ..
346 }) => {
347 quote! (_pyo3::impl_::frompyobject::extract_struct_field_with(#expr_path, obj.#getter?, #struct_name, #field_name)?)
348 }
349 };
350
351 fields.push(quote!(#ident: #extractor));
352 }
353 quote!(::std::result::Result::Ok(#self_ty{#fields}))
354 }
355}
356
357#[derive(Default)]
358struct ContainerOptions {
359 /// Treat the Container as a Wrapper, directly extract its fields from the input object.
360 transparent: bool,
361 /// Force every field to be extracted from item of source Python object.
362 from_item_all: Option<attributes::kw::from_item_all>,
363 /// Change the name of an enum variant in the generated error message.
364 annotation: Option<syn::LitStr>,
365 /// Change the path for the pyo3 crate
366 krate: Option<CrateAttribute>,
367}
368
369/// Attributes for deriving FromPyObject scoped on containers.
370enum ContainerPyO3Attribute {
371 /// Treat the Container as a Wrapper, directly extract its fields from the input object.
372 Transparent(attributes::kw::transparent),
373 /// Force every field to be extracted from item of source Python object.
374 ItemAll(attributes::kw::from_item_all),
375 /// Change the name of an enum variant in the generated error message.
376 ErrorAnnotation(LitStr),
377 /// Change the path for the pyo3 crate
378 Crate(CrateAttribute),
379}
380
381impl Parse for ContainerPyO3Attribute {
382 fn parse(input: ParseStream<'_>) -> Result<Self> {
383 let lookahead: Lookahead1<'_> = input.lookahead1();
384 if lookahead.peek(token:attributes::kw::transparent) {
385 let kw: attributes::kw::transparent = input.parse()?;
386 Ok(ContainerPyO3Attribute::Transparent(kw))
387 } else if lookahead.peek(token:attributes::kw::from_item_all) {
388 let kw: attributes::kw::from_item_all = input.parse()?;
389 Ok(ContainerPyO3Attribute::ItemAll(kw))
390 } else if lookahead.peek(token:attributes::kw::annotation) {
391 let _: attributes::kw::annotation = input.parse()?;
392 let _: Token![=] = input.parse()?;
393 input.parse().map(op:ContainerPyO3Attribute::ErrorAnnotation)
394 } else if lookahead.peek(Token![crate]) {
395 input.parse().map(op:ContainerPyO3Attribute::Crate)
396 } else {
397 Err(lookahead.error())
398 }
399 }
400}
401
402impl ContainerOptions {
403 fn from_attrs(attrs: &[Attribute]) -> Result<Self> {
404 let mut options = ContainerOptions::default();
405
406 for attr in attrs {
407 if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
408 for pyo3_attr in pyo3_attrs {
409 match pyo3_attr {
410 ContainerPyO3Attribute::Transparent(kw) => {
411 ensure_spanned!(
412 !options.transparent,
413 kw.span() => "`transparent` may only be provided once"
414 );
415 options.transparent = true;
416 }
417 ContainerPyO3Attribute::ItemAll(kw) => {
418 ensure_spanned!(
419 options.from_item_all.is_none(),
420 kw.span() => "`from_item_all` may only be provided once"
421 );
422 options.from_item_all = Some(kw);
423 }
424 ContainerPyO3Attribute::ErrorAnnotation(lit_str) => {
425 ensure_spanned!(
426 options.annotation.is_none(),
427 lit_str.span() => "`annotation` may only be provided once"
428 );
429 options.annotation = Some(lit_str);
430 }
431 ContainerPyO3Attribute::Crate(path) => {
432 ensure_spanned!(
433 options.krate.is_none(),
434 path.span() => "`crate` may only be provided once"
435 );
436 options.krate = Some(path);
437 }
438 }
439 }
440 }
441 }
442 Ok(options)
443 }
444}
445
446/// Attributes for deriving FromPyObject scoped on fields.
447#[derive(Clone, Debug)]
448struct FieldPyO3Attributes {
449 getter: Option<FieldGetter>,
450 from_py_with: Option<FromPyWithAttribute>,
451}
452
453#[derive(Clone, Debug)]
454enum FieldGetter {
455 GetItem(Option<syn::Lit>),
456 GetAttr(Option<LitStr>),
457}
458
459enum FieldPyO3Attribute {
460 Getter(FieldGetter),
461 FromPyWith(FromPyWithAttribute),
462}
463
464impl Parse for FieldPyO3Attribute {
465 fn parse(input: ParseStream<'_>) -> Result<Self> {
466 let lookahead = input.lookahead1();
467 if lookahead.peek(attributes::kw::attribute) {
468 let _: attributes::kw::attribute = input.parse()?;
469 if input.peek(syn::token::Paren) {
470 let content;
471 let _ = parenthesized!(content in input);
472 let attr_name: LitStr = content.parse()?;
473 if !content.is_empty() {
474 return Err(content.error(
475 "expected at most one argument: `attribute` or `attribute(\"name\")`",
476 ));
477 }
478 ensure_spanned!(
479 !attr_name.value().is_empty(),
480 attr_name.span() => "attribute name cannot be empty"
481 );
482 Ok(FieldPyO3Attribute::Getter(FieldGetter::GetAttr(Some(
483 attr_name,
484 ))))
485 } else {
486 Ok(FieldPyO3Attribute::Getter(FieldGetter::GetAttr(None)))
487 }
488 } else if lookahead.peek(attributes::kw::item) {
489 let _: attributes::kw::item = input.parse()?;
490 if input.peek(syn::token::Paren) {
491 let content;
492 let _ = parenthesized!(content in input);
493 let key = content.parse()?;
494 if !content.is_empty() {
495 return Err(
496 content.error("expected at most one argument: `item` or `item(key)`")
497 );
498 }
499 Ok(FieldPyO3Attribute::Getter(FieldGetter::GetItem(Some(key))))
500 } else {
501 Ok(FieldPyO3Attribute::Getter(FieldGetter::GetItem(None)))
502 }
503 } else if lookahead.peek(attributes::kw::from_py_with) {
504 input.parse().map(FieldPyO3Attribute::FromPyWith)
505 } else {
506 Err(lookahead.error())
507 }
508 }
509}
510
511impl FieldPyO3Attributes {
512 /// Extract the field attributes.
513 fn from_attrs(attrs: &[Attribute]) -> Result<Self> {
514 let mut getter = None;
515 let mut from_py_with = None;
516
517 for attr in attrs {
518 if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
519 for pyo3_attr in pyo3_attrs {
520 match pyo3_attr {
521 FieldPyO3Attribute::Getter(field_getter) => {
522 ensure_spanned!(
523 getter.is_none(),
524 attr.span() => "only one of `attribute` or `item` can be provided"
525 );
526 getter = Some(field_getter);
527 }
528 FieldPyO3Attribute::FromPyWith(from_py_with_attr) => {
529 ensure_spanned!(
530 from_py_with.is_none(),
531 attr.span() => "`from_py_with` may only be provided once"
532 );
533 from_py_with = Some(from_py_with_attr);
534 }
535 }
536 }
537 }
538 }
539
540 Ok(FieldPyO3Attributes {
541 getter,
542 from_py_with,
543 })
544 }
545}
546
547fn verify_and_get_lifetime(generics: &syn::Generics) -> Result<Option<&syn::LifetimeParam>> {
548 let mut lifetimes: Lifetimes<'_> = generics.lifetimes();
549 let lifetime: Option<&LifetimeParam> = lifetimes.next();
550 ensure_spanned!(
551 lifetimes.next().is_none(),
552 generics.span() => "FromPyObject can be derived with at most one lifetime parameter"
553 );
554 Ok(lifetime)
555}
556
557/// Derive FromPyObject for enums and structs.
558///
559/// * Max 1 lifetime specifier, will be tied to `FromPyObject`'s specifier
560/// * At least one field, in case of `#[transparent]`, exactly one field
561/// * At least one variant for enums.
562/// * Fields of input structs and enums must implement `FromPyObject` or be annotated with `from_py_with`
563/// * Derivation for structs with generic fields like `struct<T> Foo(T)`
564/// adds `T: FromPyObject` on the derived implementation.
565pub fn build_derive_from_pyobject(tokens: &DeriveInput) -> Result<TokenStream> {
566 let mut trait_generics = tokens.generics.clone();
567 let generics = &tokens.generics;
568 let lt_param = if let Some(lt) = verify_and_get_lifetime(generics)? {
569 lt.clone()
570 } else {
571 trait_generics.params.push(parse_quote!('source));
572 parse_quote!('source)
573 };
574 let mut where_clause: syn::WhereClause = parse_quote!(where);
575 for param in generics.type_params() {
576 let gen_ident = &param.ident;
577 where_clause
578 .predicates
579 .push(parse_quote!(#gen_ident: FromPyObject<#lt_param>))
580 }
581 let options = ContainerOptions::from_attrs(&tokens.attrs)?;
582 let krate = get_pyo3_crate(&options.krate);
583 let derives = match &tokens.data {
584 syn::Data::Enum(en) => {
585 if options.transparent || options.annotation.is_some() {
586 bail_spanned!(tokens.span() => "`transparent` or `annotation` is not supported \
587 at top level for enums");
588 }
589 let en = Enum::new(en, &tokens.ident)?;
590 en.build()
591 }
592 syn::Data::Struct(st) => {
593 if let Some(lit_str) = &options.annotation {
594 bail_spanned!(lit_str.span() => "`annotation` is unsupported for structs");
595 }
596 let ident = &tokens.ident;
597 let st = Container::new(&st.fields, parse_quote!(#ident), options)?;
598 st.build()
599 }
600 syn::Data::Union(_) => bail_spanned!(
601 tokens.span() => "#[derive(FromPyObject)] is not supported for unions"
602 ),
603 };
604
605 let ident = &tokens.ident;
606 Ok(quote!(
607 const _: () = {
608 use #krate as _pyo3;
609
610 #[automatically_derived]
611 impl #trait_generics _pyo3::FromPyObject<#lt_param> for #ident #generics #where_clause {
612 fn extract(obj: &#lt_param _pyo3::PyAny) -> _pyo3::PyResult<Self> {
613 #derives
614 }
615 }
616 };
617 ))
618}
619