1 | use crate::ast::{Enum, Field, Input, Struct}; |
2 | use crate::attr::Trait; |
3 | use crate::generics::InferredBounds; |
4 | use crate::span::MemberSpan; |
5 | use proc_macro2::TokenStream; |
6 | use quote::{format_ident, quote, quote_spanned, ToTokens}; |
7 | use std::collections::BTreeSet as Set; |
8 | use syn::{DeriveInput, GenericArgument, Member, PathArguments, Result, Token, Type}; |
9 | |
10 | pub fn derive(input: &DeriveInput) -> TokenStream { |
11 | match try_expand(input) { |
12 | Ok(expanded: TokenStream) => expanded, |
13 | // If there are invalid attributes in the input, expand to an Error impl |
14 | // anyway to minimize spurious knock-on errors in other code that uses |
15 | // this type as an Error. |
16 | Err(error: Error) => fallback(input, error), |
17 | } |
18 | } |
19 | |
20 | fn try_expand(input: &DeriveInput) -> Result<TokenStream> { |
21 | let input: Input<'_> = Input::from_syn(node:input)?; |
22 | input.validate()?; |
23 | Ok(match input { |
24 | Input::Struct(input: Struct<'_>) => impl_struct(input), |
25 | Input::Enum(input: Enum<'_>) => impl_enum(input), |
26 | }) |
27 | } |
28 | |
29 | fn fallback(input: &DeriveInput, error: syn::Error) -> TokenStream { |
30 | let ty = &input.ident; |
31 | let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); |
32 | |
33 | let error = error.to_compile_error(); |
34 | |
35 | quote! { |
36 | #error |
37 | |
38 | #[allow(unused_qualifications)] |
39 | impl #impl_generics std::error::Error for #ty #ty_generics #where_clause |
40 | where |
41 | // Work around trivial bounds being unstable. |
42 | // https://github.com/rust-lang/rust/issues/48214 |
43 | for<'workaround> #ty #ty_generics: ::core::fmt::Debug, |
44 | {} |
45 | |
46 | #[allow(unused_qualifications)] |
47 | impl #impl_generics ::core::fmt::Display for #ty #ty_generics #where_clause { |
48 | fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { |
49 | ::core::unreachable!() |
50 | } |
51 | } |
52 | } |
53 | } |
54 | |
55 | fn impl_struct(input: Struct) -> TokenStream { |
56 | let ty = &input.ident; |
57 | let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); |
58 | let mut error_inferred_bounds = InferredBounds::new(); |
59 | |
60 | let source_body = if let Some(transparent_attr) = &input.attrs.transparent { |
61 | let only_field = &input.fields[0]; |
62 | if only_field.contains_generic { |
63 | error_inferred_bounds.insert(only_field.ty, quote!(std::error::Error)); |
64 | } |
65 | let member = &only_field.member; |
66 | Some(quote_spanned! {transparent_attr.span=> |
67 | std::error::Error::source(self.#member.as_dyn_error()) |
68 | }) |
69 | } else if let Some(source_field) = input.source_field() { |
70 | let source = &source_field.member; |
71 | if source_field.contains_generic { |
72 | let ty = unoptional_type(source_field.ty); |
73 | error_inferred_bounds.insert(ty, quote!(std::error::Error + 'static)); |
74 | } |
75 | let asref = if type_is_option(source_field.ty) { |
76 | Some(quote_spanned!(source.member_span()=> .as_ref()?)) |
77 | } else { |
78 | None |
79 | }; |
80 | let dyn_error = quote_spanned! {source_field.source_span()=> |
81 | self.#source #asref.as_dyn_error() |
82 | }; |
83 | Some(quote! { |
84 | ::core::option::Option::Some(#dyn_error) |
85 | }) |
86 | } else { |
87 | None |
88 | }; |
89 | let source_method = source_body.map(|body| { |
90 | quote! { |
91 | fn source(&self) -> ::core::option::Option<&(dyn std::error::Error + 'static)> { |
92 | use thiserror::__private::AsDynError as _; |
93 | #body |
94 | } |
95 | } |
96 | }); |
97 | |
98 | let provide_method = input.backtrace_field().map(|backtrace_field| { |
99 | let request = quote!(request); |
100 | let backtrace = &backtrace_field.member; |
101 | let body = if let Some(source_field) = input.source_field() { |
102 | let source = &source_field.member; |
103 | let source_provide = if type_is_option(source_field.ty) { |
104 | quote_spanned! {source.member_span()=> |
105 | if let ::core::option::Option::Some(source) = &self.#source { |
106 | source.thiserror_provide(#request); |
107 | } |
108 | } |
109 | } else { |
110 | quote_spanned! {source.member_span()=> |
111 | self.#source.thiserror_provide(#request); |
112 | } |
113 | }; |
114 | let self_provide = if source == backtrace { |
115 | None |
116 | } else if type_is_option(backtrace_field.ty) { |
117 | Some(quote! { |
118 | if let ::core::option::Option::Some(backtrace) = &self.#backtrace { |
119 | #request.provide_ref::<std::backtrace::Backtrace>(backtrace); |
120 | } |
121 | }) |
122 | } else { |
123 | Some(quote! { |
124 | #request.provide_ref::<std::backtrace::Backtrace>(&self.#backtrace); |
125 | }) |
126 | }; |
127 | quote! { |
128 | use thiserror::__private::ThiserrorProvide as _; |
129 | #source_provide |
130 | #self_provide |
131 | } |
132 | } else if type_is_option(backtrace_field.ty) { |
133 | quote! { |
134 | if let ::core::option::Option::Some(backtrace) = &self.#backtrace { |
135 | #request.provide_ref::<std::backtrace::Backtrace>(backtrace); |
136 | } |
137 | } |
138 | } else { |
139 | quote! { |
140 | #request.provide_ref::<std::backtrace::Backtrace>(&self.#backtrace); |
141 | } |
142 | }; |
143 | quote! { |
144 | fn provide<'_request>(&'_request self, #request: &mut std::error::Request<'_request>) { |
145 | #body |
146 | } |
147 | } |
148 | }); |
149 | |
150 | let mut display_implied_bounds = Set::new(); |
151 | let display_body = if input.attrs.transparent.is_some() { |
152 | let only_field = &input.fields[0].member; |
153 | display_implied_bounds.insert((0, Trait::Display)); |
154 | Some(quote! { |
155 | ::core::fmt::Display::fmt(&self.#only_field, __formatter) |
156 | }) |
157 | } else if let Some(display) = &input.attrs.display { |
158 | display_implied_bounds.clone_from(&display.implied_bounds); |
159 | let use_as_display = use_as_display(display.has_bonus_display); |
160 | let pat = fields_pat(&input.fields); |
161 | Some(quote! { |
162 | #use_as_display |
163 | #[allow(unused_variables, deprecated)] |
164 | let Self #pat = self; |
165 | #display |
166 | }) |
167 | } else { |
168 | None |
169 | }; |
170 | let display_impl = display_body.map(|body| { |
171 | let mut display_inferred_bounds = InferredBounds::new(); |
172 | for (field, bound) in display_implied_bounds { |
173 | let field = &input.fields[field]; |
174 | if field.contains_generic { |
175 | display_inferred_bounds.insert(field.ty, bound); |
176 | } |
177 | } |
178 | let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics); |
179 | quote! { |
180 | #[allow(unused_qualifications)] |
181 | impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause { |
182 | #[allow(clippy::used_underscore_binding)] |
183 | fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { |
184 | #body |
185 | } |
186 | } |
187 | } |
188 | }); |
189 | |
190 | let from_impl = input.from_field().map(|from_field| { |
191 | let backtrace_field = input.distinct_backtrace_field(); |
192 | let from = unoptional_type(from_field.ty); |
193 | let body = from_initializer(from_field, backtrace_field); |
194 | quote! { |
195 | #[allow(unused_qualifications)] |
196 | impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause { |
197 | #[allow(deprecated)] |
198 | fn from(source: #from) -> Self { |
199 | #ty #body |
200 | } |
201 | } |
202 | } |
203 | }); |
204 | |
205 | if input.generics.type_params().next().is_some() { |
206 | let self_token = <Token![Self]>::default(); |
207 | error_inferred_bounds.insert(self_token, Trait::Debug); |
208 | error_inferred_bounds.insert(self_token, Trait::Display); |
209 | } |
210 | let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics); |
211 | |
212 | quote! { |
213 | #[allow(unused_qualifications)] |
214 | impl #impl_generics std::error::Error for #ty #ty_generics #error_where_clause { |
215 | #source_method |
216 | #provide_method |
217 | } |
218 | #display_impl |
219 | #from_impl |
220 | } |
221 | } |
222 | |
223 | fn impl_enum(input: Enum) -> TokenStream { |
224 | let ty = &input.ident; |
225 | let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); |
226 | let mut error_inferred_bounds = InferredBounds::new(); |
227 | |
228 | let source_method = if input.has_source() { |
229 | let arms = input.variants.iter().map(|variant| { |
230 | let ident = &variant.ident; |
231 | if let Some(transparent_attr) = &variant.attrs.transparent { |
232 | let only_field = &variant.fields[0]; |
233 | if only_field.contains_generic { |
234 | error_inferred_bounds.insert(only_field.ty, quote!(std::error::Error)); |
235 | } |
236 | let member = &only_field.member; |
237 | let source = quote_spanned! {transparent_attr.span=> |
238 | std::error::Error::source(transparent.as_dyn_error()) |
239 | }; |
240 | quote! { |
241 | #ty::#ident {#member: transparent} => #source, |
242 | } |
243 | } else if let Some(source_field) = variant.source_field() { |
244 | let source = &source_field.member; |
245 | if source_field.contains_generic { |
246 | let ty = unoptional_type(source_field.ty); |
247 | error_inferred_bounds.insert(ty, quote!(std::error::Error + 'static)); |
248 | } |
249 | let asref = if type_is_option(source_field.ty) { |
250 | Some(quote_spanned!(source.member_span()=> .as_ref()?)) |
251 | } else { |
252 | None |
253 | }; |
254 | let varsource = quote!(source); |
255 | let dyn_error = quote_spanned! {source_field.source_span()=> |
256 | #varsource #asref.as_dyn_error() |
257 | }; |
258 | quote! { |
259 | #ty::#ident {#source: #varsource, ..} => ::core::option::Option::Some(#dyn_error), |
260 | } |
261 | } else { |
262 | quote! { |
263 | #ty::#ident {..} => ::core::option::Option::None, |
264 | } |
265 | } |
266 | }); |
267 | Some(quote! { |
268 | fn source(&self) -> ::core::option::Option<&(dyn std::error::Error + 'static)> { |
269 | use thiserror::__private::AsDynError as _; |
270 | #[allow(deprecated)] |
271 | match self { |
272 | #(#arms)* |
273 | } |
274 | } |
275 | }) |
276 | } else { |
277 | None |
278 | }; |
279 | |
280 | let provide_method = if input.has_backtrace() { |
281 | let request = quote!(request); |
282 | let arms = input.variants.iter().map(|variant| { |
283 | let ident = &variant.ident; |
284 | match (variant.backtrace_field(), variant.source_field()) { |
285 | (Some(backtrace_field), Some(source_field)) |
286 | if backtrace_field.attrs.backtrace.is_none() => |
287 | { |
288 | let backtrace = &backtrace_field.member; |
289 | let source = &source_field.member; |
290 | let varsource = quote!(source); |
291 | let source_provide = if type_is_option(source_field.ty) { |
292 | quote_spanned! {source.member_span()=> |
293 | if let ::core::option::Option::Some(source) = #varsource { |
294 | source.thiserror_provide(#request); |
295 | } |
296 | } |
297 | } else { |
298 | quote_spanned! {source.member_span()=> |
299 | #varsource.thiserror_provide(#request); |
300 | } |
301 | }; |
302 | let self_provide = if type_is_option(backtrace_field.ty) { |
303 | quote! { |
304 | if let ::core::option::Option::Some(backtrace) = backtrace { |
305 | #request.provide_ref::<std::backtrace::Backtrace>(backtrace); |
306 | } |
307 | } |
308 | } else { |
309 | quote! { |
310 | #request.provide_ref::<std::backtrace::Backtrace>(backtrace); |
311 | } |
312 | }; |
313 | quote! { |
314 | #ty::#ident { |
315 | #backtrace: backtrace, |
316 | #source: #varsource, |
317 | .. |
318 | } => { |
319 | use thiserror::__private::ThiserrorProvide as _; |
320 | #source_provide |
321 | #self_provide |
322 | } |
323 | } |
324 | } |
325 | (Some(backtrace_field), Some(source_field)) |
326 | if backtrace_field.member == source_field.member => |
327 | { |
328 | let backtrace = &backtrace_field.member; |
329 | let varsource = quote!(source); |
330 | let source_provide = if type_is_option(source_field.ty) { |
331 | quote_spanned! {backtrace.member_span()=> |
332 | if let ::core::option::Option::Some(source) = #varsource { |
333 | source.thiserror_provide(#request); |
334 | } |
335 | } |
336 | } else { |
337 | quote_spanned! {backtrace.member_span()=> |
338 | #varsource.thiserror_provide(#request); |
339 | } |
340 | }; |
341 | quote! { |
342 | #ty::#ident {#backtrace: #varsource, ..} => { |
343 | use thiserror::__private::ThiserrorProvide as _; |
344 | #source_provide |
345 | } |
346 | } |
347 | } |
348 | (Some(backtrace_field), _) => { |
349 | let backtrace = &backtrace_field.member; |
350 | let body = if type_is_option(backtrace_field.ty) { |
351 | quote! { |
352 | if let ::core::option::Option::Some(backtrace) = backtrace { |
353 | #request.provide_ref::<std::backtrace::Backtrace>(backtrace); |
354 | } |
355 | } |
356 | } else { |
357 | quote! { |
358 | #request.provide_ref::<std::backtrace::Backtrace>(backtrace); |
359 | } |
360 | }; |
361 | quote! { |
362 | #ty::#ident {#backtrace: backtrace, ..} => { |
363 | #body |
364 | } |
365 | } |
366 | } |
367 | (None, _) => quote! { |
368 | #ty::#ident {..} => {} |
369 | }, |
370 | } |
371 | }); |
372 | Some(quote! { |
373 | fn provide<'_request>(&'_request self, #request: &mut std::error::Request<'_request>) { |
374 | #[allow(deprecated)] |
375 | match self { |
376 | #(#arms)* |
377 | } |
378 | } |
379 | }) |
380 | } else { |
381 | None |
382 | }; |
383 | |
384 | let display_impl = if input.has_display() { |
385 | let mut display_inferred_bounds = InferredBounds::new(); |
386 | let has_bonus_display = input.variants.iter().any(|v| { |
387 | v.attrs |
388 | .display |
389 | .as_ref() |
390 | .map_or(false, |display| display.has_bonus_display) |
391 | }); |
392 | let use_as_display = use_as_display(has_bonus_display); |
393 | let void_deref = if input.variants.is_empty() { |
394 | Some(quote!(*)) |
395 | } else { |
396 | None |
397 | }; |
398 | let arms = input.variants.iter().map(|variant| { |
399 | let mut display_implied_bounds = Set::new(); |
400 | let display = match &variant.attrs.display { |
401 | Some(display) => { |
402 | display_implied_bounds.clone_from(&display.implied_bounds); |
403 | display.to_token_stream() |
404 | } |
405 | None => { |
406 | let only_field = match &variant.fields[0].member { |
407 | Member::Named(ident) => ident.clone(), |
408 | Member::Unnamed(index) => format_ident!("_ {}" , index), |
409 | }; |
410 | display_implied_bounds.insert((0, Trait::Display)); |
411 | quote!(::core::fmt::Display::fmt(#only_field, __formatter)) |
412 | } |
413 | }; |
414 | for (field, bound) in display_implied_bounds { |
415 | let field = &variant.fields[field]; |
416 | if field.contains_generic { |
417 | display_inferred_bounds.insert(field.ty, bound); |
418 | } |
419 | } |
420 | let ident = &variant.ident; |
421 | let pat = fields_pat(&variant.fields); |
422 | quote! { |
423 | #ty::#ident #pat => #display |
424 | } |
425 | }); |
426 | let arms = arms.collect::<Vec<_>>(); |
427 | let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics); |
428 | Some(quote! { |
429 | #[allow(unused_qualifications)] |
430 | impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause { |
431 | fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { |
432 | #use_as_display |
433 | #[allow(unused_variables, deprecated, clippy::used_underscore_binding)] |
434 | match #void_deref self { |
435 | #(#arms,)* |
436 | } |
437 | } |
438 | } |
439 | }) |
440 | } else { |
441 | None |
442 | }; |
443 | |
444 | let from_impls = input.variants.iter().filter_map(|variant| { |
445 | let from_field = variant.from_field()?; |
446 | let backtrace_field = variant.distinct_backtrace_field(); |
447 | let variant = &variant.ident; |
448 | let from = unoptional_type(from_field.ty); |
449 | let body = from_initializer(from_field, backtrace_field); |
450 | Some(quote! { |
451 | #[allow(unused_qualifications)] |
452 | impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause { |
453 | #[allow(deprecated)] |
454 | fn from(source: #from) -> Self { |
455 | #ty::#variant #body |
456 | } |
457 | } |
458 | }) |
459 | }); |
460 | |
461 | if input.generics.type_params().next().is_some() { |
462 | let self_token = <Token![Self]>::default(); |
463 | error_inferred_bounds.insert(self_token, Trait::Debug); |
464 | error_inferred_bounds.insert(self_token, Trait::Display); |
465 | } |
466 | let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics); |
467 | |
468 | quote! { |
469 | #[allow(unused_qualifications)] |
470 | impl #impl_generics std::error::Error for #ty #ty_generics #error_where_clause { |
471 | #source_method |
472 | #provide_method |
473 | } |
474 | #display_impl |
475 | #(#from_impls)* |
476 | } |
477 | } |
478 | |
479 | fn fields_pat(fields: &[Field]) -> TokenStream { |
480 | let mut members: impl Iterator = fields.iter().map(|field: &Field<'_>| &field.member).peekable(); |
481 | match members.peek() { |
482 | Some(Member::Named(_)) => quote!({ #(#members),* }), |
483 | Some(Member::Unnamed(_)) => { |
484 | let vars: impl Iterator = members.map(|member: &Member| match member { |
485 | Member::Unnamed(member: &Index) => format_ident!("_ {}" , member), |
486 | Member::Named(_) => unreachable!(), |
487 | }); |
488 | quote!((#(#vars),*)) |
489 | } |
490 | None => quote!({}), |
491 | } |
492 | } |
493 | |
494 | fn use_as_display(needs_as_display: bool) -> Option<TokenStream> { |
495 | if needs_as_display { |
496 | Some(quote! { |
497 | use thiserror::__private::AsDisplay as _; |
498 | }) |
499 | } else { |
500 | None |
501 | } |
502 | } |
503 | |
504 | fn from_initializer(from_field: &Field, backtrace_field: Option<&Field>) -> TokenStream { |
505 | let from_member: &Member = &from_field.member; |
506 | let some_source: TokenStream = if type_is_option(from_field.ty) { |
507 | quote!(::core::option::Option::Some(source)) |
508 | } else { |
509 | quote!(source) |
510 | }; |
511 | let backtrace: Option = backtrace_field.map(|backtrace_field: &Field<'_>| { |
512 | let backtrace_member: &Member = &backtrace_field.member; |
513 | if type_is_option(backtrace_field.ty) { |
514 | quote! { |
515 | #backtrace_member: ::core::option::Option::Some(std::backtrace::Backtrace::capture()), |
516 | } |
517 | } else { |
518 | quote! { |
519 | #backtrace_member: ::core::convert::From::from(std::backtrace::Backtrace::capture()), |
520 | } |
521 | } |
522 | }); |
523 | quote!({ |
524 | #from_member: #some_source, |
525 | #backtrace |
526 | }) |
527 | } |
528 | |
529 | fn type_is_option(ty: &Type) -> bool { |
530 | type_parameter_of_option(ty).is_some() |
531 | } |
532 | |
533 | fn unoptional_type(ty: &Type) -> TokenStream { |
534 | let unoptional: &Type = type_parameter_of_option(ty).unwrap_or(default:ty); |
535 | quote!(#unoptional) |
536 | } |
537 | |
538 | fn type_parameter_of_option(ty: &Type) -> Option<&Type> { |
539 | let path = match ty { |
540 | Type::Path(ty) => &ty.path, |
541 | _ => return None, |
542 | }; |
543 | |
544 | let last = path.segments.last().unwrap(); |
545 | if last.ident != "Option" { |
546 | return None; |
547 | } |
548 | |
549 | let bracketed = match &last.arguments { |
550 | PathArguments::AngleBracketed(bracketed) => bracketed, |
551 | _ => return None, |
552 | }; |
553 | |
554 | if bracketed.args.len() != 1 { |
555 | return None; |
556 | } |
557 | |
558 | match &bracketed.args[0] { |
559 | GenericArgument::Type(arg) => Some(arg), |
560 | _ => None, |
561 | } |
562 | } |
563 | |