1use crate::internals::ast::{Container, Data};
2use crate::internals::{attr, ungroup};
3use proc_macro2::Span;
4use std::collections::HashSet;
5use syn::punctuated::{Pair, Punctuated};
6use syn::Token;
7
8// Remove the default from every type parameter because in the generated impls
9// they look like associated types: "error: associated type bindings are not
10// allowed here".
11pub fn without_defaults(generics: &syn::Generics) -> syn::Generics {
12 syn::Generics {
13 params: genericsimpl Iterator
14 .params
15 .iter()
16 .map(|param: &GenericParam| match param {
17 syn::GenericParam::Type(param: &TypeParam) => syn::GenericParam::Type(syn::TypeParam {
18 eq_token: None,
19 default: None,
20 ..param.clone()
21 }),
22 _ => param.clone(),
23 })
24 .collect(),
25 ..generics.clone()
26 }
27}
28
29pub fn with_where_predicates(
30 generics: &syn::Generics,
31 predicates: &[syn::WherePredicate],
32) -> syn::Generics {
33 let mut generics: Generics = generics.clone();
34 generics
35 .make_where_clause()
36 .predicates
37 .extend(iter:predicates.iter().cloned());
38 generics
39}
40
41pub fn with_where_predicates_from_fields(
42 cont: &Container,
43 generics: &syn::Generics,
44 from_field: fn(&attr::Field) -> Option<&[syn::WherePredicate]>,
45) -> syn::Generics {
46 let predicates: impl Iterator = contimpl Iterator
47 .data
48 .all_fields()
49 .filter_map(|field: &Field<'_>| from_field(&field.attrs))
50 .flat_map(<[syn::WherePredicate]>::to_vec);
51
52 let mut generics: Generics = generics.clone();
53 generics.make_where_clause().predicates.extend(iter:predicates);
54 generics
55}
56
57pub fn with_where_predicates_from_variants(
58 cont: &Container,
59 generics: &syn::Generics,
60 from_variant: fn(&attr::Variant) -> Option<&[syn::WherePredicate]>,
61) -> syn::Generics {
62 let variants: &Vec> = match &cont.data {
63 Data::Enum(variants: &Vec>) => variants,
64 Data::Struct(_, _) => {
65 return generics.clone();
66 }
67 };
68
69 let predicates: impl Iterator = variantsimpl Iterator
70 .iter()
71 .filter_map(|variant: &Variant<'_>| from_variant(&variant.attrs))
72 .flat_map(<[syn::WherePredicate]>::to_vec);
73
74 let mut generics: Generics = generics.clone();
75 generics.make_where_clause().predicates.extend(iter:predicates);
76 generics
77}
78
79// Puts the given bound on any generic type parameters that are used in fields
80// for which filter returns true.
81//
82// For example, the following struct needs the bound `A: Serialize, B:
83// Serialize`.
84//
85// struct S<'b, A, B: 'b, C> {
86// a: A,
87// b: Option<&'b B>
88// #[serde(skip_serializing)]
89// c: C,
90// }
91pub fn with_bound(
92 cont: &Container,
93 generics: &syn::Generics,
94 filter: fn(&attr::Field, Option<&attr::Variant>) -> bool,
95 bound: &syn::Path,
96) -> syn::Generics {
97 struct FindTyParams<'ast> {
98 // Set of all generic type parameters on the current struct (A, B, C in
99 // the example). Initialized up front.
100 all_type_params: HashSet<syn::Ident>,
101
102 // Set of generic type parameters used in fields for which filter
103 // returns true (A and B in the example). Filled in as the visitor sees
104 // them.
105 relevant_type_params: HashSet<syn::Ident>,
106
107 // Fields whose type is an associated type of one of the generic type
108 // parameters.
109 associated_type_usage: Vec<&'ast syn::TypePath>,
110 }
111
112 impl<'ast> FindTyParams<'ast> {
113 fn visit_field(&mut self, field: &'ast syn::Field) {
114 if let syn::Type::Path(ty) = ungroup(&field.ty) {
115 if let Some(Pair::Punctuated(t, _)) = ty.path.segments.pairs().next() {
116 if self.all_type_params.contains(&t.ident) {
117 self.associated_type_usage.push(ty);
118 }
119 }
120 }
121 self.visit_type(&field.ty);
122 }
123
124 fn visit_path(&mut self, path: &'ast syn::Path) {
125 if let Some(seg) = path.segments.last() {
126 if seg.ident == "PhantomData" {
127 // Hardcoded exception, because PhantomData<T> implements
128 // Serialize and Deserialize whether or not T implements it.
129 return;
130 }
131 }
132 if path.leading_colon.is_none() && path.segments.len() == 1 {
133 let id = &path.segments[0].ident;
134 if self.all_type_params.contains(id) {
135 self.relevant_type_params.insert(id.clone());
136 }
137 }
138 for segment in &path.segments {
139 self.visit_path_segment(segment);
140 }
141 }
142
143 // Everything below is simply traversing the syntax tree.
144
145 fn visit_type(&mut self, ty: &'ast syn::Type) {
146 match ty {
147 #![cfg_attr(all(test, exhaustive), deny(non_exhaustive_omitted_patterns))]
148 syn::Type::Array(ty) => self.visit_type(&ty.elem),
149 syn::Type::BareFn(ty) => {
150 for arg in &ty.inputs {
151 self.visit_type(&arg.ty);
152 }
153 self.visit_return_type(&ty.output);
154 }
155 syn::Type::Group(ty) => self.visit_type(&ty.elem),
156 syn::Type::ImplTrait(ty) => {
157 for bound in &ty.bounds {
158 self.visit_type_param_bound(bound);
159 }
160 }
161 syn::Type::Macro(ty) => self.visit_macro(&ty.mac),
162 syn::Type::Paren(ty) => self.visit_type(&ty.elem),
163 syn::Type::Path(ty) => {
164 if let Some(qself) = &ty.qself {
165 self.visit_type(&qself.ty);
166 }
167 self.visit_path(&ty.path);
168 }
169 syn::Type::Ptr(ty) => self.visit_type(&ty.elem),
170 syn::Type::Reference(ty) => self.visit_type(&ty.elem),
171 syn::Type::Slice(ty) => self.visit_type(&ty.elem),
172 syn::Type::TraitObject(ty) => {
173 for bound in &ty.bounds {
174 self.visit_type_param_bound(bound);
175 }
176 }
177 syn::Type::Tuple(ty) => {
178 for elem in &ty.elems {
179 self.visit_type(elem);
180 }
181 }
182
183 syn::Type::Infer(_) | syn::Type::Never(_) | syn::Type::Verbatim(_) => {}
184
185 _ => {}
186 }
187 }
188
189 fn visit_path_segment(&mut self, segment: &'ast syn::PathSegment) {
190 self.visit_path_arguments(&segment.arguments);
191 }
192
193 fn visit_path_arguments(&mut self, arguments: &'ast syn::PathArguments) {
194 match arguments {
195 syn::PathArguments::None => {}
196 syn::PathArguments::AngleBracketed(arguments) => {
197 for arg in &arguments.args {
198 match arg {
199 #![cfg_attr(all(test, exhaustive), deny(non_exhaustive_omitted_patterns))]
200 syn::GenericArgument::Type(arg) => self.visit_type(arg),
201 syn::GenericArgument::AssocType(arg) => self.visit_type(&arg.ty),
202 syn::GenericArgument::Lifetime(_)
203 | syn::GenericArgument::Const(_)
204 | syn::GenericArgument::AssocConst(_)
205 | syn::GenericArgument::Constraint(_) => {}
206 _ => {}
207 }
208 }
209 }
210 syn::PathArguments::Parenthesized(arguments) => {
211 for argument in &arguments.inputs {
212 self.visit_type(argument);
213 }
214 self.visit_return_type(&arguments.output);
215 }
216 }
217 }
218
219 fn visit_return_type(&mut self, return_type: &'ast syn::ReturnType) {
220 match return_type {
221 syn::ReturnType::Default => {}
222 syn::ReturnType::Type(_, output) => self.visit_type(output),
223 }
224 }
225
226 fn visit_type_param_bound(&mut self, bound: &'ast syn::TypeParamBound) {
227 match bound {
228 #![cfg_attr(all(test, exhaustive), deny(non_exhaustive_omitted_patterns))]
229 syn::TypeParamBound::Trait(bound) => self.visit_path(&bound.path),
230 syn::TypeParamBound::Lifetime(_) | syn::TypeParamBound::Verbatim(_) => {}
231 _ => {}
232 }
233 }
234
235 // Type parameter should not be considered used by a macro path.
236 //
237 // struct TypeMacro<T> {
238 // mac: T!(),
239 // marker: PhantomData<T>,
240 // }
241 fn visit_macro(&mut self, _mac: &'ast syn::Macro) {}
242 }
243
244 let all_type_params = generics
245 .type_params()
246 .map(|param| param.ident.clone())
247 .collect();
248
249 let mut visitor = FindTyParams {
250 all_type_params,
251 relevant_type_params: HashSet::new(),
252 associated_type_usage: Vec::new(),
253 };
254 match &cont.data {
255 Data::Enum(variants) => {
256 for variant in variants {
257 let relevant_fields = variant
258 .fields
259 .iter()
260 .filter(|field| filter(&field.attrs, Some(&variant.attrs)));
261 for field in relevant_fields {
262 visitor.visit_field(field.original);
263 }
264 }
265 }
266 Data::Struct(_, fields) => {
267 for field in fields.iter().filter(|field| filter(&field.attrs, None)) {
268 visitor.visit_field(field.original);
269 }
270 }
271 }
272
273 let relevant_type_params = visitor.relevant_type_params;
274 let associated_type_usage = visitor.associated_type_usage;
275 let new_predicates = generics
276 .type_params()
277 .map(|param| param.ident.clone())
278 .filter(|id| relevant_type_params.contains(id))
279 .map(|id| syn::TypePath {
280 qself: None,
281 path: id.into(),
282 })
283 .chain(associated_type_usage.into_iter().cloned())
284 .map(|bounded_ty| {
285 syn::WherePredicate::Type(syn::PredicateType {
286 lifetimes: None,
287 // the type parameter that is being bounded e.g. T
288 bounded_ty: syn::Type::Path(bounded_ty),
289 colon_token: <Token![:]>::default(),
290 // the bound e.g. Serialize
291 bounds: vec![syn::TypeParamBound::Trait(syn::TraitBound {
292 paren_token: None,
293 modifier: syn::TraitBoundModifier::None,
294 lifetimes: None,
295 path: bound.clone(),
296 })]
297 .into_iter()
298 .collect(),
299 })
300 });
301
302 let mut generics = generics.clone();
303 generics
304 .make_where_clause()
305 .predicates
306 .extend(new_predicates);
307 generics
308}
309
310pub fn with_self_bound(
311 cont: &Container,
312 generics: &syn::Generics,
313 bound: &syn::Path,
314) -> syn::Generics {
315 let mut generics: Generics = generics.clone();
316 genericsPunctuated
317 .make_where_clause()
318 .predicates
319 .push(syn::WherePredicate::Type(syn::PredicateType {
320 lifetimes: None,
321 // the type that is being bounded e.g. MyStruct<'a, T>
322 bounded_ty: type_of_item(cont),
323 colon_token: <Token![:]>::default(),
324 // the bound e.g. Default
325 bounds: vecIntoIter![syn::TypeParamBound::Trait(syn::TraitBound {
326 paren_token: None,
327 modifier: syn::TraitBoundModifier::None,
328 lifetimes: None,
329 path: bound.clone(),
330 })]
331 .into_iter()
332 .collect(),
333 }));
334 generics
335}
336
337pub fn with_lifetime_bound(generics: &syn::Generics, lifetime: &str) -> syn::Generics {
338 let bound = syn::Lifetime::new(lifetime, Span::call_site());
339 let def = syn::LifetimeParam {
340 attrs: Vec::new(),
341 lifetime: bound.clone(),
342 colon_token: None,
343 bounds: Punctuated::new(),
344 };
345
346 let params = Some(syn::GenericParam::Lifetime(def))
347 .into_iter()
348 .chain(generics.params.iter().cloned().map(|mut param| {
349 match &mut param {
350 syn::GenericParam::Lifetime(param) => {
351 param.bounds.push(bound.clone());
352 }
353 syn::GenericParam::Type(param) => {
354 param
355 .bounds
356 .push(syn::TypeParamBound::Lifetime(bound.clone()));
357 }
358 syn::GenericParam::Const(_) => {}
359 }
360 param
361 }))
362 .collect();
363
364 syn::Generics {
365 params,
366 ..generics.clone()
367 }
368}
369
370fn type_of_item(cont: &Container) -> syn::Type {
371 syn::Type::Path(syn::TypePath {
372 qself: None,
373 path: syn::Path {
374 leading_colon: None,
375 segments: vec![syn::PathSegment {
376 ident: cont.ident.clone(),
377 arguments: syn::PathArguments::AngleBracketed(
378 syn::AngleBracketedGenericArguments {
379 colon2_token: None,
380 lt_token: <Token![<]>::default(),
381 args: cont
382 .generics
383 .params
384 .iter()
385 .map(|param| match param {
386 syn::GenericParam::Type(param) => {
387 syn::GenericArgument::Type(syn::Type::Path(syn::TypePath {
388 qself: None,
389 path: param.ident.clone().into(),
390 }))
391 }
392 syn::GenericParam::Lifetime(param) => {
393 syn::GenericArgument::Lifetime(param.lifetime.clone())
394 }
395 syn::GenericParam::Const(_) => {
396 panic!("Serde does not support const generics yet");
397 }
398 })
399 .collect(),
400 gt_token: <Token![>]>::default(),
401 },
402 ),
403 }]
404 .into_iter()
405 .collect(),
406 },
407 })
408}
409