| 1 | use proc_macro2::TokenStream as TokenStream2; |
| 2 | use quote::{quote, ToTokens}; |
| 3 | use syn::{GenericArgument, PathArguments, PathSegment, ReturnType, Type, TypePath}; |
| 4 | |
| 5 | pub fn quote_option<T: ToTokens>(a: &Option<T>) -> TokenStream2 { |
| 6 | if let Some(a: &T) = a { |
| 7 | quote! { Some(#a) } |
| 8 | } else { |
| 9 | quote! { None } |
| 10 | } |
| 11 | } |
| 12 | |
| 13 | pub fn remove_lifetime(ty: &mut Type) { |
| 14 | match ty { |
| 15 | Type::Path(TypePath { path, .. }) => { |
| 16 | if let Some(PathSegment { |
| 17 | arguments: PathArguments::AngleBracketed(inner), |
| 18 | .. |
| 19 | }) = path.segments.last_mut() |
| 20 | { |
| 21 | for arg in &mut inner.args { |
| 22 | match arg { |
| 23 | GenericArgument::Lifetime(l) => { |
| 24 | // `T::<'a, S>` becomes `T::<'_, S>` |
| 25 | *l = syn::parse_quote!('_); |
| 26 | } |
| 27 | GenericArgument::Type(ty) => { |
| 28 | remove_lifetime(ty); |
| 29 | } |
| 30 | _ => {} |
| 31 | } |
| 32 | } |
| 33 | } |
| 34 | } |
| 35 | Type::Reference(rty) => { |
| 36 | rty.lifetime = None; |
| 37 | remove_lifetime(rty.elem.as_mut()); |
| 38 | } |
| 39 | Type::Tuple(ty) => { |
| 40 | for elem in &mut ty.elems { |
| 41 | remove_lifetime(elem); |
| 42 | } |
| 43 | } |
| 44 | Type::Array(ary) => { |
| 45 | remove_lifetime(ary.elem.as_mut()); |
| 46 | } |
| 47 | _ => {} |
| 48 | } |
| 49 | } |
| 50 | |
| 51 | /// Extract `T` from `PyResult<T>`. |
| 52 | /// |
| 53 | /// For `PyResult<&'a T>` case, `'a` will be removed, i.e. returns `&T` for this case. |
| 54 | pub fn escape_return_type(ret: &ReturnType) -> Option<Type> { |
| 55 | let ret: &Type = if let ReturnType::Type(_, ty: &Box) = ret { |
| 56 | unwrap_pyresult(ty) |
| 57 | } else { |
| 58 | return None; |
| 59 | }; |
| 60 | let mut ret: Type = ret.clone(); |
| 61 | remove_lifetime(&mut ret); |
| 62 | Some(ret) |
| 63 | } |
| 64 | |
| 65 | fn unwrap_pyresult(ty: &Type) -> &Type { |
| 66 | if let Type::Path(TypePath { path: &Path, .. }) = ty { |
| 67 | if let Some(last: &PathSegment) = path.segments.last() { |
| 68 | if last.ident == "PyResult" { |
| 69 | if let PathArguments::AngleBracketed(inner: &AngleBracketedGenericArguments) = &last.arguments { |
| 70 | for arg: &GenericArgument in &inner.args { |
| 71 | if let GenericArgument::Type(ty: &Type) = arg { |
| 72 | return ty; |
| 73 | } |
| 74 | } |
| 75 | } |
| 76 | } |
| 77 | } |
| 78 | } |
| 79 | ty |
| 80 | } |
| 81 | |
| 82 | #[cfg (test)] |
| 83 | mod test { |
| 84 | use super::*; |
| 85 | use syn::{parse_str, Result}; |
| 86 | |
| 87 | #[test ] |
| 88 | fn test_unwrap_pyresult() -> Result<()> { |
| 89 | let ty: Type = parse_str("PyResult<i32>" )?; |
| 90 | let out = unwrap_pyresult(&ty); |
| 91 | assert_eq!(out, &parse_str("i32" )?); |
| 92 | |
| 93 | let ty: Type = parse_str("PyResult<&PyString>" )?; |
| 94 | let out = unwrap_pyresult(&ty); |
| 95 | assert_eq!(out, &parse_str("&PyString" )?); |
| 96 | |
| 97 | let ty: Type = parse_str("PyResult<&'a PyString>" )?; |
| 98 | let out = unwrap_pyresult(&ty); |
| 99 | assert_eq!(out, &parse_str("&'a PyString" )?); |
| 100 | |
| 101 | let ty: Type = parse_str("::pyo3::PyResult<i32>" )?; |
| 102 | let out = unwrap_pyresult(&ty); |
| 103 | assert_eq!(out, &parse_str("i32" )?); |
| 104 | |
| 105 | let ty: Type = parse_str("::pyo3::PyResult<&PyString>" )?; |
| 106 | let out = unwrap_pyresult(&ty); |
| 107 | assert_eq!(out, &parse_str("&PyString" )?); |
| 108 | |
| 109 | let ty: Type = parse_str("::pyo3::PyResult<&'a PyString>" )?; |
| 110 | let out = unwrap_pyresult(&ty); |
| 111 | assert_eq!(out, &parse_str("&'a PyString" )?); |
| 112 | |
| 113 | Ok(()) |
| 114 | } |
| 115 | } |
| 116 | |