1 | use proc_macro2::{Span, TokenStream}; |
2 | use quote::{format_ident, quote, ToTokens}; |
3 | use syn::{punctuated::Punctuated, spanned::Spanned, Data, DeriveInput, Error, Field}; |
4 | use zvariant_utils::{case, macros}; |
5 | |
6 | use crate::utils::*; |
7 | |
8 | fn dict_name_for_field( |
9 | f: &Field, |
10 | rename_attr: Option<String>, |
11 | rename_all_attr: Option<&str>, |
12 | ) -> Result<String, Error> { |
13 | if let Some(name: String) = rename_attr { |
14 | Ok(name) |
15 | } else { |
16 | let ident: String = f.ident.as_ref().unwrap().to_string(); |
17 | |
18 | match rename_all_attr { |
19 | Some("lowercase" ) => Ok(ident.to_ascii_lowercase()), |
20 | Some("UPPERCASE" ) => Ok(ident.to_ascii_uppercase()), |
21 | Some("PascalCase" ) => Ok(case::pascal_or_camel_case(&ident, is_pascal_case:true)), |
22 | Some("camelCase" ) => Ok(case::pascal_or_camel_case(&ident, is_pascal_case:false)), |
23 | Some("snake_case" ) => Ok(case::snake_or_kebab_case(&ident, is_snake_case:true)), |
24 | Some("kebab-case" ) => Ok(case::snake_or_kebab_case(&ident, is_snake_case:false)), |
25 | None => Ok(ident), |
26 | Some(other: &str) => Err(Error::new( |
27 | f.span(), |
28 | message:format!("invalid `rename_all` attribute value {other}" ), |
29 | )), |
30 | } |
31 | } |
32 | } |
33 | |
34 | pub fn expand_serialize_derive(input: DeriveInput) -> Result<TokenStream, Error> { |
35 | let (name, data) = match input.data { |
36 | Data::Struct(data) => (input.ident, data), |
37 | _ => return Err(Error::new(input.span(), "only structs supported" )), |
38 | }; |
39 | |
40 | let StructAttributes { rename_all, .. } = StructAttributes::parse(&input.attrs)?; |
41 | |
42 | let zv = zvariant_path(); |
43 | let mut entries = quote! {}; |
44 | let mut num_entries: usize = 0; |
45 | |
46 | for f in &data.fields { |
47 | let FieldAttributes { rename } = FieldAttributes::parse(&f.attrs)?; |
48 | |
49 | let name = &f.ident; |
50 | let dict_name = dict_name_for_field(f, rename, rename_all.as_deref())?; |
51 | |
52 | let is_option = macros::ty_is_option(&f.ty); |
53 | |
54 | let e = if is_option { |
55 | quote! { |
56 | if self.#name.is_some() { |
57 | map.serialize_entry(#dict_name, &#zv::SerializeValue(self.#name.as_ref().unwrap()))?; |
58 | } |
59 | } |
60 | } else { |
61 | quote! { |
62 | map.serialize_entry(#dict_name, &#zv::SerializeValue(&self.#name))?; |
63 | } |
64 | }; |
65 | |
66 | entries.extend(e); |
67 | num_entries += 1; |
68 | } |
69 | |
70 | let generics = input.generics; |
71 | let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); |
72 | |
73 | let num_entries = num_entries.to_token_stream(); |
74 | Ok(quote! { |
75 | #[allow(deprecated)] |
76 | impl #impl_generics #zv::export::serde::ser::Serialize for #name #ty_generics |
77 | #where_clause |
78 | { |
79 | fn serialize<S>(&self, serializer: S) -> ::std::result::Result<S::Ok, S::Error> |
80 | where |
81 | S: #zv::export::serde::ser::Serializer, |
82 | { |
83 | use #zv::export::serde::ser::SerializeMap; |
84 | |
85 | // zbus doesn't care about number of entries (it would need bytes instead) |
86 | let mut map = serializer.serialize_map(::std::option::Option::Some(#num_entries))?; |
87 | #entries |
88 | map.end() |
89 | } |
90 | } |
91 | }) |
92 | } |
93 | |
94 | pub fn expand_deserialize_derive(input: DeriveInput) -> Result<TokenStream, Error> { |
95 | let (name, data) = match input.data { |
96 | Data::Struct(data) => (input.ident, data), |
97 | _ => return Err(Error::new(input.span(), "only structs supported" )), |
98 | }; |
99 | |
100 | let StructAttributes { |
101 | rename_all, |
102 | deny_unknown_fields, |
103 | .. |
104 | } = StructAttributes::parse(&input.attrs)?; |
105 | |
106 | let visitor = format_ident!(" {}Visitor" , name); |
107 | let zv = zvariant_path(); |
108 | let mut fields = Vec::new(); |
109 | let mut req_fields = Vec::new(); |
110 | let mut dict_names = Vec::new(); |
111 | let mut entries = Vec::new(); |
112 | |
113 | for f in &data.fields { |
114 | let FieldAttributes { rename } = FieldAttributes::parse(&f.attrs)?; |
115 | |
116 | let name = &f.ident; |
117 | let dict_name = dict_name_for_field(f, rename, rename_all.as_deref())?; |
118 | |
119 | let is_option = macros::ty_is_option(&f.ty); |
120 | |
121 | entries.push(quote! { |
122 | #dict_name => { |
123 | // FIXME: add an option about strict parsing (instead of silently skipping the field) |
124 | #name = access.next_value::<#zv::DeserializeValue<_>>().map(|v| v.0).ok(); |
125 | } |
126 | }); |
127 | |
128 | dict_names.push(dict_name); |
129 | fields.push(name); |
130 | |
131 | if !is_option { |
132 | req_fields.push(name); |
133 | } |
134 | } |
135 | |
136 | let fallback = if deny_unknown_fields { |
137 | quote! { |
138 | field => { |
139 | return ::std::result::Result::Err( |
140 | <M::Error as #zv::export::serde::de::Error>::unknown_field( |
141 | field, |
142 | &[#(#dict_names),*], |
143 | ), |
144 | ); |
145 | } |
146 | } |
147 | } else { |
148 | quote! { |
149 | unknown => { |
150 | let _ = access.next_value::<#zv::Value>(); |
151 | } |
152 | } |
153 | }; |
154 | entries.push(fallback); |
155 | |
156 | let (_, ty_generics, _) = input.generics.split_for_impl(); |
157 | let mut generics = input.generics.clone(); |
158 | let def = syn::LifetimeParam { |
159 | attrs: Vec::new(), |
160 | lifetime: syn::Lifetime::new("'de" , Span::call_site()), |
161 | colon_token: None, |
162 | bounds: Punctuated::new(), |
163 | }; |
164 | generics.params = Some(syn::GenericParam::Lifetime(def)) |
165 | .into_iter() |
166 | .chain(generics.params) |
167 | .collect(); |
168 | |
169 | let (impl_generics, _, where_clause) = generics.split_for_impl(); |
170 | |
171 | Ok(quote! { |
172 | #[allow(deprecated)] |
173 | impl #impl_generics #zv::export::serde::de::Deserialize<'de> for #name #ty_generics |
174 | #where_clause |
175 | { |
176 | fn deserialize<D>(deserializer: D) -> ::std::result::Result<Self, D::Error> |
177 | where |
178 | D: #zv::export::serde::de::Deserializer<'de>, |
179 | { |
180 | struct #visitor #ty_generics(::std::marker::PhantomData<#name #ty_generics>); |
181 | |
182 | impl #impl_generics #zv::export::serde::de::Visitor<'de> for #visitor #ty_generics { |
183 | type Value = #name #ty_generics; |
184 | |
185 | fn expecting(&self, formatter: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { |
186 | formatter.write_str("a dictionary" ) |
187 | } |
188 | |
189 | fn visit_map<M>( |
190 | self, |
191 | mut access: M, |
192 | ) -> ::std::result::Result<Self::Value, M::Error> |
193 | where |
194 | M: #zv::export::serde::de::MapAccess<'de>, |
195 | { |
196 | #( let mut #fields = ::std::default::Default::default(); )* |
197 | |
198 | // does not check duplicated fields, since those shouldn't exist in stream |
199 | while let ::std::option::Option::Some(key) = access.next_key::<&str>()? { |
200 | match key { |
201 | #(#entries)* |
202 | } |
203 | } |
204 | |
205 | #(let #req_fields = if let ::std::option::Option::Some(val) = #req_fields { |
206 | val |
207 | } else { |
208 | return ::std::result::Result::Err( |
209 | <M::Error as #zv::export::serde::de::Error>::missing_field( |
210 | ::std::stringify!(#req_fields), |
211 | ), |
212 | ); |
213 | };)* |
214 | |
215 | ::std::result::Result::Ok(#name { #(#fields),* }) |
216 | } |
217 | } |
218 | |
219 | |
220 | deserializer.deserialize_map(#visitor(::std::marker::PhantomData)) |
221 | } |
222 | } |
223 | }) |
224 | } |
225 | |