1 | use proc_macro2::TokenStream as TokenStream2; |
2 | use quote::{quote, ToTokens}; |
3 | use syn::{GenericArgument, PathArguments, PathSegment, ReturnType, Type, TypePath}; |
4 | |
5 | pub fn quote_option<T: ToTokens>(a: &Option<T>) -> TokenStream2 { |
6 | if let Some(a: &T) = a { |
7 | quote! { Some(#a) } |
8 | } else { |
9 | quote! { None } |
10 | } |
11 | } |
12 | |
13 | pub fn remove_lifetime(ty: &mut Type) { |
14 | match ty { |
15 | Type::Path(TypePath { path, .. }) => { |
16 | if let Some(PathSegment { |
17 | arguments: PathArguments::AngleBracketed(inner), |
18 | .. |
19 | }) = path.segments.last_mut() |
20 | { |
21 | for arg in &mut inner.args { |
22 | match arg { |
23 | GenericArgument::Lifetime(l) => { |
24 | // `T::<'a, S>` becomes `T::<'_, S>` |
25 | *l = syn::parse_quote!('_); |
26 | } |
27 | GenericArgument::Type(ty) => { |
28 | remove_lifetime(ty); |
29 | } |
30 | _ => {} |
31 | } |
32 | } |
33 | } |
34 | } |
35 | Type::Reference(rty) => { |
36 | rty.lifetime = None; |
37 | remove_lifetime(rty.elem.as_mut()); |
38 | } |
39 | Type::Tuple(ty) => { |
40 | for elem in &mut ty.elems { |
41 | remove_lifetime(elem); |
42 | } |
43 | } |
44 | Type::Array(ary) => { |
45 | remove_lifetime(ary.elem.as_mut()); |
46 | } |
47 | _ => {} |
48 | } |
49 | } |
50 | |
51 | /// Extract `T` from `PyResult<T>`. |
52 | /// |
53 | /// For `PyResult<&'a T>` case, `'a` will be removed, i.e. returns `&T` for this case. |
54 | pub fn escape_return_type(ret: &ReturnType) -> Option<Type> { |
55 | let ret: &Type = if let ReturnType::Type(_, ty: &Box) = ret { |
56 | unwrap_pyresult(ty) |
57 | } else { |
58 | return None; |
59 | }; |
60 | let mut ret: Type = ret.clone(); |
61 | remove_lifetime(&mut ret); |
62 | Some(ret) |
63 | } |
64 | |
65 | fn unwrap_pyresult(ty: &Type) -> &Type { |
66 | if let Type::Path(TypePath { path: &Path, .. }) = ty { |
67 | if let Some(last: &PathSegment) = path.segments.last() { |
68 | if last.ident == "PyResult" { |
69 | if let PathArguments::AngleBracketed(inner: &AngleBracketedGenericArguments) = &last.arguments { |
70 | for arg: &GenericArgument in &inner.args { |
71 | if let GenericArgument::Type(ty: &Type) = arg { |
72 | return ty; |
73 | } |
74 | } |
75 | } |
76 | } |
77 | } |
78 | } |
79 | ty |
80 | } |
81 | |
82 | #[cfg (test)] |
83 | mod test { |
84 | use super::*; |
85 | use syn::{parse_str, Result}; |
86 | |
87 | #[test ] |
88 | fn test_unwrap_pyresult() -> Result<()> { |
89 | let ty: Type = parse_str("PyResult<i32>" )?; |
90 | let out = unwrap_pyresult(&ty); |
91 | assert_eq!(out, &parse_str("i32" )?); |
92 | |
93 | let ty: Type = parse_str("PyResult<&PyString>" )?; |
94 | let out = unwrap_pyresult(&ty); |
95 | assert_eq!(out, &parse_str("&PyString" )?); |
96 | |
97 | let ty: Type = parse_str("PyResult<&'a PyString>" )?; |
98 | let out = unwrap_pyresult(&ty); |
99 | assert_eq!(out, &parse_str("&'a PyString" )?); |
100 | |
101 | let ty: Type = parse_str("::pyo3::PyResult<i32>" )?; |
102 | let out = unwrap_pyresult(&ty); |
103 | assert_eq!(out, &parse_str("i32" )?); |
104 | |
105 | let ty: Type = parse_str("::pyo3::PyResult<&PyString>" )?; |
106 | let out = unwrap_pyresult(&ty); |
107 | assert_eq!(out, &parse_str("&PyString" )?); |
108 | |
109 | let ty: Type = parse_str("::pyo3::PyResult<&'a PyString>" )?; |
110 | let out = unwrap_pyresult(&ty); |
111 | assert_eq!(out, &parse_str("&'a PyString" )?); |
112 | |
113 | Ok(()) |
114 | } |
115 | } |
116 | |