1use proc_macro2::TokenStream;
2use quote::quote;
3use syn::{spanned::Spanned as _, Error, Result};
4
5use crate::utils::{
6 self, AttrParams, DeriveType, FullMetaInfo, HashSet, MetaInfo, MultiFieldData,
7 State,
8};
9
10pub fn expand(
11 input: &syn::DeriveInput,
12 trait_name: &'static str,
13) -> Result<TokenStream> {
14 let syn::DeriveInput {
15 ident, generics, ..
16 } = input;
17
18 let state = State::with_attr_params(
19 input,
20 trait_name,
21 quote!(::std::error),
22 trait_name.to_lowercase(),
23 allowed_attr_params(),
24 )?;
25
26 let type_params: HashSet<_> = generics
27 .params
28 .iter()
29 .filter_map(|generic| match generic {
30 syn::GenericParam::Type(ty) => Some(ty.ident.clone()),
31 _ => None,
32 })
33 .collect();
34
35 let (bounds, source, backtrace) = match state.derive_type {
36 DeriveType::Named | DeriveType::Unnamed => render_struct(&type_params, &state)?,
37 DeriveType::Enum => render_enum(&type_params, &state)?,
38 };
39
40 let source = source.map(|source| {
41 quote! {
42 fn source(&self) -> Option<&(dyn ::std::error::Error + 'static)> {
43 #source
44 }
45 }
46 });
47
48 let backtrace = backtrace.map(|backtrace| {
49 quote! {
50 fn backtrace(&self) -> Option<&::std::backtrace::Backtrace> {
51 #backtrace
52 }
53 }
54 });
55
56 let mut generics = generics.clone();
57
58 if !type_params.is_empty() {
59 let generic_parameters = generics.params.iter();
60 generics = utils::add_extra_where_clauses(
61 &generics,
62 quote! {
63 where
64 #ident<#(#generic_parameters),*>: ::std::fmt::Debug + ::std::fmt::Display
65 },
66 );
67 }
68
69 if !bounds.is_empty() {
70 let bounds = bounds.iter();
71 generics = utils::add_extra_where_clauses(
72 &generics,
73 quote! {
74 where
75 #(#bounds: ::std::fmt::Debug + ::std::fmt::Display + ::std::error::Error + 'static),*
76 },
77 );
78 }
79
80 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
81
82 let render = quote! {
83 impl#impl_generics ::std::error::Error for #ident#ty_generics #where_clause {
84 #source
85 #backtrace
86 }
87 };
88
89 Ok(render)
90}
91
92fn render_struct(
93 type_params: &HashSet<syn::Ident>,
94 state: &State,
95) -> Result<(HashSet<syn::Type>, Option<TokenStream>, Option<TokenStream>)> {
96 let parsed_fields: ParsedFields<'_, '_> = parse_fields(type_params, state)?;
97
98 let source: Option = parsed_fields.render_source_as_struct();
99 let backtrace: Option = parsed_fields.render_backtrace_as_struct();
100
101 Ok((parsed_fields.bounds, source, backtrace))
102}
103
104fn render_enum(
105 type_params: &HashSet<syn::Ident>,
106 state: &State,
107) -> Result<(HashSet<syn::Type>, Option<TokenStream>, Option<TokenStream>)> {
108 let mut bounds = HashSet::default();
109 let mut source_match_arms = Vec::new();
110 let mut backtrace_match_arms = Vec::new();
111
112 for variant in state.enabled_variant_data().variants {
113 let default_info = FullMetaInfo {
114 enabled: true,
115 ..FullMetaInfo::default()
116 };
117
118 let state = State::from_variant(
119 state.input,
120 state.trait_name,
121 state.trait_module.clone(),
122 state.trait_attr.clone(),
123 allowed_attr_params(),
124 variant,
125 default_info,
126 )?;
127
128 let parsed_fields = parse_fields(type_params, &state)?;
129
130 if let Some(expr) = parsed_fields.render_source_as_enum_variant_match_arm() {
131 source_match_arms.push(expr);
132 }
133
134 if let Some(expr) = parsed_fields.render_backtrace_as_enum_variant_match_arm() {
135 backtrace_match_arms.push(expr);
136 }
137
138 bounds.extend(parsed_fields.bounds.into_iter());
139 }
140
141 let render = |match_arms: &mut Vec<TokenStream>| {
142 if !match_arms.is_empty() && match_arms.len() < state.variants.len() {
143 match_arms.push(quote!(_ => None));
144 }
145
146 if !match_arms.is_empty() {
147 let expr = quote! {
148 match self {
149 #(#match_arms),*
150 }
151 };
152
153 Some(expr)
154 } else {
155 None
156 }
157 };
158
159 let source = render(&mut source_match_arms);
160 let backtrace = render(&mut backtrace_match_arms);
161
162 Ok((bounds, source, backtrace))
163}
164
165fn allowed_attr_params() -> AttrParams {
166 AttrParams {
167 enum_: vec!["ignore"],
168 struct_: vec!["ignore"],
169 variant: vec!["ignore"],
170 field: vec!["ignore", "source", "backtrace"],
171 }
172}
173
174struct ParsedFields<'input, 'state> {
175 data: MultiFieldData<'input, 'state>,
176 source: Option<usize>,
177 backtrace: Option<usize>,
178 bounds: HashSet<syn::Type>,
179}
180
181impl<'input, 'state> ParsedFields<'input, 'state> {
182 fn new(data: MultiFieldData<'input, 'state>) -> Self {
183 Self {
184 data,
185 source: None,
186 backtrace: None,
187 bounds: HashSet::default(),
188 }
189 }
190}
191
192impl<'input, 'state> ParsedFields<'input, 'state> {
193 fn render_source_as_struct(&self) -> Option<TokenStream> {
194 let source = self.source?;
195 let ident = &self.data.members[source];
196 Some(render_some(quote!(&#ident)))
197 }
198
199 fn render_source_as_enum_variant_match_arm(&self) -> Option<TokenStream> {
200 let source = self.source?;
201 let pattern = self.data.matcher(&[source], &[quote!(source)]);
202 let expr = render_some(quote!(source));
203 Some(quote!(#pattern => #expr))
204 }
205
206 fn render_backtrace_as_struct(&self) -> Option<TokenStream> {
207 let backtrace = self.backtrace?;
208 let backtrace_expr = &self.data.members[backtrace];
209 Some(quote!(Some(&#backtrace_expr)))
210 }
211
212 fn render_backtrace_as_enum_variant_match_arm(&self) -> Option<TokenStream> {
213 let backtrace = self.backtrace?;
214 let pattern = self.data.matcher(&[backtrace], &[quote!(backtrace)]);
215 Some(quote!(#pattern => Some(backtrace)))
216 }
217}
218
219fn render_some<T>(expr: T) -> TokenStream
220where
221 T: quote::ToTokens,
222{
223 quote!(Some(#expr as &(dyn ::std::error::Error + 'static)))
224}
225
226fn parse_fields<'input, 'state>(
227 type_params: &HashSet<syn::Ident>,
228 state: &'state State<'input>,
229) -> Result<ParsedFields<'input, 'state>> {
230 let mut parsed_fields = match state.derive_type {
231 DeriveType::Named => {
232 parse_fields_impl(state, |attr, field, _| {
233 // Unwrapping is safe, cause fields in named struct
234 // always have an ident
235 let ident = field.ident.as_ref().unwrap();
236
237 match attr {
238 "source" => ident == "source",
239 "backtrace" => {
240 ident == "backtrace"
241 || is_type_path_ends_with_segment(&field.ty, "Backtrace")
242 }
243 _ => unreachable!(),
244 }
245 })
246 }
247
248 DeriveType::Unnamed => {
249 let mut parsed_fields =
250 parse_fields_impl(state, |attr, field, len| match attr {
251 "source" => {
252 len == 1
253 && !is_type_path_ends_with_segment(&field.ty, "Backtrace")
254 }
255 "backtrace" => {
256 is_type_path_ends_with_segment(&field.ty, "Backtrace")
257 }
258 _ => unreachable!(),
259 })?;
260
261 parsed_fields.source = parsed_fields
262 .source
263 .or_else(|| infer_source_field(&state.fields, &parsed_fields));
264
265 Ok(parsed_fields)
266 }
267
268 _ => unreachable!(),
269 }?;
270
271 if let Some(source) = parsed_fields.source {
272 add_bound_if_type_parameter_used_in_type(
273 &mut parsed_fields.bounds,
274 type_params,
275 &state.fields[source].ty,
276 );
277 }
278
279 Ok(parsed_fields)
280}
281
282/// Checks if `ty` is [`syn::Type::Path`] and ends with segment matching `tail`
283/// and doesn't contain any generic parameters.
284fn is_type_path_ends_with_segment(ty: &syn::Type, tail: &str) -> bool {
285 let ty: &TypePath = match ty {
286 syn::Type::Path(ty: &TypePath) => ty,
287 _ => return false,
288 };
289
290 // Unwrapping is safe, cause 'syn::TypePath.path.segments'
291 // have to have at least one segment
292 let segment: &PathSegment = ty.path.segments.last().unwrap();
293
294 match segment.arguments {
295 syn::PathArguments::None => (),
296 _ => return false,
297 };
298
299 segment.ident == tail
300}
301
302fn infer_source_field(
303 fields: &[&syn::Field],
304 parsed_fields: &ParsedFields,
305) -> Option<usize> {
306 // if we have exactly two fields
307 if fields.len() != 2 {
308 return None;
309 }
310
311 // no source field was specified/inferred
312 if parsed_fields.source.is_some() {
313 return None;
314 }
315
316 // but one of the fields was specified/inferred as backtrace field
317 if let Some(backtrace: usize) = parsed_fields.backtrace {
318 // then infer *other field* as source field
319 let source: usize = (backtrace + 1) % 2;
320 // unless it was explicitly marked as non-source
321 if parsed_fields.data.infos[source].info.source != Some(false) {
322 return Some(source);
323 }
324 }
325
326 None
327}
328
329fn parse_fields_impl<'input, 'state, P>(
330 state: &'state State<'input>,
331 is_valid_default_field_for_attr: P,
332) -> Result<ParsedFields<'input, 'state>>
333where
334 P: Fn(&str, &syn::Field, usize) -> bool,
335{
336 let MultiFieldData { fields, infos, .. } = state.enabled_fields_data();
337
338 let iter = fields
339 .iter()
340 .zip(infos.iter().map(|info| &info.info))
341 .enumerate()
342 .map(|(index, (field, info))| (index, *field, info));
343
344 let source = parse_field_impl(
345 &is_valid_default_field_for_attr,
346 state.fields.len(),
347 iter.clone(),
348 "source",
349 |info| info.source,
350 )?;
351
352 let backtrace = parse_field_impl(
353 &is_valid_default_field_for_attr,
354 state.fields.len(),
355 iter.clone(),
356 "backtrace",
357 |info| info.backtrace,
358 )?;
359
360 let mut parsed_fields = ParsedFields::new(state.enabled_fields_data());
361
362 if let Some((index, _, _)) = source {
363 parsed_fields.source = Some(index);
364 }
365
366 if let Some((index, _, _)) = backtrace {
367 parsed_fields.backtrace = Some(index);
368 }
369
370 Ok(parsed_fields)
371}
372
373fn parse_field_impl<'a, P, V>(
374 is_valid_default_field_for_attr: &P,
375 len: usize,
376 iter: impl Iterator<Item = (usize, &'a syn::Field, &'a MetaInfo)> + Clone,
377 attr: &str,
378 value: V,
379) -> Result<Option<(usize, &'a syn::Field, &'a MetaInfo)>>
380where
381 P: Fn(&str, &syn::Field, usize) -> bool,
382 V: Fn(&MetaInfo) -> Option<bool>,
383{
384 let explicit_fields = iter.clone().filter(|(_, _, info)| match value(info) {
385 Some(true) => true,
386 _ => false,
387 });
388
389 let inferred_fields = iter.filter(|(_, field, info)| match value(info) {
390 None => is_valid_default_field_for_attr(attr, field, len),
391 _ => false,
392 });
393
394 let field = assert_iter_contains_zero_or_one_item(
395 explicit_fields,
396 &format!(
397 "Multiple `{}` attributes specified. \
398 Single attribute per struct/enum variant allowed.",
399 attr
400 ),
401 )?;
402
403 let field = match field {
404 field @ Some(_) => field,
405 None => assert_iter_contains_zero_or_one_item(
406 inferred_fields,
407 "Conflicting fields found. Consider specifying some \
408 `#[error(...)]` attributes to resolve conflict.",
409 )?,
410 };
411
412 Ok(field)
413}
414
415fn assert_iter_contains_zero_or_one_item<'a>(
416 mut iter: impl Iterator<Item = (usize, &'a syn::Field, &'a MetaInfo)>,
417 error_msg: &str,
418) -> Result<Option<(usize, &'a syn::Field, &'a MetaInfo)>> {
419 let item: (usize, &Field, &MetaInfo) = match iter.next() {
420 Some(item: (usize, &Field, &MetaInfo)) => item,
421 None => return Ok(None),
422 };
423
424 if let Some((_, field: &Field, _)) = iter.next() {
425 return Err(Error::new(field.span(), message:error_msg));
426 }
427
428 Ok(Some(item))
429}
430
431fn add_bound_if_type_parameter_used_in_type(
432 bounds: &mut HashSet<syn::Type>,
433 type_params: &HashSet<syn::Ident>,
434 ty: &syn::Type,
435) {
436 if let Some(ty: Type) = utils::get_if_type_parameter_used_in_type(type_parameters:type_params, ty) {
437 bounds.insert(ty);
438 }
439}
440