1 | use proc_macro2::TokenStream; |
2 | use quote::quote; |
3 | use syn::{spanned::Spanned as _, Error, Result}; |
4 | |
5 | use crate::utils::{ |
6 | self, AttrParams, DeriveType, FullMetaInfo, HashSet, MetaInfo, MultiFieldData, |
7 | State, |
8 | }; |
9 | |
10 | pub fn expand( |
11 | input: &syn::DeriveInput, |
12 | trait_name: &'static str, |
13 | ) -> Result<TokenStream> { |
14 | let syn::DeriveInput { |
15 | ident, generics, .. |
16 | } = input; |
17 | |
18 | let state = State::with_attr_params( |
19 | input, |
20 | trait_name, |
21 | quote!(::std::error), |
22 | trait_name.to_lowercase(), |
23 | allowed_attr_params(), |
24 | )?; |
25 | |
26 | let type_params: HashSet<_> = generics |
27 | .params |
28 | .iter() |
29 | .filter_map(|generic| match generic { |
30 | syn::GenericParam::Type(ty) => Some(ty.ident.clone()), |
31 | _ => None, |
32 | }) |
33 | .collect(); |
34 | |
35 | let (bounds, source, backtrace) = match state.derive_type { |
36 | DeriveType::Named | DeriveType::Unnamed => render_struct(&type_params, &state)?, |
37 | DeriveType::Enum => render_enum(&type_params, &state)?, |
38 | }; |
39 | |
40 | let source = source.map(|source| { |
41 | quote! { |
42 | fn source(&self) -> Option<&(dyn ::std::error::Error + 'static)> { |
43 | #source |
44 | } |
45 | } |
46 | }); |
47 | |
48 | let backtrace = backtrace.map(|backtrace| { |
49 | quote! { |
50 | fn backtrace(&self) -> Option<&::std::backtrace::Backtrace> { |
51 | #backtrace |
52 | } |
53 | } |
54 | }); |
55 | |
56 | let mut generics = generics.clone(); |
57 | |
58 | if !type_params.is_empty() { |
59 | let generic_parameters = generics.params.iter(); |
60 | generics = utils::add_extra_where_clauses( |
61 | &generics, |
62 | quote! { |
63 | where |
64 | #ident<#(#generic_parameters),*>: ::std::fmt::Debug + ::std::fmt::Display |
65 | }, |
66 | ); |
67 | } |
68 | |
69 | if !bounds.is_empty() { |
70 | let bounds = bounds.iter(); |
71 | generics = utils::add_extra_where_clauses( |
72 | &generics, |
73 | quote! { |
74 | where |
75 | #(#bounds: ::std::fmt::Debug + ::std::fmt::Display + ::std::error::Error + 'static),* |
76 | }, |
77 | ); |
78 | } |
79 | |
80 | let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); |
81 | |
82 | let render = quote! { |
83 | impl#impl_generics ::std::error::Error for #ident#ty_generics #where_clause { |
84 | #source |
85 | #backtrace |
86 | } |
87 | }; |
88 | |
89 | Ok(render) |
90 | } |
91 | |
92 | fn render_struct( |
93 | type_params: &HashSet<syn::Ident>, |
94 | state: &State, |
95 | ) -> Result<(HashSet<syn::Type>, Option<TokenStream>, Option<TokenStream>)> { |
96 | let parsed_fields: ParsedFields<'_, '_> = parse_fields(type_params, state)?; |
97 | |
98 | let source: Option = parsed_fields.render_source_as_struct(); |
99 | let backtrace: Option = parsed_fields.render_backtrace_as_struct(); |
100 | |
101 | Ok((parsed_fields.bounds, source, backtrace)) |
102 | } |
103 | |
104 | fn render_enum( |
105 | type_params: &HashSet<syn::Ident>, |
106 | state: &State, |
107 | ) -> Result<(HashSet<syn::Type>, Option<TokenStream>, Option<TokenStream>)> { |
108 | let mut bounds = HashSet::default(); |
109 | let mut source_match_arms = Vec::new(); |
110 | let mut backtrace_match_arms = Vec::new(); |
111 | |
112 | for variant in state.enabled_variant_data().variants { |
113 | let default_info = FullMetaInfo { |
114 | enabled: true, |
115 | ..FullMetaInfo::default() |
116 | }; |
117 | |
118 | let state = State::from_variant( |
119 | state.input, |
120 | state.trait_name, |
121 | state.trait_module.clone(), |
122 | state.trait_attr.clone(), |
123 | allowed_attr_params(), |
124 | variant, |
125 | default_info, |
126 | )?; |
127 | |
128 | let parsed_fields = parse_fields(type_params, &state)?; |
129 | |
130 | if let Some(expr) = parsed_fields.render_source_as_enum_variant_match_arm() { |
131 | source_match_arms.push(expr); |
132 | } |
133 | |
134 | if let Some(expr) = parsed_fields.render_backtrace_as_enum_variant_match_arm() { |
135 | backtrace_match_arms.push(expr); |
136 | } |
137 | |
138 | bounds.extend(parsed_fields.bounds.into_iter()); |
139 | } |
140 | |
141 | let render = |match_arms: &mut Vec<TokenStream>| { |
142 | if !match_arms.is_empty() && match_arms.len() < state.variants.len() { |
143 | match_arms.push(quote!(_ => None)); |
144 | } |
145 | |
146 | if !match_arms.is_empty() { |
147 | let expr = quote! { |
148 | match self { |
149 | #(#match_arms),* |
150 | } |
151 | }; |
152 | |
153 | Some(expr) |
154 | } else { |
155 | None |
156 | } |
157 | }; |
158 | |
159 | let source = render(&mut source_match_arms); |
160 | let backtrace = render(&mut backtrace_match_arms); |
161 | |
162 | Ok((bounds, source, backtrace)) |
163 | } |
164 | |
165 | fn allowed_attr_params() -> AttrParams { |
166 | AttrParams { |
167 | enum_: vec!["ignore" ], |
168 | struct_: vec!["ignore" ], |
169 | variant: vec!["ignore" ], |
170 | field: vec!["ignore" , "source" , "backtrace" ], |
171 | } |
172 | } |
173 | |
174 | struct ParsedFields<'input, 'state> { |
175 | data: MultiFieldData<'input, 'state>, |
176 | source: Option<usize>, |
177 | backtrace: Option<usize>, |
178 | bounds: HashSet<syn::Type>, |
179 | } |
180 | |
181 | impl<'input, 'state> ParsedFields<'input, 'state> { |
182 | fn new(data: MultiFieldData<'input, 'state>) -> Self { |
183 | Self { |
184 | data, |
185 | source: None, |
186 | backtrace: None, |
187 | bounds: HashSet::default(), |
188 | } |
189 | } |
190 | } |
191 | |
192 | impl<'input, 'state> ParsedFields<'input, 'state> { |
193 | fn render_source_as_struct(&self) -> Option<TokenStream> { |
194 | let source = self.source?; |
195 | let ident = &self.data.members[source]; |
196 | Some(render_some(quote!(&#ident))) |
197 | } |
198 | |
199 | fn render_source_as_enum_variant_match_arm(&self) -> Option<TokenStream> { |
200 | let source = self.source?; |
201 | let pattern = self.data.matcher(&[source], &[quote!(source)]); |
202 | let expr = render_some(quote!(source)); |
203 | Some(quote!(#pattern => #expr)) |
204 | } |
205 | |
206 | fn render_backtrace_as_struct(&self) -> Option<TokenStream> { |
207 | let backtrace = self.backtrace?; |
208 | let backtrace_expr = &self.data.members[backtrace]; |
209 | Some(quote!(Some(&#backtrace_expr))) |
210 | } |
211 | |
212 | fn render_backtrace_as_enum_variant_match_arm(&self) -> Option<TokenStream> { |
213 | let backtrace = self.backtrace?; |
214 | let pattern = self.data.matcher(&[backtrace], &[quote!(backtrace)]); |
215 | Some(quote!(#pattern => Some(backtrace))) |
216 | } |
217 | } |
218 | |
219 | fn render_some<T>(expr: T) -> TokenStream |
220 | where |
221 | T: quote::ToTokens, |
222 | { |
223 | quote!(Some(#expr as &(dyn ::std::error::Error + 'static))) |
224 | } |
225 | |
226 | fn parse_fields<'input, 'state>( |
227 | type_params: &HashSet<syn::Ident>, |
228 | state: &'state State<'input>, |
229 | ) -> Result<ParsedFields<'input, 'state>> { |
230 | let mut parsed_fields = match state.derive_type { |
231 | DeriveType::Named => { |
232 | parse_fields_impl(state, |attr, field, _| { |
233 | // Unwrapping is safe, cause fields in named struct |
234 | // always have an ident |
235 | let ident = field.ident.as_ref().unwrap(); |
236 | |
237 | match attr { |
238 | "source" => ident == "source" , |
239 | "backtrace" => { |
240 | ident == "backtrace" |
241 | || is_type_path_ends_with_segment(&field.ty, "Backtrace" ) |
242 | } |
243 | _ => unreachable!(), |
244 | } |
245 | }) |
246 | } |
247 | |
248 | DeriveType::Unnamed => { |
249 | let mut parsed_fields = |
250 | parse_fields_impl(state, |attr, field, len| match attr { |
251 | "source" => { |
252 | len == 1 |
253 | && !is_type_path_ends_with_segment(&field.ty, "Backtrace" ) |
254 | } |
255 | "backtrace" => { |
256 | is_type_path_ends_with_segment(&field.ty, "Backtrace" ) |
257 | } |
258 | _ => unreachable!(), |
259 | })?; |
260 | |
261 | parsed_fields.source = parsed_fields |
262 | .source |
263 | .or_else(|| infer_source_field(&state.fields, &parsed_fields)); |
264 | |
265 | Ok(parsed_fields) |
266 | } |
267 | |
268 | _ => unreachable!(), |
269 | }?; |
270 | |
271 | if let Some(source) = parsed_fields.source { |
272 | add_bound_if_type_parameter_used_in_type( |
273 | &mut parsed_fields.bounds, |
274 | type_params, |
275 | &state.fields[source].ty, |
276 | ); |
277 | } |
278 | |
279 | Ok(parsed_fields) |
280 | } |
281 | |
282 | /// Checks if `ty` is [`syn::Type::Path`] and ends with segment matching `tail` |
283 | /// and doesn't contain any generic parameters. |
284 | fn is_type_path_ends_with_segment(ty: &syn::Type, tail: &str) -> bool { |
285 | let ty: &TypePath = match ty { |
286 | syn::Type::Path(ty: &TypePath) => ty, |
287 | _ => return false, |
288 | }; |
289 | |
290 | // Unwrapping is safe, cause 'syn::TypePath.path.segments' |
291 | // have to have at least one segment |
292 | let segment: &PathSegment = ty.path.segments.last().unwrap(); |
293 | |
294 | match segment.arguments { |
295 | syn::PathArguments::None => (), |
296 | _ => return false, |
297 | }; |
298 | |
299 | segment.ident == tail |
300 | } |
301 | |
302 | fn infer_source_field( |
303 | fields: &[&syn::Field], |
304 | parsed_fields: &ParsedFields, |
305 | ) -> Option<usize> { |
306 | // if we have exactly two fields |
307 | if fields.len() != 2 { |
308 | return None; |
309 | } |
310 | |
311 | // no source field was specified/inferred |
312 | if parsed_fields.source.is_some() { |
313 | return None; |
314 | } |
315 | |
316 | // but one of the fields was specified/inferred as backtrace field |
317 | if let Some(backtrace: usize) = parsed_fields.backtrace { |
318 | // then infer *other field* as source field |
319 | let source: usize = (backtrace + 1) % 2; |
320 | // unless it was explicitly marked as non-source |
321 | if parsed_fields.data.infos[source].info.source != Some(false) { |
322 | return Some(source); |
323 | } |
324 | } |
325 | |
326 | None |
327 | } |
328 | |
329 | fn parse_fields_impl<'input, 'state, P>( |
330 | state: &'state State<'input>, |
331 | is_valid_default_field_for_attr: P, |
332 | ) -> Result<ParsedFields<'input, 'state>> |
333 | where |
334 | P: Fn(&str, &syn::Field, usize) -> bool, |
335 | { |
336 | let MultiFieldData { fields, infos, .. } = state.enabled_fields_data(); |
337 | |
338 | let iter = fields |
339 | .iter() |
340 | .zip(infos.iter().map(|info| &info.info)) |
341 | .enumerate() |
342 | .map(|(index, (field, info))| (index, *field, info)); |
343 | |
344 | let source = parse_field_impl( |
345 | &is_valid_default_field_for_attr, |
346 | state.fields.len(), |
347 | iter.clone(), |
348 | "source" , |
349 | |info| info.source, |
350 | )?; |
351 | |
352 | let backtrace = parse_field_impl( |
353 | &is_valid_default_field_for_attr, |
354 | state.fields.len(), |
355 | iter.clone(), |
356 | "backtrace" , |
357 | |info| info.backtrace, |
358 | )?; |
359 | |
360 | let mut parsed_fields = ParsedFields::new(state.enabled_fields_data()); |
361 | |
362 | if let Some((index, _, _)) = source { |
363 | parsed_fields.source = Some(index); |
364 | } |
365 | |
366 | if let Some((index, _, _)) = backtrace { |
367 | parsed_fields.backtrace = Some(index); |
368 | } |
369 | |
370 | Ok(parsed_fields) |
371 | } |
372 | |
373 | fn parse_field_impl<'a, P, V>( |
374 | is_valid_default_field_for_attr: &P, |
375 | len: usize, |
376 | iter: impl Iterator<Item = (usize, &'a syn::Field, &'a MetaInfo)> + Clone, |
377 | attr: &str, |
378 | value: V, |
379 | ) -> Result<Option<(usize, &'a syn::Field, &'a MetaInfo)>> |
380 | where |
381 | P: Fn(&str, &syn::Field, usize) -> bool, |
382 | V: Fn(&MetaInfo) -> Option<bool>, |
383 | { |
384 | let explicit_fields = iter.clone().filter(|(_, _, info)| match value(info) { |
385 | Some(true) => true, |
386 | _ => false, |
387 | }); |
388 | |
389 | let inferred_fields = iter.filter(|(_, field, info)| match value(info) { |
390 | None => is_valid_default_field_for_attr(attr, field, len), |
391 | _ => false, |
392 | }); |
393 | |
394 | let field = assert_iter_contains_zero_or_one_item( |
395 | explicit_fields, |
396 | &format!( |
397 | "Multiple ` {}` attributes specified. \ |
398 | Single attribute per struct/enum variant allowed." , |
399 | attr |
400 | ), |
401 | )?; |
402 | |
403 | let field = match field { |
404 | field @ Some(_) => field, |
405 | None => assert_iter_contains_zero_or_one_item( |
406 | inferred_fields, |
407 | "Conflicting fields found. Consider specifying some \ |
408 | `#[error(...)]` attributes to resolve conflict." , |
409 | )?, |
410 | }; |
411 | |
412 | Ok(field) |
413 | } |
414 | |
415 | fn assert_iter_contains_zero_or_one_item<'a>( |
416 | mut iter: impl Iterator<Item = (usize, &'a syn::Field, &'a MetaInfo)>, |
417 | error_msg: &str, |
418 | ) -> Result<Option<(usize, &'a syn::Field, &'a MetaInfo)>> { |
419 | let item: (usize, &Field, &MetaInfo) = match iter.next() { |
420 | Some(item: (usize, &Field, &MetaInfo)) => item, |
421 | None => return Ok(None), |
422 | }; |
423 | |
424 | if let Some((_, field: &Field, _)) = iter.next() { |
425 | return Err(Error::new(field.span(), message:error_msg)); |
426 | } |
427 | |
428 | Ok(Some(item)) |
429 | } |
430 | |
431 | fn add_bound_if_type_parameter_used_in_type( |
432 | bounds: &mut HashSet<syn::Type>, |
433 | type_params: &HashSet<syn::Ident>, |
434 | ty: &syn::Type, |
435 | ) { |
436 | if let Some(ty: Type) = utils::get_if_type_parameter_used_in_type(type_parameters:type_params, ty) { |
437 | bounds.insert(ty); |
438 | } |
439 | } |
440 | |