1use crate::internals::ast::{Container, Data};
2use crate::internals::{attr, ungroup};
3use proc_macro2::Span;
4use std::collections::HashSet;
5use syn::punctuated::{Pair, Punctuated};
6use syn::Token;
7
8// Remove the default from every type parameter because in the generated impls
9// they look like associated types: "error: associated type bindings are not
10// allowed here".
11pub fn without_defaults(generics: &syn::Generics) -> syn::Generics {
12 syn::Generics {
13 params: generics
14 .params
15 .iter()
16 .map(|param| match param {
17 syn::GenericParam::Type(param) => syn::GenericParam::Type(syn::TypeParam {
18 eq_token: None,
19 default: None,
20 ..param.clone()
21 }),
22 _ => param.clone(),
23 })
24 .collect(),
25 ..generics.clone()
26 }
27}
28
29pub fn with_where_predicates(
30 generics: &syn::Generics,
31 predicates: &[syn::WherePredicate],
32) -> syn::Generics {
33 let mut generics = generics.clone();
34 generics
35 .make_where_clause()
36 .predicates
37 .extend(predicates.iter().cloned());
38 generics
39}
40
41pub fn with_where_predicates_from_fields(
42 cont: &Container,
43 generics: &syn::Generics,
44 from_field: fn(&attr::Field) -> Option<&[syn::WherePredicate]>,
45) -> syn::Generics {
46 let predicates = cont
47 .data
48 .all_fields()
49 .filter_map(|field| from_field(&field.attrs))
50 .flat_map(<[syn::WherePredicate]>::to_vec);
51
52 let mut generics = generics.clone();
53 generics.make_where_clause().predicates.extend(predicates);
54 generics
55}
56
57pub fn with_where_predicates_from_variants(
58 cont: &Container,
59 generics: &syn::Generics,
60 from_variant: fn(&attr::Variant) -> Option<&[syn::WherePredicate]>,
61) -> syn::Generics {
62 let variants = match &cont.data {
63 Data::Enum(variants) => variants,
64 Data::Struct(_, _) => {
65 return generics.clone();
66 }
67 };
68
69 let predicates = variants
70 .iter()
71 .filter_map(|variant| from_variant(&variant.attrs))
72 .flat_map(<[syn::WherePredicate]>::to_vec);
73
74 let mut generics = generics.clone();
75 generics.make_where_clause().predicates.extend(predicates);
76 generics
77}
78
79// Puts the given bound on any generic type parameters that are used in fields
80// for which filter returns true.
81//
82// For example, the following struct needs the bound `A: Serialize, B:
83// Serialize`.
84//
85// struct S<'b, A, B: 'b, C> {
86// a: A,
87// b: Option<&'b B>
88// #[serde(skip_serializing)]
89// c: C,
90// }
91pub fn with_bound(
92 cont: &Container,
93 generics: &syn::Generics,
94 filter: fn(&attr::Field, Option<&attr::Variant>) -> bool,
95 bound: &syn::Path,
96) -> syn::Generics {
97 struct FindTyParams<'ast> {
98 // Set of all generic type parameters on the current struct (A, B, C in
99 // the example). Initialized up front.
100 all_type_params: HashSet<syn::Ident>,
101
102 // Set of generic type parameters used in fields for which filter
103 // returns true (A and B in the example). Filled in as the visitor sees
104 // them.
105 relevant_type_params: HashSet<syn::Ident>,
106
107 // Fields whose type is an associated type of one of the generic type
108 // parameters.
109 associated_type_usage: Vec<&'ast syn::TypePath>,
110 }
111
112 impl<'ast> FindTyParams<'ast> {
113 fn visit_field(&mut self, field: &'ast syn::Field) {
114 if let syn::Type::Path(ty) = ungroup(&field.ty) {
115 if let Some(Pair::Punctuated(t, _)) = ty.path.segments.pairs().next() {
116 if self.all_type_params.contains(&t.ident) {
117 self.associated_type_usage.push(ty);
118 }
119 }
120 }
121 self.visit_type(&field.ty);
122 }
123
124 fn visit_path(&mut self, path: &'ast syn::Path) {
125 if let Some(seg) = path.segments.last() {
126 if seg.ident == "PhantomData" {
127 // Hardcoded exception, because PhantomData<T> implements
128 // Serialize and Deserialize whether or not T implements it.
129 return;
130 }
131 }
132 if path.leading_colon.is_none() && path.segments.len() == 1 {
133 let id = &path.segments[0].ident;
134 if self.all_type_params.contains(id) {
135 self.relevant_type_params.insert(id.clone());
136 }
137 }
138 for segment in &path.segments {
139 self.visit_path_segment(segment);
140 }
141 }
142
143 // Everything below is simply traversing the syntax tree.
144
145 fn visit_type(&mut self, ty: &'ast syn::Type) {
146 match ty {
147 syn::Type::Array(ty) => self.visit_type(&ty.elem),
148 syn::Type::BareFn(ty) => {
149 for arg in &ty.inputs {
150 self.visit_type(&arg.ty);
151 }
152 self.visit_return_type(&ty.output);
153 }
154 syn::Type::Group(ty) => self.visit_type(&ty.elem),
155 syn::Type::ImplTrait(ty) => {
156 for bound in &ty.bounds {
157 self.visit_type_param_bound(bound);
158 }
159 }
160 syn::Type::Macro(ty) => self.visit_macro(&ty.mac),
161 syn::Type::Paren(ty) => self.visit_type(&ty.elem),
162 syn::Type::Path(ty) => {
163 if let Some(qself) = &ty.qself {
164 self.visit_type(&qself.ty);
165 }
166 self.visit_path(&ty.path);
167 }
168 syn::Type::Ptr(ty) => self.visit_type(&ty.elem),
169 syn::Type::Reference(ty) => self.visit_type(&ty.elem),
170 syn::Type::Slice(ty) => self.visit_type(&ty.elem),
171 syn::Type::TraitObject(ty) => {
172 for bound in &ty.bounds {
173 self.visit_type_param_bound(bound);
174 }
175 }
176 syn::Type::Tuple(ty) => {
177 for elem in &ty.elems {
178 self.visit_type(elem);
179 }
180 }
181
182 syn::Type::Infer(_) | syn::Type::Never(_) | syn::Type::Verbatim(_) => {}
183
184 #[cfg_attr(all(test, exhaustive), deny(non_exhaustive_omitted_patterns))]
185 _ => {}
186 }
187 }
188
189 fn visit_path_segment(&mut self, segment: &'ast syn::PathSegment) {
190 self.visit_path_arguments(&segment.arguments);
191 }
192
193 fn visit_path_arguments(&mut self, arguments: &'ast syn::PathArguments) {
194 match arguments {
195 syn::PathArguments::None => {}
196 syn::PathArguments::AngleBracketed(arguments) => {
197 for arg in &arguments.args {
198 match arg {
199 syn::GenericArgument::Type(arg) => self.visit_type(arg),
200 syn::GenericArgument::AssocType(arg) => self.visit_type(&arg.ty),
201 syn::GenericArgument::Lifetime(_)
202 | syn::GenericArgument::Const(_)
203 | syn::GenericArgument::AssocConst(_)
204 | syn::GenericArgument::Constraint(_) => {}
205 #[cfg_attr(
206 all(test, exhaustive),
207 deny(non_exhaustive_omitted_patterns)
208 )]
209 _ => {}
210 }
211 }
212 }
213 syn::PathArguments::Parenthesized(arguments) => {
214 for argument in &arguments.inputs {
215 self.visit_type(argument);
216 }
217 self.visit_return_type(&arguments.output);
218 }
219 }
220 }
221
222 fn visit_return_type(&mut self, return_type: &'ast syn::ReturnType) {
223 match return_type {
224 syn::ReturnType::Default => {}
225 syn::ReturnType::Type(_, output) => self.visit_type(output),
226 }
227 }
228
229 fn visit_type_param_bound(&mut self, bound: &'ast syn::TypeParamBound) {
230 match bound {
231 syn::TypeParamBound::Trait(bound) => self.visit_path(&bound.path),
232 syn::TypeParamBound::Lifetime(_) | syn::TypeParamBound::Verbatim(_) => {}
233 #[cfg_attr(all(test, exhaustive), deny(non_exhaustive_omitted_patterns))]
234 _ => {}
235 }
236 }
237
238 // Type parameter should not be considered used by a macro path.
239 //
240 // struct TypeMacro<T> {
241 // mac: T!(),
242 // marker: PhantomData<T>,
243 // }
244 fn visit_macro(&mut self, _mac: &'ast syn::Macro) {}
245 }
246
247 let all_type_params = generics
248 .type_params()
249 .map(|param| param.ident.clone())
250 .collect();
251
252 let mut visitor = FindTyParams {
253 all_type_params,
254 relevant_type_params: HashSet::new(),
255 associated_type_usage: Vec::new(),
256 };
257 match &cont.data {
258 Data::Enum(variants) => {
259 for variant in variants {
260 let relevant_fields = variant
261 .fields
262 .iter()
263 .filter(|field| filter(&field.attrs, Some(&variant.attrs)));
264 for field in relevant_fields {
265 visitor.visit_field(field.original);
266 }
267 }
268 }
269 Data::Struct(_, fields) => {
270 for field in fields.iter().filter(|field| filter(&field.attrs, None)) {
271 visitor.visit_field(field.original);
272 }
273 }
274 }
275
276 let relevant_type_params = visitor.relevant_type_params;
277 let associated_type_usage = visitor.associated_type_usage;
278 let new_predicates = generics
279 .type_params()
280 .map(|param| param.ident.clone())
281 .filter(|id| relevant_type_params.contains(id))
282 .map(|id| syn::TypePath {
283 qself: None,
284 path: id.into(),
285 })
286 .chain(associated_type_usage.into_iter().cloned())
287 .map(|bounded_ty| {
288 syn::WherePredicate::Type(syn::PredicateType {
289 lifetimes: None,
290 // the type parameter that is being bounded e.g. T
291 bounded_ty: syn::Type::Path(bounded_ty),
292 colon_token: <Token![:]>::default(),
293 // the bound e.g. Serialize
294 bounds: vec![syn::TypeParamBound::Trait(syn::TraitBound {
295 paren_token: None,
296 modifier: syn::TraitBoundModifier::None,
297 lifetimes: None,
298 path: bound.clone(),
299 })]
300 .into_iter()
301 .collect(),
302 })
303 });
304
305 let mut generics = generics.clone();
306 generics
307 .make_where_clause()
308 .predicates
309 .extend(new_predicates);
310 generics
311}
312
313pub fn with_self_bound(
314 cont: &Container,
315 generics: &syn::Generics,
316 bound: &syn::Path,
317) -> syn::Generics {
318 let mut generics = generics.clone();
319 generics
320 .make_where_clause()
321 .predicates
322 .push(syn::WherePredicate::Type(syn::PredicateType {
323 lifetimes: None,
324 // the type that is being bounded e.g. MyStruct<'a, T>
325 bounded_ty: type_of_item(cont),
326 colon_token: <Token![:]>::default(),
327 // the bound e.g. Default
328 bounds: vec![syn::TypeParamBound::Trait(syn::TraitBound {
329 paren_token: None,
330 modifier: syn::TraitBoundModifier::None,
331 lifetimes: None,
332 path: bound.clone(),
333 })]
334 .into_iter()
335 .collect(),
336 }));
337 generics
338}
339
340pub fn with_lifetime_bound(generics: &syn::Generics, lifetime: &str) -> syn::Generics {
341 let bound = syn::Lifetime::new(lifetime, Span::call_site());
342 let def = syn::LifetimeParam {
343 attrs: Vec::new(),
344 lifetime: bound.clone(),
345 colon_token: None,
346 bounds: Punctuated::new(),
347 };
348
349 let params = Some(syn::GenericParam::Lifetime(def))
350 .into_iter()
351 .chain(generics.params.iter().cloned().map(|mut param| {
352 match &mut param {
353 syn::GenericParam::Lifetime(param) => {
354 param.bounds.push(bound.clone());
355 }
356 syn::GenericParam::Type(param) => {
357 param
358 .bounds
359 .push(syn::TypeParamBound::Lifetime(bound.clone()));
360 }
361 syn::GenericParam::Const(_) => {}
362 }
363 param
364 }))
365 .collect();
366
367 syn::Generics {
368 params,
369 ..generics.clone()
370 }
371}
372
373fn type_of_item(cont: &Container) -> syn::Type {
374 syn::Type::Path(syn::TypePath {
375 qself: None,
376 path: syn::Path {
377 leading_colon: None,
378 segments: vec![syn::PathSegment {
379 ident: cont.ident.clone(),
380 arguments: syn::PathArguments::AngleBracketed(
381 syn::AngleBracketedGenericArguments {
382 colon2_token: None,
383 lt_token: <Token![<]>::default(),
384 args: cont
385 .generics
386 .params
387 .iter()
388 .map(|param| match param {
389 syn::GenericParam::Type(param) => {
390 syn::GenericArgument::Type(syn::Type::Path(syn::TypePath {
391 qself: None,
392 path: param.ident.clone().into(),
393 }))
394 }
395 syn::GenericParam::Lifetime(param) => {
396 syn::GenericArgument::Lifetime(param.lifetime.clone())
397 }
398 syn::GenericParam::Const(_) => {
399 panic!("Serde does not support const generics yet");
400 }
401 })
402 .collect(),
403 gt_token: <Token![>]>::default(),
404 },
405 ),
406 }]
407 .into_iter()
408 .collect(),
409 },
410 })
411}
412