| 1 | use proc_macro2::{Span, TokenStream}; |
| 2 | use quote::{format_ident, quote, ToTokens}; |
| 3 | use syn::{punctuated::Punctuated, spanned::Spanned, Data, DeriveInput, Error, Field}; |
| 4 | use zvariant_utils::{case, macros}; |
| 5 | |
| 6 | use crate::utils::*; |
| 7 | |
| 8 | fn dict_name_for_field( |
| 9 | f: &Field, |
| 10 | rename_attr: Option<String>, |
| 11 | rename_all_attr: Option<&str>, |
| 12 | ) -> Result<String, Error> { |
| 13 | if let Some(name: String) = rename_attr { |
| 14 | Ok(name) |
| 15 | } else { |
| 16 | let ident: String = f.ident.as_ref().unwrap().to_string(); |
| 17 | |
| 18 | match rename_all_attr { |
| 19 | Some("lowercase" ) => Ok(ident.to_ascii_lowercase()), |
| 20 | Some("UPPERCASE" ) => Ok(ident.to_ascii_uppercase()), |
| 21 | Some("PascalCase" ) => Ok(case::pascal_or_camel_case(&ident, is_pascal_case:true)), |
| 22 | Some("camelCase" ) => Ok(case::pascal_or_camel_case(&ident, is_pascal_case:false)), |
| 23 | Some("snake_case" ) => Ok(case::snake_or_kebab_case(&ident, is_snake_case:true)), |
| 24 | Some("kebab-case" ) => Ok(case::snake_or_kebab_case(&ident, is_snake_case:false)), |
| 25 | None => Ok(ident), |
| 26 | Some(other: &str) => Err(Error::new( |
| 27 | f.span(), |
| 28 | message:format!("invalid `rename_all` attribute value {other}" ), |
| 29 | )), |
| 30 | } |
| 31 | } |
| 32 | } |
| 33 | |
| 34 | pub fn expand_serialize_derive(input: DeriveInput) -> Result<TokenStream, Error> { |
| 35 | let (name, data) = match input.data { |
| 36 | Data::Struct(data) => (input.ident, data), |
| 37 | _ => return Err(Error::new(input.span(), "only structs supported" )), |
| 38 | }; |
| 39 | |
| 40 | let StructAttributes { rename_all, .. } = StructAttributes::parse(&input.attrs)?; |
| 41 | |
| 42 | let zv = zvariant_path(); |
| 43 | let mut entries = quote! {}; |
| 44 | let mut num_entries: usize = 0; |
| 45 | |
| 46 | for f in &data.fields { |
| 47 | let FieldAttributes { rename } = FieldAttributes::parse(&f.attrs)?; |
| 48 | |
| 49 | let name = &f.ident; |
| 50 | let dict_name = dict_name_for_field(f, rename, rename_all.as_deref())?; |
| 51 | |
| 52 | let is_option = macros::ty_is_option(&f.ty); |
| 53 | |
| 54 | let e = if is_option { |
| 55 | quote! { |
| 56 | if self.#name.is_some() { |
| 57 | map.serialize_entry(#dict_name, &#zv::SerializeValue(self.#name.as_ref().unwrap()))?; |
| 58 | } |
| 59 | } |
| 60 | } else { |
| 61 | quote! { |
| 62 | map.serialize_entry(#dict_name, &#zv::SerializeValue(&self.#name))?; |
| 63 | } |
| 64 | }; |
| 65 | |
| 66 | entries.extend(e); |
| 67 | num_entries += 1; |
| 68 | } |
| 69 | |
| 70 | let generics = input.generics; |
| 71 | let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); |
| 72 | |
| 73 | let num_entries = num_entries.to_token_stream(); |
| 74 | Ok(quote! { |
| 75 | #[allow(deprecated)] |
| 76 | impl #impl_generics #zv::export::serde::ser::Serialize for #name #ty_generics |
| 77 | #where_clause |
| 78 | { |
| 79 | fn serialize<S>(&self, serializer: S) -> ::std::result::Result<S::Ok, S::Error> |
| 80 | where |
| 81 | S: #zv::export::serde::ser::Serializer, |
| 82 | { |
| 83 | use #zv::export::serde::ser::SerializeMap; |
| 84 | |
| 85 | // zbus doesn't care about number of entries (it would need bytes instead) |
| 86 | let mut map = serializer.serialize_map(::std::option::Option::Some(#num_entries))?; |
| 87 | #entries |
| 88 | map.end() |
| 89 | } |
| 90 | } |
| 91 | }) |
| 92 | } |
| 93 | |
| 94 | pub fn expand_deserialize_derive(input: DeriveInput) -> Result<TokenStream, Error> { |
| 95 | let (name, data) = match input.data { |
| 96 | Data::Struct(data) => (input.ident, data), |
| 97 | _ => return Err(Error::new(input.span(), "only structs supported" )), |
| 98 | }; |
| 99 | |
| 100 | let StructAttributes { |
| 101 | rename_all, |
| 102 | deny_unknown_fields, |
| 103 | .. |
| 104 | } = StructAttributes::parse(&input.attrs)?; |
| 105 | |
| 106 | let visitor = format_ident!(" {}Visitor" , name); |
| 107 | let zv = zvariant_path(); |
| 108 | let mut fields = Vec::new(); |
| 109 | let mut req_fields = Vec::new(); |
| 110 | let mut dict_names = Vec::new(); |
| 111 | let mut entries = Vec::new(); |
| 112 | |
| 113 | for f in &data.fields { |
| 114 | let FieldAttributes { rename } = FieldAttributes::parse(&f.attrs)?; |
| 115 | |
| 116 | let name = &f.ident; |
| 117 | let dict_name = dict_name_for_field(f, rename, rename_all.as_deref())?; |
| 118 | |
| 119 | let is_option = macros::ty_is_option(&f.ty); |
| 120 | |
| 121 | entries.push(quote! { |
| 122 | #dict_name => { |
| 123 | // FIXME: add an option about strict parsing (instead of silently skipping the field) |
| 124 | #name = access.next_value::<#zv::DeserializeValue<_>>().map(|v| v.0).ok(); |
| 125 | } |
| 126 | }); |
| 127 | |
| 128 | dict_names.push(dict_name); |
| 129 | fields.push(name); |
| 130 | |
| 131 | if !is_option { |
| 132 | req_fields.push(name); |
| 133 | } |
| 134 | } |
| 135 | |
| 136 | let fallback = if deny_unknown_fields { |
| 137 | quote! { |
| 138 | field => { |
| 139 | return ::std::result::Result::Err( |
| 140 | <M::Error as #zv::export::serde::de::Error>::unknown_field( |
| 141 | field, |
| 142 | &[#(#dict_names),*], |
| 143 | ), |
| 144 | ); |
| 145 | } |
| 146 | } |
| 147 | } else { |
| 148 | quote! { |
| 149 | unknown => { |
| 150 | let _ = access.next_value::<#zv::Value>(); |
| 151 | } |
| 152 | } |
| 153 | }; |
| 154 | entries.push(fallback); |
| 155 | |
| 156 | let (_, ty_generics, _) = input.generics.split_for_impl(); |
| 157 | let mut generics = input.generics.clone(); |
| 158 | let def = syn::LifetimeParam { |
| 159 | attrs: Vec::new(), |
| 160 | lifetime: syn::Lifetime::new("'de" , Span::call_site()), |
| 161 | colon_token: None, |
| 162 | bounds: Punctuated::new(), |
| 163 | }; |
| 164 | generics.params = Some(syn::GenericParam::Lifetime(def)) |
| 165 | .into_iter() |
| 166 | .chain(generics.params) |
| 167 | .collect(); |
| 168 | |
| 169 | let (impl_generics, _, where_clause) = generics.split_for_impl(); |
| 170 | |
| 171 | Ok(quote! { |
| 172 | #[allow(deprecated)] |
| 173 | impl #impl_generics #zv::export::serde::de::Deserialize<'de> for #name #ty_generics |
| 174 | #where_clause |
| 175 | { |
| 176 | fn deserialize<D>(deserializer: D) -> ::std::result::Result<Self, D::Error> |
| 177 | where |
| 178 | D: #zv::export::serde::de::Deserializer<'de>, |
| 179 | { |
| 180 | struct #visitor #ty_generics(::std::marker::PhantomData<#name #ty_generics>); |
| 181 | |
| 182 | impl #impl_generics #zv::export::serde::de::Visitor<'de> for #visitor #ty_generics { |
| 183 | type Value = #name #ty_generics; |
| 184 | |
| 185 | fn expecting(&self, formatter: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { |
| 186 | formatter.write_str("a dictionary" ) |
| 187 | } |
| 188 | |
| 189 | fn visit_map<M>( |
| 190 | self, |
| 191 | mut access: M, |
| 192 | ) -> ::std::result::Result<Self::Value, M::Error> |
| 193 | where |
| 194 | M: #zv::export::serde::de::MapAccess<'de>, |
| 195 | { |
| 196 | #( let mut #fields = ::std::default::Default::default(); )* |
| 197 | |
| 198 | // does not check duplicated fields, since those shouldn't exist in stream |
| 199 | while let ::std::option::Option::Some(key) = access.next_key::<&str>()? { |
| 200 | match key { |
| 201 | #(#entries)* |
| 202 | } |
| 203 | } |
| 204 | |
| 205 | #(let #req_fields = if let ::std::option::Option::Some(val) = #req_fields { |
| 206 | val |
| 207 | } else { |
| 208 | return ::std::result::Result::Err( |
| 209 | <M::Error as #zv::export::serde::de::Error>::missing_field( |
| 210 | ::std::stringify!(#req_fields), |
| 211 | ), |
| 212 | ); |
| 213 | };)* |
| 214 | |
| 215 | ::std::result::Result::Ok(#name { #(#fields),* }) |
| 216 | } |
| 217 | } |
| 218 | |
| 219 | |
| 220 | deserializer.deserialize_map(#visitor(::std::marker::PhantomData)) |
| 221 | } |
| 222 | } |
| 223 | }) |
| 224 | } |
| 225 | |