1 | use crate::ast::{Enum, Field, Input, Struct}; |
2 | use crate::attr::Trait; |
3 | use crate::generics::InferredBounds; |
4 | use proc_macro2::TokenStream; |
5 | use quote::{format_ident, quote, quote_spanned, ToTokens}; |
6 | use std::collections::BTreeSet as Set; |
7 | use syn::spanned::Spanned; |
8 | use syn::{ |
9 | Data, DeriveInput, GenericArgument, Member, PathArguments, Result, Token, Type, Visibility, |
10 | }; |
11 | |
12 | pub fn derive(node: &DeriveInput) -> Result<TokenStream> { |
13 | let input: Input<'_> = Input::from_syn(node)?; |
14 | input.validate()?; |
15 | Ok(match input { |
16 | Input::Struct(input: Struct<'_>) => impl_struct(input), |
17 | Input::Enum(input: Enum<'_>) => impl_enum(input), |
18 | }) |
19 | } |
20 | |
21 | fn impl_struct(input: Struct) -> TokenStream { |
22 | let ty = &input.ident; |
23 | let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); |
24 | let mut error_inferred_bounds = InferredBounds::new(); |
25 | |
26 | let source_body = if input.attrs.transparent.is_some() { |
27 | let only_field = &input.fields[0]; |
28 | if only_field.contains_generic { |
29 | error_inferred_bounds.insert(only_field.ty, quote!(std::error::Error)); |
30 | } |
31 | let member = &only_field.member; |
32 | Some(quote! { |
33 | std::error::Error::source(self.#member.as_dyn_error()) |
34 | }) |
35 | } else if let Some(source_field) = input.source_field() { |
36 | let source = &source_field.member; |
37 | if source_field.contains_generic { |
38 | let ty = unoptional_type(source_field.ty); |
39 | error_inferred_bounds.insert(ty, quote!(std::error::Error + 'static)); |
40 | } |
41 | let asref = if type_is_option(source_field.ty) { |
42 | Some(quote_spanned!(source.span()=> .as_ref()?)) |
43 | } else { |
44 | None |
45 | }; |
46 | let dyn_error = quote_spanned!(source.span()=> self.#source #asref.as_dyn_error()); |
47 | Some(quote! { |
48 | std::option::Option::Some(#dyn_error) |
49 | }) |
50 | } else { |
51 | None |
52 | }; |
53 | let source_method = source_body.map(|body| { |
54 | quote! { |
55 | fn source(&self) -> std::option::Option<&(dyn std::error::Error + 'static)> { |
56 | use thiserror::__private::AsDynError; |
57 | #body |
58 | } |
59 | } |
60 | }); |
61 | |
62 | let provide_method = input.backtrace_field().map(|backtrace_field| { |
63 | let request = quote!(request); |
64 | let backtrace = &backtrace_field.member; |
65 | let body = if let Some(source_field) = input.source_field() { |
66 | let source = &source_field.member; |
67 | let source_provide = if type_is_option(source_field.ty) { |
68 | quote_spanned! {source.span()=> |
69 | if let std::option::Option::Some(source) = &self.#source { |
70 | source.thiserror_provide(#request); |
71 | } |
72 | } |
73 | } else { |
74 | quote_spanned! {source.span()=> |
75 | self.#source.thiserror_provide(#request); |
76 | } |
77 | }; |
78 | let self_provide = if source == backtrace { |
79 | None |
80 | } else if type_is_option(backtrace_field.ty) { |
81 | Some(quote! { |
82 | if let std::option::Option::Some(backtrace) = &self.#backtrace { |
83 | #request.provide_ref::<std::backtrace::Backtrace>(backtrace); |
84 | } |
85 | }) |
86 | } else { |
87 | Some(quote! { |
88 | #request.provide_ref::<std::backtrace::Backtrace>(&self.#backtrace); |
89 | }) |
90 | }; |
91 | quote! { |
92 | use thiserror::__private::ThiserrorProvide; |
93 | #source_provide |
94 | #self_provide |
95 | } |
96 | } else if type_is_option(backtrace_field.ty) { |
97 | quote! { |
98 | if let std::option::Option::Some(backtrace) = &self.#backtrace { |
99 | #request.provide_ref::<std::backtrace::Backtrace>(backtrace); |
100 | } |
101 | } |
102 | } else { |
103 | quote! { |
104 | #request.provide_ref::<std::backtrace::Backtrace>(&self.#backtrace); |
105 | } |
106 | }; |
107 | quote! { |
108 | fn provide<'_request>(&'_request self, #request: &mut std::error::Request<'_request>) { |
109 | #body |
110 | } |
111 | } |
112 | }); |
113 | |
114 | let mut display_implied_bounds = Set::new(); |
115 | let display_body = if input.attrs.transparent.is_some() { |
116 | let only_field = &input.fields[0].member; |
117 | display_implied_bounds.insert((0, Trait::Display)); |
118 | Some(quote! { |
119 | std::fmt::Display::fmt(&self.#only_field, __formatter) |
120 | }) |
121 | } else if let Some(display) = &input.attrs.display { |
122 | display_implied_bounds = display.implied_bounds.clone(); |
123 | let use_as_display = if display.has_bonus_display { |
124 | Some(quote! { |
125 | #[allow(unused_imports)] |
126 | use thiserror::__private::{DisplayAsDisplay, PathAsDisplay}; |
127 | }) |
128 | } else { |
129 | None |
130 | }; |
131 | let pat = fields_pat(&input.fields); |
132 | Some(quote! { |
133 | #use_as_display |
134 | #[allow(unused_variables, deprecated)] |
135 | let Self #pat = self; |
136 | #display |
137 | }) |
138 | } else { |
139 | None |
140 | }; |
141 | let display_impl = display_body.map(|body| { |
142 | let mut display_inferred_bounds = InferredBounds::new(); |
143 | for (field, bound) in display_implied_bounds { |
144 | let field = &input.fields[field]; |
145 | if field.contains_generic { |
146 | display_inferred_bounds.insert(field.ty, bound); |
147 | } |
148 | } |
149 | let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics); |
150 | quote! { |
151 | #[allow(unused_qualifications)] |
152 | impl #impl_generics std::fmt::Display for #ty #ty_generics #display_where_clause { |
153 | #[allow(clippy::used_underscore_binding)] |
154 | fn fmt(&self, __formatter: &mut std::fmt::Formatter) -> std::fmt::Result { |
155 | #body |
156 | } |
157 | } |
158 | } |
159 | }); |
160 | |
161 | let from_impl = input.from_field().map(|from_field| { |
162 | let backtrace_field = input.distinct_backtrace_field(); |
163 | let from = unoptional_type(from_field.ty); |
164 | let body = from_initializer(from_field, backtrace_field); |
165 | quote! { |
166 | #[allow(unused_qualifications)] |
167 | impl #impl_generics std::convert::From<#from> for #ty #ty_generics #where_clause { |
168 | #[allow(deprecated)] |
169 | fn from(source: #from) -> Self { |
170 | #ty #body |
171 | } |
172 | } |
173 | } |
174 | }); |
175 | |
176 | let error_trait = spanned_error_trait(input.original); |
177 | if input.generics.type_params().next().is_some() { |
178 | let self_token = <Token![Self]>::default(); |
179 | error_inferred_bounds.insert(self_token, Trait::Debug); |
180 | error_inferred_bounds.insert(self_token, Trait::Display); |
181 | } |
182 | let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics); |
183 | |
184 | quote! { |
185 | #[allow(unused_qualifications)] |
186 | impl #impl_generics #error_trait for #ty #ty_generics #error_where_clause { |
187 | #source_method |
188 | #provide_method |
189 | } |
190 | #display_impl |
191 | #from_impl |
192 | } |
193 | } |
194 | |
195 | fn impl_enum(input: Enum) -> TokenStream { |
196 | let ty = &input.ident; |
197 | let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); |
198 | let mut error_inferred_bounds = InferredBounds::new(); |
199 | |
200 | let source_method = if input.has_source() { |
201 | let arms = input.variants.iter().map(|variant| { |
202 | let ident = &variant.ident; |
203 | if variant.attrs.transparent.is_some() { |
204 | let only_field = &variant.fields[0]; |
205 | if only_field.contains_generic { |
206 | error_inferred_bounds.insert(only_field.ty, quote!(std::error::Error)); |
207 | } |
208 | let member = &only_field.member; |
209 | let source = quote!(std::error::Error::source(transparent.as_dyn_error())); |
210 | quote! { |
211 | #ty::#ident {#member: transparent} => #source, |
212 | } |
213 | } else if let Some(source_field) = variant.source_field() { |
214 | let source = &source_field.member; |
215 | if source_field.contains_generic { |
216 | let ty = unoptional_type(source_field.ty); |
217 | error_inferred_bounds.insert(ty, quote!(std::error::Error + 'static)); |
218 | } |
219 | let asref = if type_is_option(source_field.ty) { |
220 | Some(quote_spanned!(source.span()=> .as_ref()?)) |
221 | } else { |
222 | None |
223 | }; |
224 | let varsource = quote!(source); |
225 | let dyn_error = quote_spanned!(source.span()=> #varsource #asref.as_dyn_error()); |
226 | quote! { |
227 | #ty::#ident {#source: #varsource, ..} => std::option::Option::Some(#dyn_error), |
228 | } |
229 | } else { |
230 | quote! { |
231 | #ty::#ident {..} => std::option::Option::None, |
232 | } |
233 | } |
234 | }); |
235 | Some(quote! { |
236 | fn source(&self) -> std::option::Option<&(dyn std::error::Error + 'static)> { |
237 | use thiserror::__private::AsDynError; |
238 | #[allow(deprecated)] |
239 | match self { |
240 | #(#arms)* |
241 | } |
242 | } |
243 | }) |
244 | } else { |
245 | None |
246 | }; |
247 | |
248 | let provide_method = if input.has_backtrace() { |
249 | let request = quote!(request); |
250 | let arms = input.variants.iter().map(|variant| { |
251 | let ident = &variant.ident; |
252 | match (variant.backtrace_field(), variant.source_field()) { |
253 | (Some(backtrace_field), Some(source_field)) |
254 | if backtrace_field.attrs.backtrace.is_none() => |
255 | { |
256 | let backtrace = &backtrace_field.member; |
257 | let source = &source_field.member; |
258 | let varsource = quote!(source); |
259 | let source_provide = if type_is_option(source_field.ty) { |
260 | quote_spanned! {source.span()=> |
261 | if let std::option::Option::Some(source) = #varsource { |
262 | source.thiserror_provide(#request); |
263 | } |
264 | } |
265 | } else { |
266 | quote_spanned! {source.span()=> |
267 | #varsource.thiserror_provide(#request); |
268 | } |
269 | }; |
270 | let self_provide = if type_is_option(backtrace_field.ty) { |
271 | quote! { |
272 | if let std::option::Option::Some(backtrace) = backtrace { |
273 | #request.provide_ref::<std::backtrace::Backtrace>(backtrace); |
274 | } |
275 | } |
276 | } else { |
277 | quote! { |
278 | #request.provide_ref::<std::backtrace::Backtrace>(backtrace); |
279 | } |
280 | }; |
281 | quote! { |
282 | #ty::#ident { |
283 | #backtrace: backtrace, |
284 | #source: #varsource, |
285 | .. |
286 | } => { |
287 | use thiserror::__private::ThiserrorProvide; |
288 | #source_provide |
289 | #self_provide |
290 | } |
291 | } |
292 | } |
293 | (Some(backtrace_field), Some(source_field)) |
294 | if backtrace_field.member == source_field.member => |
295 | { |
296 | let backtrace = &backtrace_field.member; |
297 | let varsource = quote!(source); |
298 | let source_provide = if type_is_option(source_field.ty) { |
299 | quote_spanned! {backtrace.span()=> |
300 | if let std::option::Option::Some(source) = #varsource { |
301 | source.thiserror_provide(#request); |
302 | } |
303 | } |
304 | } else { |
305 | quote_spanned! {backtrace.span()=> |
306 | #varsource.thiserror_provide(#request); |
307 | } |
308 | }; |
309 | quote! { |
310 | #ty::#ident {#backtrace: #varsource, ..} => { |
311 | use thiserror::__private::ThiserrorProvide; |
312 | #source_provide |
313 | } |
314 | } |
315 | } |
316 | (Some(backtrace_field), _) => { |
317 | let backtrace = &backtrace_field.member; |
318 | let body = if type_is_option(backtrace_field.ty) { |
319 | quote! { |
320 | if let std::option::Option::Some(backtrace) = backtrace { |
321 | #request.provide_ref::<std::backtrace::Backtrace>(backtrace); |
322 | } |
323 | } |
324 | } else { |
325 | quote! { |
326 | #request.provide_ref::<std::backtrace::Backtrace>(backtrace); |
327 | } |
328 | }; |
329 | quote! { |
330 | #ty::#ident {#backtrace: backtrace, ..} => { |
331 | #body |
332 | } |
333 | } |
334 | } |
335 | (None, _) => quote! { |
336 | #ty::#ident {..} => {} |
337 | }, |
338 | } |
339 | }); |
340 | Some(quote! { |
341 | fn provide<'_request>(&'_request self, #request: &mut std::error::Request<'_request>) { |
342 | #[allow(deprecated)] |
343 | match self { |
344 | #(#arms)* |
345 | } |
346 | } |
347 | }) |
348 | } else { |
349 | None |
350 | }; |
351 | |
352 | let display_impl = if input.has_display() { |
353 | let mut display_inferred_bounds = InferredBounds::new(); |
354 | let use_as_display = if input.variants.iter().any(|v| { |
355 | v.attrs |
356 | .display |
357 | .as_ref() |
358 | .map_or(false, |display| display.has_bonus_display) |
359 | }) { |
360 | Some(quote! { |
361 | #[allow(unused_imports)] |
362 | use thiserror::__private::{DisplayAsDisplay, PathAsDisplay}; |
363 | }) |
364 | } else { |
365 | None |
366 | }; |
367 | let void_deref = if input.variants.is_empty() { |
368 | Some(quote!(*)) |
369 | } else { |
370 | None |
371 | }; |
372 | let arms = input.variants.iter().map(|variant| { |
373 | let mut display_implied_bounds = Set::new(); |
374 | let display = match &variant.attrs.display { |
375 | Some(display) => { |
376 | display_implied_bounds = display.implied_bounds.clone(); |
377 | display.to_token_stream() |
378 | } |
379 | None => { |
380 | let only_field = match &variant.fields[0].member { |
381 | Member::Named(ident) => ident.clone(), |
382 | Member::Unnamed(index) => format_ident!("_ {}" , index), |
383 | }; |
384 | display_implied_bounds.insert((0, Trait::Display)); |
385 | quote!(std::fmt::Display::fmt(#only_field, __formatter)) |
386 | } |
387 | }; |
388 | for (field, bound) in display_implied_bounds { |
389 | let field = &variant.fields[field]; |
390 | if field.contains_generic { |
391 | display_inferred_bounds.insert(field.ty, bound); |
392 | } |
393 | } |
394 | let ident = &variant.ident; |
395 | let pat = fields_pat(&variant.fields); |
396 | quote! { |
397 | #ty::#ident #pat => #display |
398 | } |
399 | }); |
400 | let arms = arms.collect::<Vec<_>>(); |
401 | let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics); |
402 | Some(quote! { |
403 | #[allow(unused_qualifications)] |
404 | impl #impl_generics std::fmt::Display for #ty #ty_generics #display_where_clause { |
405 | fn fmt(&self, __formatter: &mut std::fmt::Formatter) -> std::fmt::Result { |
406 | #use_as_display |
407 | #[allow(unused_variables, deprecated, clippy::used_underscore_binding)] |
408 | match #void_deref self { |
409 | #(#arms,)* |
410 | } |
411 | } |
412 | } |
413 | }) |
414 | } else { |
415 | None |
416 | }; |
417 | |
418 | let from_impls = input.variants.iter().filter_map(|variant| { |
419 | let from_field = variant.from_field()?; |
420 | let backtrace_field = variant.distinct_backtrace_field(); |
421 | let variant = &variant.ident; |
422 | let from = unoptional_type(from_field.ty); |
423 | let body = from_initializer(from_field, backtrace_field); |
424 | Some(quote! { |
425 | #[allow(unused_qualifications)] |
426 | impl #impl_generics std::convert::From<#from> for #ty #ty_generics #where_clause { |
427 | #[allow(deprecated)] |
428 | fn from(source: #from) -> Self { |
429 | #ty::#variant #body |
430 | } |
431 | } |
432 | }) |
433 | }); |
434 | |
435 | let error_trait = spanned_error_trait(input.original); |
436 | if input.generics.type_params().next().is_some() { |
437 | let self_token = <Token![Self]>::default(); |
438 | error_inferred_bounds.insert(self_token, Trait::Debug); |
439 | error_inferred_bounds.insert(self_token, Trait::Display); |
440 | } |
441 | let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics); |
442 | |
443 | quote! { |
444 | #[allow(unused_qualifications)] |
445 | impl #impl_generics #error_trait for #ty #ty_generics #error_where_clause { |
446 | #source_method |
447 | #provide_method |
448 | } |
449 | #display_impl |
450 | #(#from_impls)* |
451 | } |
452 | } |
453 | |
454 | fn fields_pat(fields: &[Field]) -> TokenStream { |
455 | let mut members: impl Iterator = fields.iter().map(|field: &Field<'_>| &field.member).peekable(); |
456 | match members.peek() { |
457 | Some(Member::Named(_)) => quote!({ #(#members),* }), |
458 | Some(Member::Unnamed(_)) => { |
459 | let vars: impl Iterator = members.map(|member: &Member| match member { |
460 | Member::Unnamed(member: &Index) => format_ident!("_ {}" , member), |
461 | Member::Named(_) => unreachable!(), |
462 | }); |
463 | quote!((#(#vars),*)) |
464 | } |
465 | None => quote!({}), |
466 | } |
467 | } |
468 | |
469 | fn from_initializer(from_field: &Field, backtrace_field: Option<&Field>) -> TokenStream { |
470 | let from_member: &Member = &from_field.member; |
471 | let some_source: TokenStream = if type_is_option(from_field.ty) { |
472 | quote!(std::option::Option::Some(source)) |
473 | } else { |
474 | quote!(source) |
475 | }; |
476 | let backtrace: Option = backtrace_field.map(|backtrace_field: &Field<'_>| { |
477 | let backtrace_member: &Member = &backtrace_field.member; |
478 | if type_is_option(backtrace_field.ty) { |
479 | quote! { |
480 | #backtrace_member: std::option::Option::Some(std::backtrace::Backtrace::capture()), |
481 | } |
482 | } else { |
483 | quote! { |
484 | #backtrace_member: std::convert::From::from(std::backtrace::Backtrace::capture()), |
485 | } |
486 | } |
487 | }); |
488 | quote!({ |
489 | #from_member: #some_source, |
490 | #backtrace |
491 | }) |
492 | } |
493 | |
494 | fn type_is_option(ty: &Type) -> bool { |
495 | type_parameter_of_option(ty).is_some() |
496 | } |
497 | |
498 | fn unoptional_type(ty: &Type) -> TokenStream { |
499 | let unoptional: &Type = type_parameter_of_option(ty).unwrap_or(default:ty); |
500 | quote!(#unoptional) |
501 | } |
502 | |
503 | fn type_parameter_of_option(ty: &Type) -> Option<&Type> { |
504 | let path = match ty { |
505 | Type::Path(ty) => &ty.path, |
506 | _ => return None, |
507 | }; |
508 | |
509 | let last = path.segments.last().unwrap(); |
510 | if last.ident != "Option" { |
511 | return None; |
512 | } |
513 | |
514 | let bracketed = match &last.arguments { |
515 | PathArguments::AngleBracketed(bracketed) => bracketed, |
516 | _ => return None, |
517 | }; |
518 | |
519 | if bracketed.args.len() != 1 { |
520 | return None; |
521 | } |
522 | |
523 | match &bracketed.args[0] { |
524 | GenericArgument::Type(arg) => Some(arg), |
525 | _ => None, |
526 | } |
527 | } |
528 | |
529 | fn spanned_error_trait(input: &DeriveInput) -> TokenStream { |
530 | let vis_span: Option = match &input.vis { |
531 | Visibility::Public(vis: &Pub) => Some(vis.span), |
532 | Visibility::Restricted(vis: &VisRestricted) => Some(vis.pub_token.span), |
533 | Visibility::Inherited => None, |
534 | }; |
535 | let data_span: Span = match &input.data { |
536 | Data::Struct(data: &DataStruct) => data.struct_token.span, |
537 | Data::Enum(data: &DataEnum) => data.enum_token.span, |
538 | Data::Union(data: &DataUnion) => data.union_token.span, |
539 | }; |
540 | let first_span: Span = vis_span.unwrap_or(default:data_span); |
541 | let last_span: Span = input.ident.span(); |
542 | let path: TokenStream = quote_spanned!(first_span=> std::error::); |
543 | let error: TokenStream = quote_spanned!(last_span=> Error); |
544 | quote!(#path #error) |
545 | } |
546 | |