1use crate::ast::{Enum, Field, Input, Struct};
2use crate::attr::Trait;
3use crate::generics::InferredBounds;
4use crate::span::MemberSpan;
5use proc_macro2::TokenStream;
6use quote::{format_ident, quote, quote_spanned, ToTokens};
7use std::collections::BTreeSet as Set;
8use syn::{DeriveInput, GenericArgument, Member, PathArguments, Result, Token, Type};
9
10pub fn derive(input: &DeriveInput) -> TokenStream {
11 match try_expand(input) {
12 Ok(expanded: TokenStream) => expanded,
13 // If there are invalid attributes in the input, expand to an Error impl
14 // anyway to minimize spurious knock-on errors in other code that uses
15 // this type as an Error.
16 Err(error: Error) => fallback(input, error),
17 }
18}
19
20fn try_expand(input: &DeriveInput) -> Result<TokenStream> {
21 let input: Input<'_> = Input::from_syn(node:input)?;
22 input.validate()?;
23 Ok(match input {
24 Input::Struct(input: Struct<'_>) => impl_struct(input),
25 Input::Enum(input: Enum<'_>) => impl_enum(input),
26 })
27}
28
29fn fallback(input: &DeriveInput, error: syn::Error) -> TokenStream {
30 let ty = &input.ident;
31 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
32
33 let error = error.to_compile_error();
34
35 quote! {
36 #error
37
38 #[allow(unused_qualifications)]
39 impl #impl_generics std::error::Error for #ty #ty_generics #where_clause
40 where
41 // Work around trivial bounds being unstable.
42 // https://github.com/rust-lang/rust/issues/48214
43 for<'workaround> #ty #ty_generics: ::core::fmt::Debug,
44 {}
45
46 #[allow(unused_qualifications)]
47 impl #impl_generics ::core::fmt::Display for #ty #ty_generics #where_clause {
48 fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
49 ::core::unreachable!()
50 }
51 }
52 }
53}
54
55fn impl_struct(input: Struct) -> TokenStream {
56 let ty = &input.ident;
57 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
58 let mut error_inferred_bounds = InferredBounds::new();
59
60 let source_body = if let Some(transparent_attr) = &input.attrs.transparent {
61 let only_field = &input.fields[0];
62 if only_field.contains_generic {
63 error_inferred_bounds.insert(only_field.ty, quote!(std::error::Error));
64 }
65 let member = &only_field.member;
66 Some(quote_spanned! {transparent_attr.span=>
67 std::error::Error::source(self.#member.as_dyn_error())
68 })
69 } else if let Some(source_field) = input.source_field() {
70 let source = &source_field.member;
71 if source_field.contains_generic {
72 let ty = unoptional_type(source_field.ty);
73 error_inferred_bounds.insert(ty, quote!(std::error::Error + 'static));
74 }
75 let asref = if type_is_option(source_field.ty) {
76 Some(quote_spanned!(source.member_span()=> .as_ref()?))
77 } else {
78 None
79 };
80 let dyn_error = quote_spanned! {source_field.source_span()=>
81 self.#source #asref.as_dyn_error()
82 };
83 Some(quote! {
84 ::core::option::Option::Some(#dyn_error)
85 })
86 } else {
87 None
88 };
89 let source_method = source_body.map(|body| {
90 quote! {
91 fn source(&self) -> ::core::option::Option<&(dyn std::error::Error + 'static)> {
92 use thiserror::__private::AsDynError as _;
93 #body
94 }
95 }
96 });
97
98 let provide_method = input.backtrace_field().map(|backtrace_field| {
99 let request = quote!(request);
100 let backtrace = &backtrace_field.member;
101 let body = if let Some(source_field) = input.source_field() {
102 let source = &source_field.member;
103 let source_provide = if type_is_option(source_field.ty) {
104 quote_spanned! {source.member_span()=>
105 if let ::core::option::Option::Some(source) = &self.#source {
106 source.thiserror_provide(#request);
107 }
108 }
109 } else {
110 quote_spanned! {source.member_span()=>
111 self.#source.thiserror_provide(#request);
112 }
113 };
114 let self_provide = if source == backtrace {
115 None
116 } else if type_is_option(backtrace_field.ty) {
117 Some(quote! {
118 if let ::core::option::Option::Some(backtrace) = &self.#backtrace {
119 #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
120 }
121 })
122 } else {
123 Some(quote! {
124 #request.provide_ref::<std::backtrace::Backtrace>(&self.#backtrace);
125 })
126 };
127 quote! {
128 use thiserror::__private::ThiserrorProvide as _;
129 #source_provide
130 #self_provide
131 }
132 } else if type_is_option(backtrace_field.ty) {
133 quote! {
134 if let ::core::option::Option::Some(backtrace) = &self.#backtrace {
135 #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
136 }
137 }
138 } else {
139 quote! {
140 #request.provide_ref::<std::backtrace::Backtrace>(&self.#backtrace);
141 }
142 };
143 quote! {
144 fn provide<'_request>(&'_request self, #request: &mut std::error::Request<'_request>) {
145 #body
146 }
147 }
148 });
149
150 let mut display_implied_bounds = Set::new();
151 let display_body = if input.attrs.transparent.is_some() {
152 let only_field = &input.fields[0].member;
153 display_implied_bounds.insert((0, Trait::Display));
154 Some(quote! {
155 ::core::fmt::Display::fmt(&self.#only_field, __formatter)
156 })
157 } else if let Some(display) = &input.attrs.display {
158 display_implied_bounds.clone_from(&display.implied_bounds);
159 let use_as_display = use_as_display(display.has_bonus_display);
160 let pat = fields_pat(&input.fields);
161 Some(quote! {
162 #use_as_display
163 #[allow(unused_variables, deprecated)]
164 let Self #pat = self;
165 #display
166 })
167 } else {
168 None
169 };
170 let display_impl = display_body.map(|body| {
171 let mut display_inferred_bounds = InferredBounds::new();
172 for (field, bound) in display_implied_bounds {
173 let field = &input.fields[field];
174 if field.contains_generic {
175 display_inferred_bounds.insert(field.ty, bound);
176 }
177 }
178 let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
179 quote! {
180 #[allow(unused_qualifications)]
181 impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
182 #[allow(clippy::used_underscore_binding)]
183 fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
184 #body
185 }
186 }
187 }
188 });
189
190 let from_impl = input.from_field().map(|from_field| {
191 let backtrace_field = input.distinct_backtrace_field();
192 let from = unoptional_type(from_field.ty);
193 let body = from_initializer(from_field, backtrace_field);
194 quote! {
195 #[allow(unused_qualifications)]
196 impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
197 #[allow(deprecated)]
198 fn from(source: #from) -> Self {
199 #ty #body
200 }
201 }
202 }
203 });
204
205 if input.generics.type_params().next().is_some() {
206 let self_token = <Token![Self]>::default();
207 error_inferred_bounds.insert(self_token, Trait::Debug);
208 error_inferred_bounds.insert(self_token, Trait::Display);
209 }
210 let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
211
212 quote! {
213 #[allow(unused_qualifications)]
214 impl #impl_generics std::error::Error for #ty #ty_generics #error_where_clause {
215 #source_method
216 #provide_method
217 }
218 #display_impl
219 #from_impl
220 }
221}
222
223fn impl_enum(input: Enum) -> TokenStream {
224 let ty = &input.ident;
225 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
226 let mut error_inferred_bounds = InferredBounds::new();
227
228 let source_method = if input.has_source() {
229 let arms = input.variants.iter().map(|variant| {
230 let ident = &variant.ident;
231 if let Some(transparent_attr) = &variant.attrs.transparent {
232 let only_field = &variant.fields[0];
233 if only_field.contains_generic {
234 error_inferred_bounds.insert(only_field.ty, quote!(std::error::Error));
235 }
236 let member = &only_field.member;
237 let source = quote_spanned! {transparent_attr.span=>
238 std::error::Error::source(transparent.as_dyn_error())
239 };
240 quote! {
241 #ty::#ident {#member: transparent} => #source,
242 }
243 } else if let Some(source_field) = variant.source_field() {
244 let source = &source_field.member;
245 if source_field.contains_generic {
246 let ty = unoptional_type(source_field.ty);
247 error_inferred_bounds.insert(ty, quote!(std::error::Error + 'static));
248 }
249 let asref = if type_is_option(source_field.ty) {
250 Some(quote_spanned!(source.member_span()=> .as_ref()?))
251 } else {
252 None
253 };
254 let varsource = quote!(source);
255 let dyn_error = quote_spanned! {source_field.source_span()=>
256 #varsource #asref.as_dyn_error()
257 };
258 quote! {
259 #ty::#ident {#source: #varsource, ..} => ::core::option::Option::Some(#dyn_error),
260 }
261 } else {
262 quote! {
263 #ty::#ident {..} => ::core::option::Option::None,
264 }
265 }
266 });
267 Some(quote! {
268 fn source(&self) -> ::core::option::Option<&(dyn std::error::Error + 'static)> {
269 use thiserror::__private::AsDynError as _;
270 #[allow(deprecated)]
271 match self {
272 #(#arms)*
273 }
274 }
275 })
276 } else {
277 None
278 };
279
280 let provide_method = if input.has_backtrace() {
281 let request = quote!(request);
282 let arms = input.variants.iter().map(|variant| {
283 let ident = &variant.ident;
284 match (variant.backtrace_field(), variant.source_field()) {
285 (Some(backtrace_field), Some(source_field))
286 if backtrace_field.attrs.backtrace.is_none() =>
287 {
288 let backtrace = &backtrace_field.member;
289 let source = &source_field.member;
290 let varsource = quote!(source);
291 let source_provide = if type_is_option(source_field.ty) {
292 quote_spanned! {source.member_span()=>
293 if let ::core::option::Option::Some(source) = #varsource {
294 source.thiserror_provide(#request);
295 }
296 }
297 } else {
298 quote_spanned! {source.member_span()=>
299 #varsource.thiserror_provide(#request);
300 }
301 };
302 let self_provide = if type_is_option(backtrace_field.ty) {
303 quote! {
304 if let ::core::option::Option::Some(backtrace) = backtrace {
305 #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
306 }
307 }
308 } else {
309 quote! {
310 #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
311 }
312 };
313 quote! {
314 #ty::#ident {
315 #backtrace: backtrace,
316 #source: #varsource,
317 ..
318 } => {
319 use thiserror::__private::ThiserrorProvide as _;
320 #source_provide
321 #self_provide
322 }
323 }
324 }
325 (Some(backtrace_field), Some(source_field))
326 if backtrace_field.member == source_field.member =>
327 {
328 let backtrace = &backtrace_field.member;
329 let varsource = quote!(source);
330 let source_provide = if type_is_option(source_field.ty) {
331 quote_spanned! {backtrace.member_span()=>
332 if let ::core::option::Option::Some(source) = #varsource {
333 source.thiserror_provide(#request);
334 }
335 }
336 } else {
337 quote_spanned! {backtrace.member_span()=>
338 #varsource.thiserror_provide(#request);
339 }
340 };
341 quote! {
342 #ty::#ident {#backtrace: #varsource, ..} => {
343 use thiserror::__private::ThiserrorProvide as _;
344 #source_provide
345 }
346 }
347 }
348 (Some(backtrace_field), _) => {
349 let backtrace = &backtrace_field.member;
350 let body = if type_is_option(backtrace_field.ty) {
351 quote! {
352 if let ::core::option::Option::Some(backtrace) = backtrace {
353 #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
354 }
355 }
356 } else {
357 quote! {
358 #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
359 }
360 };
361 quote! {
362 #ty::#ident {#backtrace: backtrace, ..} => {
363 #body
364 }
365 }
366 }
367 (None, _) => quote! {
368 #ty::#ident {..} => {}
369 },
370 }
371 });
372 Some(quote! {
373 fn provide<'_request>(&'_request self, #request: &mut std::error::Request<'_request>) {
374 #[allow(deprecated)]
375 match self {
376 #(#arms)*
377 }
378 }
379 })
380 } else {
381 None
382 };
383
384 let display_impl = if input.has_display() {
385 let mut display_inferred_bounds = InferredBounds::new();
386 let has_bonus_display = input.variants.iter().any(|v| {
387 v.attrs
388 .display
389 .as_ref()
390 .map_or(false, |display| display.has_bonus_display)
391 });
392 let use_as_display = use_as_display(has_bonus_display);
393 let void_deref = if input.variants.is_empty() {
394 Some(quote!(*))
395 } else {
396 None
397 };
398 let arms = input.variants.iter().map(|variant| {
399 let mut display_implied_bounds = Set::new();
400 let display = match &variant.attrs.display {
401 Some(display) => {
402 display_implied_bounds.clone_from(&display.implied_bounds);
403 display.to_token_stream()
404 }
405 None => {
406 let only_field = match &variant.fields[0].member {
407 Member::Named(ident) => ident.clone(),
408 Member::Unnamed(index) => format_ident!("_{}", index),
409 };
410 display_implied_bounds.insert((0, Trait::Display));
411 quote!(::core::fmt::Display::fmt(#only_field, __formatter))
412 }
413 };
414 for (field, bound) in display_implied_bounds {
415 let field = &variant.fields[field];
416 if field.contains_generic {
417 display_inferred_bounds.insert(field.ty, bound);
418 }
419 }
420 let ident = &variant.ident;
421 let pat = fields_pat(&variant.fields);
422 quote! {
423 #ty::#ident #pat => #display
424 }
425 });
426 let arms = arms.collect::<Vec<_>>();
427 let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
428 Some(quote! {
429 #[allow(unused_qualifications)]
430 impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
431 fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
432 #use_as_display
433 #[allow(unused_variables, deprecated, clippy::used_underscore_binding)]
434 match #void_deref self {
435 #(#arms,)*
436 }
437 }
438 }
439 })
440 } else {
441 None
442 };
443
444 let from_impls = input.variants.iter().filter_map(|variant| {
445 let from_field = variant.from_field()?;
446 let backtrace_field = variant.distinct_backtrace_field();
447 let variant = &variant.ident;
448 let from = unoptional_type(from_field.ty);
449 let body = from_initializer(from_field, backtrace_field);
450 Some(quote! {
451 #[allow(unused_qualifications)]
452 impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
453 #[allow(deprecated)]
454 fn from(source: #from) -> Self {
455 #ty::#variant #body
456 }
457 }
458 })
459 });
460
461 if input.generics.type_params().next().is_some() {
462 let self_token = <Token![Self]>::default();
463 error_inferred_bounds.insert(self_token, Trait::Debug);
464 error_inferred_bounds.insert(self_token, Trait::Display);
465 }
466 let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
467
468 quote! {
469 #[allow(unused_qualifications)]
470 impl #impl_generics std::error::Error for #ty #ty_generics #error_where_clause {
471 #source_method
472 #provide_method
473 }
474 #display_impl
475 #(#from_impls)*
476 }
477}
478
479fn fields_pat(fields: &[Field]) -> TokenStream {
480 let mut members: impl Iterator = fields.iter().map(|field: &Field<'_>| &field.member).peekable();
481 match members.peek() {
482 Some(Member::Named(_)) => quote!({ #(#members),* }),
483 Some(Member::Unnamed(_)) => {
484 let vars: impl Iterator = members.map(|member: &Member| match member {
485 Member::Unnamed(member: &Index) => format_ident!("_{}", member),
486 Member::Named(_) => unreachable!(),
487 });
488 quote!((#(#vars),*))
489 }
490 None => quote!({}),
491 }
492}
493
494fn use_as_display(needs_as_display: bool) -> Option<TokenStream> {
495 if needs_as_display {
496 Some(quote! {
497 use thiserror::__private::AsDisplay as _;
498 })
499 } else {
500 None
501 }
502}
503
504fn from_initializer(from_field: &Field, backtrace_field: Option<&Field>) -> TokenStream {
505 let from_member: &Member = &from_field.member;
506 let some_source: TokenStream = if type_is_option(from_field.ty) {
507 quote!(::core::option::Option::Some(source))
508 } else {
509 quote!(source)
510 };
511 let backtrace: Option = backtrace_field.map(|backtrace_field: &Field<'_>| {
512 let backtrace_member: &Member = &backtrace_field.member;
513 if type_is_option(backtrace_field.ty) {
514 quote! {
515 #backtrace_member: ::core::option::Option::Some(std::backtrace::Backtrace::capture()),
516 }
517 } else {
518 quote! {
519 #backtrace_member: ::core::convert::From::from(std::backtrace::Backtrace::capture()),
520 }
521 }
522 });
523 quote!({
524 #from_member: #some_source,
525 #backtrace
526 })
527}
528
529fn type_is_option(ty: &Type) -> bool {
530 type_parameter_of_option(ty).is_some()
531}
532
533fn unoptional_type(ty: &Type) -> TokenStream {
534 let unoptional: &Type = type_parameter_of_option(ty).unwrap_or(default:ty);
535 quote!(#unoptional)
536}
537
538fn type_parameter_of_option(ty: &Type) -> Option<&Type> {
539 let path = match ty {
540 Type::Path(ty) => &ty.path,
541 _ => return None,
542 };
543
544 let last = path.segments.last().unwrap();
545 if last.ident != "Option" {
546 return None;
547 }
548
549 let bracketed = match &last.arguments {
550 PathArguments::AngleBracketed(bracketed) => bracketed,
551 _ => return None,
552 };
553
554 if bracketed.args.len() != 1 {
555 return None;
556 }
557
558 match &bracketed.args[0] {
559 GenericArgument::Type(arg) => Some(arg),
560 _ => None,
561 }
562}
563