1use crate::ast::{Enum, Field, Input, Struct};
2use crate::attr::Trait;
3use crate::generics::InferredBounds;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote, quote_spanned, ToTokens};
6use std::collections::BTreeSet as Set;
7use syn::spanned::Spanned;
8use syn::{
9 Data, DeriveInput, GenericArgument, Member, PathArguments, Result, Token, Type, Visibility,
10};
11
12pub fn derive(node: &DeriveInput) -> Result<TokenStream> {
13 let input: Input<'_> = Input::from_syn(node)?;
14 input.validate()?;
15 Ok(match input {
16 Input::Struct(input: Struct<'_>) => impl_struct(input),
17 Input::Enum(input: Enum<'_>) => impl_enum(input),
18 })
19}
20
21fn impl_struct(input: Struct) -> TokenStream {
22 let ty = &input.ident;
23 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
24 let mut error_inferred_bounds = InferredBounds::new();
25
26 let source_body = if input.attrs.transparent.is_some() {
27 let only_field = &input.fields[0];
28 if only_field.contains_generic {
29 error_inferred_bounds.insert(only_field.ty, quote!(std::error::Error));
30 }
31 let member = &only_field.member;
32 Some(quote! {
33 std::error::Error::source(self.#member.as_dyn_error())
34 })
35 } else if let Some(source_field) = input.source_field() {
36 let source = &source_field.member;
37 if source_field.contains_generic {
38 let ty = unoptional_type(source_field.ty);
39 error_inferred_bounds.insert(ty, quote!(std::error::Error + 'static));
40 }
41 let asref = if type_is_option(source_field.ty) {
42 Some(quote_spanned!(source.span()=> .as_ref()?))
43 } else {
44 None
45 };
46 let dyn_error = quote_spanned!(source.span()=> self.#source #asref.as_dyn_error());
47 Some(quote! {
48 std::option::Option::Some(#dyn_error)
49 })
50 } else {
51 None
52 };
53 let source_method = source_body.map(|body| {
54 quote! {
55 fn source(&self) -> std::option::Option<&(dyn std::error::Error + 'static)> {
56 use thiserror::__private::AsDynError;
57 #body
58 }
59 }
60 });
61
62 let provide_method = input.backtrace_field().map(|backtrace_field| {
63 let request = quote!(request);
64 let backtrace = &backtrace_field.member;
65 let body = if let Some(source_field) = input.source_field() {
66 let source = &source_field.member;
67 let source_provide = if type_is_option(source_field.ty) {
68 quote_spanned! {source.span()=>
69 if let std::option::Option::Some(source) = &self.#source {
70 source.thiserror_provide(#request);
71 }
72 }
73 } else {
74 quote_spanned! {source.span()=>
75 self.#source.thiserror_provide(#request);
76 }
77 };
78 let self_provide = if source == backtrace {
79 None
80 } else if type_is_option(backtrace_field.ty) {
81 Some(quote! {
82 if let std::option::Option::Some(backtrace) = &self.#backtrace {
83 #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
84 }
85 })
86 } else {
87 Some(quote! {
88 #request.provide_ref::<std::backtrace::Backtrace>(&self.#backtrace);
89 })
90 };
91 quote! {
92 use thiserror::__private::ThiserrorProvide;
93 #source_provide
94 #self_provide
95 }
96 } else if type_is_option(backtrace_field.ty) {
97 quote! {
98 if let std::option::Option::Some(backtrace) = &self.#backtrace {
99 #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
100 }
101 }
102 } else {
103 quote! {
104 #request.provide_ref::<std::backtrace::Backtrace>(&self.#backtrace);
105 }
106 };
107 quote! {
108 fn provide<'_request>(&'_request self, #request: &mut std::error::Request<'_request>) {
109 #body
110 }
111 }
112 });
113
114 let mut display_implied_bounds = Set::new();
115 let display_body = if input.attrs.transparent.is_some() {
116 let only_field = &input.fields[0].member;
117 display_implied_bounds.insert((0, Trait::Display));
118 Some(quote! {
119 std::fmt::Display::fmt(&self.#only_field, __formatter)
120 })
121 } else if let Some(display) = &input.attrs.display {
122 display_implied_bounds = display.implied_bounds.clone();
123 let use_as_display = if display.has_bonus_display {
124 Some(quote! {
125 #[allow(unused_imports)]
126 use thiserror::__private::{DisplayAsDisplay, PathAsDisplay};
127 })
128 } else {
129 None
130 };
131 let pat = fields_pat(&input.fields);
132 Some(quote! {
133 #use_as_display
134 #[allow(unused_variables, deprecated)]
135 let Self #pat = self;
136 #display
137 })
138 } else {
139 None
140 };
141 let display_impl = display_body.map(|body| {
142 let mut display_inferred_bounds = InferredBounds::new();
143 for (field, bound) in display_implied_bounds {
144 let field = &input.fields[field];
145 if field.contains_generic {
146 display_inferred_bounds.insert(field.ty, bound);
147 }
148 }
149 let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
150 quote! {
151 #[allow(unused_qualifications)]
152 impl #impl_generics std::fmt::Display for #ty #ty_generics #display_where_clause {
153 #[allow(clippy::used_underscore_binding)]
154 fn fmt(&self, __formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
155 #body
156 }
157 }
158 }
159 });
160
161 let from_impl = input.from_field().map(|from_field| {
162 let backtrace_field = input.distinct_backtrace_field();
163 let from = unoptional_type(from_field.ty);
164 let body = from_initializer(from_field, backtrace_field);
165 quote! {
166 #[allow(unused_qualifications)]
167 impl #impl_generics std::convert::From<#from> for #ty #ty_generics #where_clause {
168 #[allow(deprecated)]
169 fn from(source: #from) -> Self {
170 #ty #body
171 }
172 }
173 }
174 });
175
176 let error_trait = spanned_error_trait(input.original);
177 if input.generics.type_params().next().is_some() {
178 let self_token = <Token![Self]>::default();
179 error_inferred_bounds.insert(self_token, Trait::Debug);
180 error_inferred_bounds.insert(self_token, Trait::Display);
181 }
182 let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
183
184 quote! {
185 #[allow(unused_qualifications)]
186 impl #impl_generics #error_trait for #ty #ty_generics #error_where_clause {
187 #source_method
188 #provide_method
189 }
190 #display_impl
191 #from_impl
192 }
193}
194
195fn impl_enum(input: Enum) -> TokenStream {
196 let ty = &input.ident;
197 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
198 let mut error_inferred_bounds = InferredBounds::new();
199
200 let source_method = if input.has_source() {
201 let arms = input.variants.iter().map(|variant| {
202 let ident = &variant.ident;
203 if variant.attrs.transparent.is_some() {
204 let only_field = &variant.fields[0];
205 if only_field.contains_generic {
206 error_inferred_bounds.insert(only_field.ty, quote!(std::error::Error));
207 }
208 let member = &only_field.member;
209 let source = quote!(std::error::Error::source(transparent.as_dyn_error()));
210 quote! {
211 #ty::#ident {#member: transparent} => #source,
212 }
213 } else if let Some(source_field) = variant.source_field() {
214 let source = &source_field.member;
215 if source_field.contains_generic {
216 let ty = unoptional_type(source_field.ty);
217 error_inferred_bounds.insert(ty, quote!(std::error::Error + 'static));
218 }
219 let asref = if type_is_option(source_field.ty) {
220 Some(quote_spanned!(source.span()=> .as_ref()?))
221 } else {
222 None
223 };
224 let varsource = quote!(source);
225 let dyn_error = quote_spanned!(source.span()=> #varsource #asref.as_dyn_error());
226 quote! {
227 #ty::#ident {#source: #varsource, ..} => std::option::Option::Some(#dyn_error),
228 }
229 } else {
230 quote! {
231 #ty::#ident {..} => std::option::Option::None,
232 }
233 }
234 });
235 Some(quote! {
236 fn source(&self) -> std::option::Option<&(dyn std::error::Error + 'static)> {
237 use thiserror::__private::AsDynError;
238 #[allow(deprecated)]
239 match self {
240 #(#arms)*
241 }
242 }
243 })
244 } else {
245 None
246 };
247
248 let provide_method = if input.has_backtrace() {
249 let request = quote!(request);
250 let arms = input.variants.iter().map(|variant| {
251 let ident = &variant.ident;
252 match (variant.backtrace_field(), variant.source_field()) {
253 (Some(backtrace_field), Some(source_field))
254 if backtrace_field.attrs.backtrace.is_none() =>
255 {
256 let backtrace = &backtrace_field.member;
257 let source = &source_field.member;
258 let varsource = quote!(source);
259 let source_provide = if type_is_option(source_field.ty) {
260 quote_spanned! {source.span()=>
261 if let std::option::Option::Some(source) = #varsource {
262 source.thiserror_provide(#request);
263 }
264 }
265 } else {
266 quote_spanned! {source.span()=>
267 #varsource.thiserror_provide(#request);
268 }
269 };
270 let self_provide = if type_is_option(backtrace_field.ty) {
271 quote! {
272 if let std::option::Option::Some(backtrace) = backtrace {
273 #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
274 }
275 }
276 } else {
277 quote! {
278 #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
279 }
280 };
281 quote! {
282 #ty::#ident {
283 #backtrace: backtrace,
284 #source: #varsource,
285 ..
286 } => {
287 use thiserror::__private::ThiserrorProvide;
288 #source_provide
289 #self_provide
290 }
291 }
292 }
293 (Some(backtrace_field), Some(source_field))
294 if backtrace_field.member == source_field.member =>
295 {
296 let backtrace = &backtrace_field.member;
297 let varsource = quote!(source);
298 let source_provide = if type_is_option(source_field.ty) {
299 quote_spanned! {backtrace.span()=>
300 if let std::option::Option::Some(source) = #varsource {
301 source.thiserror_provide(#request);
302 }
303 }
304 } else {
305 quote_spanned! {backtrace.span()=>
306 #varsource.thiserror_provide(#request);
307 }
308 };
309 quote! {
310 #ty::#ident {#backtrace: #varsource, ..} => {
311 use thiserror::__private::ThiserrorProvide;
312 #source_provide
313 }
314 }
315 }
316 (Some(backtrace_field), _) => {
317 let backtrace = &backtrace_field.member;
318 let body = if type_is_option(backtrace_field.ty) {
319 quote! {
320 if let std::option::Option::Some(backtrace) = backtrace {
321 #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
322 }
323 }
324 } else {
325 quote! {
326 #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
327 }
328 };
329 quote! {
330 #ty::#ident {#backtrace: backtrace, ..} => {
331 #body
332 }
333 }
334 }
335 (None, _) => quote! {
336 #ty::#ident {..} => {}
337 },
338 }
339 });
340 Some(quote! {
341 fn provide<'_request>(&'_request self, #request: &mut std::error::Request<'_request>) {
342 #[allow(deprecated)]
343 match self {
344 #(#arms)*
345 }
346 }
347 })
348 } else {
349 None
350 };
351
352 let display_impl = if input.has_display() {
353 let mut display_inferred_bounds = InferredBounds::new();
354 let use_as_display = if input.variants.iter().any(|v| {
355 v.attrs
356 .display
357 .as_ref()
358 .map_or(false, |display| display.has_bonus_display)
359 }) {
360 Some(quote! {
361 #[allow(unused_imports)]
362 use thiserror::__private::{DisplayAsDisplay, PathAsDisplay};
363 })
364 } else {
365 None
366 };
367 let void_deref = if input.variants.is_empty() {
368 Some(quote!(*))
369 } else {
370 None
371 };
372 let arms = input.variants.iter().map(|variant| {
373 let mut display_implied_bounds = Set::new();
374 let display = match &variant.attrs.display {
375 Some(display) => {
376 display_implied_bounds = display.implied_bounds.clone();
377 display.to_token_stream()
378 }
379 None => {
380 let only_field = match &variant.fields[0].member {
381 Member::Named(ident) => ident.clone(),
382 Member::Unnamed(index) => format_ident!("_{}", index),
383 };
384 display_implied_bounds.insert((0, Trait::Display));
385 quote!(std::fmt::Display::fmt(#only_field, __formatter))
386 }
387 };
388 for (field, bound) in display_implied_bounds {
389 let field = &variant.fields[field];
390 if field.contains_generic {
391 display_inferred_bounds.insert(field.ty, bound);
392 }
393 }
394 let ident = &variant.ident;
395 let pat = fields_pat(&variant.fields);
396 quote! {
397 #ty::#ident #pat => #display
398 }
399 });
400 let arms = arms.collect::<Vec<_>>();
401 let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
402 Some(quote! {
403 #[allow(unused_qualifications)]
404 impl #impl_generics std::fmt::Display for #ty #ty_generics #display_where_clause {
405 fn fmt(&self, __formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
406 #use_as_display
407 #[allow(unused_variables, deprecated, clippy::used_underscore_binding)]
408 match #void_deref self {
409 #(#arms,)*
410 }
411 }
412 }
413 })
414 } else {
415 None
416 };
417
418 let from_impls = input.variants.iter().filter_map(|variant| {
419 let from_field = variant.from_field()?;
420 let backtrace_field = variant.distinct_backtrace_field();
421 let variant = &variant.ident;
422 let from = unoptional_type(from_field.ty);
423 let body = from_initializer(from_field, backtrace_field);
424 Some(quote! {
425 #[allow(unused_qualifications)]
426 impl #impl_generics std::convert::From<#from> for #ty #ty_generics #where_clause {
427 #[allow(deprecated)]
428 fn from(source: #from) -> Self {
429 #ty::#variant #body
430 }
431 }
432 })
433 });
434
435 let error_trait = spanned_error_trait(input.original);
436 if input.generics.type_params().next().is_some() {
437 let self_token = <Token![Self]>::default();
438 error_inferred_bounds.insert(self_token, Trait::Debug);
439 error_inferred_bounds.insert(self_token, Trait::Display);
440 }
441 let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
442
443 quote! {
444 #[allow(unused_qualifications)]
445 impl #impl_generics #error_trait for #ty #ty_generics #error_where_clause {
446 #source_method
447 #provide_method
448 }
449 #display_impl
450 #(#from_impls)*
451 }
452}
453
454fn fields_pat(fields: &[Field]) -> TokenStream {
455 let mut members: impl Iterator = fields.iter().map(|field: &Field<'_>| &field.member).peekable();
456 match members.peek() {
457 Some(Member::Named(_)) => quote!({ #(#members),* }),
458 Some(Member::Unnamed(_)) => {
459 let vars: impl Iterator = members.map(|member: &Member| match member {
460 Member::Unnamed(member: &Index) => format_ident!("_{}", member),
461 Member::Named(_) => unreachable!(),
462 });
463 quote!((#(#vars),*))
464 }
465 None => quote!({}),
466 }
467}
468
469fn from_initializer(from_field: &Field, backtrace_field: Option<&Field>) -> TokenStream {
470 let from_member: &Member = &from_field.member;
471 let some_source: TokenStream = if type_is_option(from_field.ty) {
472 quote!(std::option::Option::Some(source))
473 } else {
474 quote!(source)
475 };
476 let backtrace: Option = backtrace_field.map(|backtrace_field: &Field<'_>| {
477 let backtrace_member: &Member = &backtrace_field.member;
478 if type_is_option(backtrace_field.ty) {
479 quote! {
480 #backtrace_member: std::option::Option::Some(std::backtrace::Backtrace::capture()),
481 }
482 } else {
483 quote! {
484 #backtrace_member: std::convert::From::from(std::backtrace::Backtrace::capture()),
485 }
486 }
487 });
488 quote!({
489 #from_member: #some_source,
490 #backtrace
491 })
492}
493
494fn type_is_option(ty: &Type) -> bool {
495 type_parameter_of_option(ty).is_some()
496}
497
498fn unoptional_type(ty: &Type) -> TokenStream {
499 let unoptional: &Type = type_parameter_of_option(ty).unwrap_or(default:ty);
500 quote!(#unoptional)
501}
502
503fn type_parameter_of_option(ty: &Type) -> Option<&Type> {
504 let path = match ty {
505 Type::Path(ty) => &ty.path,
506 _ => return None,
507 };
508
509 let last = path.segments.last().unwrap();
510 if last.ident != "Option" {
511 return None;
512 }
513
514 let bracketed = match &last.arguments {
515 PathArguments::AngleBracketed(bracketed) => bracketed,
516 _ => return None,
517 };
518
519 if bracketed.args.len() != 1 {
520 return None;
521 }
522
523 match &bracketed.args[0] {
524 GenericArgument::Type(arg) => Some(arg),
525 _ => None,
526 }
527}
528
529fn spanned_error_trait(input: &DeriveInput) -> TokenStream {
530 let vis_span: Option = match &input.vis {
531 Visibility::Public(vis: &Pub) => Some(vis.span),
532 Visibility::Restricted(vis: &VisRestricted) => Some(vis.pub_token.span),
533 Visibility::Inherited => None,
534 };
535 let data_span: Span = match &input.data {
536 Data::Struct(data: &DataStruct) => data.struct_token.span,
537 Data::Enum(data: &DataEnum) => data.enum_token.span,
538 Data::Union(data: &DataUnion) => data.union_token.span,
539 };
540 let first_span: Span = vis_span.unwrap_or(default:data_span);
541 let last_span: Span = input.ident.span();
542 let path: TokenStream = quote_spanned!(first_span=> std::error::);
543 let error: TokenStream = quote_spanned!(last_span=> Error);
544 quote!(#path #error)
545}
546