1use proc_macro2::{Span, TokenStream};
2use quote::{format_ident, quote, ToTokens};
3use syn::{punctuated::Punctuated, spanned::Spanned, Data, DeriveInput, Error, Field};
4use zvariant_utils::{case, macros};
5
6use crate::utils::*;
7
8fn dict_name_for_field(
9 f: &Field,
10 rename_attr: Option<String>,
11 rename_all_attr: Option<&str>,
12) -> Result<String, Error> {
13 if let Some(name: String) = rename_attr {
14 Ok(name)
15 } else {
16 let ident: String = f.ident.as_ref().unwrap().to_string();
17
18 match rename_all_attr {
19 Some("lowercase") => Ok(ident.to_ascii_lowercase()),
20 Some("UPPERCASE") => Ok(ident.to_ascii_uppercase()),
21 Some("PascalCase") => Ok(case::pascal_or_camel_case(&ident, is_pascal_case:true)),
22 Some("camelCase") => Ok(case::pascal_or_camel_case(&ident, is_pascal_case:false)),
23 Some("snake_case") => Ok(case::snake_or_kebab_case(&ident, is_snake_case:true)),
24 Some("kebab-case") => Ok(case::snake_or_kebab_case(&ident, is_snake_case:false)),
25 None => Ok(ident),
26 Some(other: &str) => Err(Error::new(
27 f.span(),
28 message:format!("invalid `rename_all` attribute value {other}"),
29 )),
30 }
31 }
32}
33
34pub fn expand_serialize_derive(input: DeriveInput) -> Result<TokenStream, Error> {
35 let (name, data) = match input.data {
36 Data::Struct(data) => (input.ident, data),
37 _ => return Err(Error::new(input.span(), "only structs supported")),
38 };
39
40 let StructAttributes { rename_all, .. } = StructAttributes::parse(&input.attrs)?;
41
42 let zv = zvariant_path();
43 let mut entries = quote! {};
44 let mut num_entries: usize = 0;
45
46 for f in &data.fields {
47 let FieldAttributes { rename } = FieldAttributes::parse(&f.attrs)?;
48
49 let name = &f.ident;
50 let dict_name = dict_name_for_field(f, rename, rename_all.as_deref())?;
51
52 let is_option = macros::ty_is_option(&f.ty);
53
54 let e = if is_option {
55 quote! {
56 if self.#name.is_some() {
57 map.serialize_entry(#dict_name, &#zv::SerializeValue(self.#name.as_ref().unwrap()))?;
58 }
59 }
60 } else {
61 quote! {
62 map.serialize_entry(#dict_name, &#zv::SerializeValue(&self.#name))?;
63 }
64 };
65
66 entries.extend(e);
67 num_entries += 1;
68 }
69
70 let generics = input.generics;
71 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
72
73 let num_entries = num_entries.to_token_stream();
74 Ok(quote! {
75 #[allow(deprecated)]
76 impl #impl_generics #zv::export::serde::ser::Serialize for #name #ty_generics
77 #where_clause
78 {
79 fn serialize<S>(&self, serializer: S) -> ::std::result::Result<S::Ok, S::Error>
80 where
81 S: #zv::export::serde::ser::Serializer,
82 {
83 use #zv::export::serde::ser::SerializeMap;
84
85 // zbus doesn't care about number of entries (it would need bytes instead)
86 let mut map = serializer.serialize_map(::std::option::Option::Some(#num_entries))?;
87 #entries
88 map.end()
89 }
90 }
91 })
92}
93
94pub fn expand_deserialize_derive(input: DeriveInput) -> Result<TokenStream, Error> {
95 let (name, data) = match input.data {
96 Data::Struct(data) => (input.ident, data),
97 _ => return Err(Error::new(input.span(), "only structs supported")),
98 };
99
100 let StructAttributes {
101 rename_all,
102 deny_unknown_fields,
103 ..
104 } = StructAttributes::parse(&input.attrs)?;
105
106 let visitor = format_ident!("{}Visitor", name);
107 let zv = zvariant_path();
108 let mut fields = Vec::new();
109 let mut req_fields = Vec::new();
110 let mut dict_names = Vec::new();
111 let mut entries = Vec::new();
112
113 for f in &data.fields {
114 let FieldAttributes { rename } = FieldAttributes::parse(&f.attrs)?;
115
116 let name = &f.ident;
117 let dict_name = dict_name_for_field(f, rename, rename_all.as_deref())?;
118
119 let is_option = macros::ty_is_option(&f.ty);
120
121 entries.push(quote! {
122 #dict_name => {
123 // FIXME: add an option about strict parsing (instead of silently skipping the field)
124 #name = access.next_value::<#zv::DeserializeValue<_>>().map(|v| v.0).ok();
125 }
126 });
127
128 dict_names.push(dict_name);
129 fields.push(name);
130
131 if !is_option {
132 req_fields.push(name);
133 }
134 }
135
136 let fallback = if deny_unknown_fields {
137 quote! {
138 field => {
139 return ::std::result::Result::Err(
140 <M::Error as #zv::export::serde::de::Error>::unknown_field(
141 field,
142 &[#(#dict_names),*],
143 ),
144 );
145 }
146 }
147 } else {
148 quote! {
149 unknown => {
150 let _ = access.next_value::<#zv::Value>();
151 }
152 }
153 };
154 entries.push(fallback);
155
156 let (_, ty_generics, _) = input.generics.split_for_impl();
157 let mut generics = input.generics.clone();
158 let def = syn::LifetimeParam {
159 attrs: Vec::new(),
160 lifetime: syn::Lifetime::new("'de", Span::call_site()),
161 colon_token: None,
162 bounds: Punctuated::new(),
163 };
164 generics.params = Some(syn::GenericParam::Lifetime(def))
165 .into_iter()
166 .chain(generics.params)
167 .collect();
168
169 let (impl_generics, _, where_clause) = generics.split_for_impl();
170
171 Ok(quote! {
172 #[allow(deprecated)]
173 impl #impl_generics #zv::export::serde::de::Deserialize<'de> for #name #ty_generics
174 #where_clause
175 {
176 fn deserialize<D>(deserializer: D) -> ::std::result::Result<Self, D::Error>
177 where
178 D: #zv::export::serde::de::Deserializer<'de>,
179 {
180 struct #visitor #ty_generics(::std::marker::PhantomData<#name #ty_generics>);
181
182 impl #impl_generics #zv::export::serde::de::Visitor<'de> for #visitor #ty_generics {
183 type Value = #name #ty_generics;
184
185 fn expecting(&self, formatter: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
186 formatter.write_str("a dictionary")
187 }
188
189 fn visit_map<M>(
190 self,
191 mut access: M,
192 ) -> ::std::result::Result<Self::Value, M::Error>
193 where
194 M: #zv::export::serde::de::MapAccess<'de>,
195 {
196 #( let mut #fields = ::std::default::Default::default(); )*
197
198 // does not check duplicated fields, since those shouldn't exist in stream
199 while let ::std::option::Option::Some(key) = access.next_key::<&str>()? {
200 match key {
201 #(#entries)*
202 }
203 }
204
205 #(let #req_fields = if let ::std::option::Option::Some(val) = #req_fields {
206 val
207 } else {
208 return ::std::result::Result::Err(
209 <M::Error as #zv::export::serde::de::Error>::missing_field(
210 ::std::stringify!(#req_fields),
211 ),
212 );
213 };)*
214
215 ::std::result::Result::Ok(#name { #(#fields),* })
216 }
217 }
218
219
220 deserializer.deserialize_map(#visitor(::std::marker::PhantomData))
221 }
222 }
223 })
224}
225