| 1 | use proc_macro2::TokenStream; |
| 2 | use quote::quote; |
| 3 | use syn::{spanned::Spanned as _, Error, Result}; |
| 4 | |
| 5 | use crate::utils::{ |
| 6 | self, AttrParams, DeriveType, FullMetaInfo, HashSet, MetaInfo, MultiFieldData, |
| 7 | State, |
| 8 | }; |
| 9 | |
| 10 | pub fn expand( |
| 11 | input: &syn::DeriveInput, |
| 12 | trait_name: &'static str, |
| 13 | ) -> Result<TokenStream> { |
| 14 | let syn::DeriveInput { |
| 15 | ident, generics, .. |
| 16 | } = input; |
| 17 | |
| 18 | let state = State::with_attr_params( |
| 19 | input, |
| 20 | trait_name, |
| 21 | trait_name.to_lowercase(), |
| 22 | allowed_attr_params(), |
| 23 | )?; |
| 24 | |
| 25 | let type_params: HashSet<_> = generics |
| 26 | .params |
| 27 | .iter() |
| 28 | .filter_map(|generic| match generic { |
| 29 | syn::GenericParam::Type(ty) => Some(ty.ident.clone()), |
| 30 | _ => None, |
| 31 | }) |
| 32 | .collect(); |
| 33 | |
| 34 | let (bounds, source, provide) = match state.derive_type { |
| 35 | DeriveType::Named | DeriveType::Unnamed => render_struct(&type_params, &state)?, |
| 36 | DeriveType::Enum => render_enum(&type_params, &state)?, |
| 37 | }; |
| 38 | |
| 39 | let source = source.map(|source| { |
| 40 | // Not using `#[inline]` here on purpose, since this is almost never part |
| 41 | // of a hot codepath. |
| 42 | quote! { |
| 43 | // TODO: Use `derive_more::core::error::Error` once `error_in_core` Rust feature is |
| 44 | // stabilized. |
| 45 | fn source(&self) -> Option<&(dyn derive_more::with_trait::Error + 'static)> { |
| 46 | use derive_more::__private::AsDynError; |
| 47 | #source |
| 48 | } |
| 49 | } |
| 50 | }); |
| 51 | |
| 52 | let provide = provide.map(|provide| { |
| 53 | // Not using `#[inline]` here on purpose, since this is almost never part |
| 54 | // of a hot codepath. |
| 55 | quote! { |
| 56 | fn provide<'_request>( |
| 57 | &'_request self, |
| 58 | request: &mut derive_more::core::error::Request<'_request>, |
| 59 | ) { |
| 60 | #provide |
| 61 | } |
| 62 | } |
| 63 | }); |
| 64 | |
| 65 | let mut generics = generics.clone(); |
| 66 | |
| 67 | if !type_params.is_empty() { |
| 68 | let (_, ty_generics, _) = generics.split_for_impl(); |
| 69 | generics = utils::add_extra_where_clauses( |
| 70 | &generics, |
| 71 | quote! { |
| 72 | where |
| 73 | #ident #ty_generics: derive_more::core::fmt::Debug |
| 74 | + derive_more::core::fmt::Display |
| 75 | }, |
| 76 | ); |
| 77 | } |
| 78 | |
| 79 | if !bounds.is_empty() { |
| 80 | let bounds = bounds.iter(); |
| 81 | generics = utils::add_extra_where_clauses( |
| 82 | &generics, |
| 83 | quote! { |
| 84 | where #( |
| 85 | #bounds: derive_more::core::fmt::Debug |
| 86 | + derive_more::core::fmt::Display |
| 87 | // TODO: Use `derive_more::core::error::Error` once `error_in_core` |
| 88 | // Rust feature is stabilized. |
| 89 | + derive_more::with_trait::Error |
| 90 | + 'static |
| 91 | ),* |
| 92 | }, |
| 93 | ); |
| 94 | } |
| 95 | |
| 96 | let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); |
| 97 | |
| 98 | let render = quote! { |
| 99 | #[automatically_derived] |
| 100 | // TODO: Use `derive_more::core::error::Error` once `error_in_core` Rust feature is |
| 101 | // stabilized. |
| 102 | impl #impl_generics derive_more::with_trait::Error for #ident #ty_generics #where_clause { |
| 103 | #source |
| 104 | #provide |
| 105 | } |
| 106 | }; |
| 107 | |
| 108 | Ok(render) |
| 109 | } |
| 110 | |
| 111 | fn render_struct( |
| 112 | type_params: &HashSet<syn::Ident>, |
| 113 | state: &State, |
| 114 | ) -> Result<(HashSet<syn::Type>, Option<TokenStream>, Option<TokenStream>)> { |
| 115 | let parsed_fields: ParsedFields<'_, '_> = parse_fields(type_params, state)?; |
| 116 | |
| 117 | let source: Option = parsed_fields.render_source_as_struct(); |
| 118 | let provide: Option = parsed_fields.render_provide_as_struct(); |
| 119 | |
| 120 | Ok((parsed_fields.bounds, source, provide)) |
| 121 | } |
| 122 | |
| 123 | fn render_enum( |
| 124 | type_params: &HashSet<syn::Ident>, |
| 125 | state: &State, |
| 126 | ) -> Result<(HashSet<syn::Type>, Option<TokenStream>, Option<TokenStream>)> { |
| 127 | let mut bounds = HashSet::default(); |
| 128 | let mut source_match_arms = Vec::new(); |
| 129 | let mut provide_match_arms = Vec::new(); |
| 130 | |
| 131 | for variant in state.enabled_variant_data().variants { |
| 132 | let default_info = FullMetaInfo { |
| 133 | enabled: true, |
| 134 | ..FullMetaInfo::default() |
| 135 | }; |
| 136 | |
| 137 | let state = State::from_variant( |
| 138 | state.input, |
| 139 | state.trait_name, |
| 140 | state.trait_attr.clone(), |
| 141 | allowed_attr_params(), |
| 142 | variant, |
| 143 | default_info, |
| 144 | )?; |
| 145 | |
| 146 | let parsed_fields = parse_fields(type_params, &state)?; |
| 147 | |
| 148 | if let Some(expr) = parsed_fields.render_source_as_enum_variant_match_arm() { |
| 149 | source_match_arms.push(expr); |
| 150 | } |
| 151 | |
| 152 | if let Some(expr) = parsed_fields.render_provide_as_enum_variant_match_arm() { |
| 153 | provide_match_arms.push(expr); |
| 154 | } |
| 155 | |
| 156 | bounds.extend(parsed_fields.bounds.into_iter()); |
| 157 | } |
| 158 | |
| 159 | let render = |match_arms: &mut Vec<TokenStream>, unmatched| { |
| 160 | if !match_arms.is_empty() && match_arms.len() < state.variants.len() { |
| 161 | match_arms.push(quote! { _ => #unmatched }); |
| 162 | } |
| 163 | |
| 164 | (!match_arms.is_empty()).then(|| { |
| 165 | quote! { |
| 166 | match self { |
| 167 | #(#match_arms),* |
| 168 | } |
| 169 | } |
| 170 | }) |
| 171 | }; |
| 172 | |
| 173 | let source = render(&mut source_match_arms, quote! { None }); |
| 174 | let provide = render(&mut provide_match_arms, quote! { () }); |
| 175 | |
| 176 | Ok((bounds, source, provide)) |
| 177 | } |
| 178 | |
| 179 | fn allowed_attr_params() -> AttrParams { |
| 180 | AttrParams { |
| 181 | enum_: vec!["ignore" ], |
| 182 | struct_: vec!["ignore" ], |
| 183 | variant: vec!["ignore" ], |
| 184 | field: vec!["ignore" , "source" , "backtrace" ], |
| 185 | } |
| 186 | } |
| 187 | |
| 188 | struct ParsedFields<'input, 'state> { |
| 189 | data: MultiFieldData<'input, 'state>, |
| 190 | source: Option<usize>, |
| 191 | backtrace: Option<usize>, |
| 192 | bounds: HashSet<syn::Type>, |
| 193 | } |
| 194 | |
| 195 | impl<'input, 'state> ParsedFields<'input, 'state> { |
| 196 | fn new(data: MultiFieldData<'input, 'state>) -> Self { |
| 197 | Self { |
| 198 | data, |
| 199 | source: None, |
| 200 | backtrace: None, |
| 201 | bounds: HashSet::default(), |
| 202 | } |
| 203 | } |
| 204 | } |
| 205 | |
| 206 | impl ParsedFields<'_, '_> { |
| 207 | fn render_source_as_struct(&self) -> Option<TokenStream> { |
| 208 | let source = self.source?; |
| 209 | let ident = &self.data.members[source]; |
| 210 | Some(render_some(quote! { #ident })) |
| 211 | } |
| 212 | |
| 213 | fn render_source_as_enum_variant_match_arm(&self) -> Option<TokenStream> { |
| 214 | let source = self.source?; |
| 215 | let pattern = self.data.matcher(&[source], &[quote! { source }]); |
| 216 | let expr = render_some(quote! { source }); |
| 217 | Some(quote! { #pattern => #expr }) |
| 218 | } |
| 219 | |
| 220 | fn render_provide_as_struct(&self) -> Option<TokenStream> { |
| 221 | let backtrace = self.backtrace?; |
| 222 | |
| 223 | let source_provider = self.source.map(|source| { |
| 224 | let source_expr = &self.data.members[source]; |
| 225 | quote! { |
| 226 | // TODO: Use `derive_more::core::error::Error` once `error_in_core` Rust feature is |
| 227 | // stabilized. |
| 228 | derive_more::with_trait::Error::provide(&#source_expr, request); |
| 229 | } |
| 230 | }); |
| 231 | let backtrace_provider = self |
| 232 | .source |
| 233 | .filter(|source| *source == backtrace) |
| 234 | .is_none() |
| 235 | .then(|| { |
| 236 | let backtrace_expr = &self.data.members[backtrace]; |
| 237 | quote! { |
| 238 | request.provide_ref::<::std::backtrace::Backtrace>(&#backtrace_expr); |
| 239 | } |
| 240 | }); |
| 241 | |
| 242 | (source_provider.is_some() || backtrace_provider.is_some()).then(|| { |
| 243 | quote! { |
| 244 | #backtrace_provider |
| 245 | #source_provider |
| 246 | } |
| 247 | }) |
| 248 | } |
| 249 | |
| 250 | fn render_provide_as_enum_variant_match_arm(&self) -> Option<TokenStream> { |
| 251 | let backtrace = self.backtrace?; |
| 252 | |
| 253 | match self.source { |
| 254 | Some(source) if source == backtrace => { |
| 255 | let pattern = self.data.matcher(&[source], &[quote! { source }]); |
| 256 | Some(quote! { |
| 257 | #pattern => { |
| 258 | // TODO: Use `derive_more::core::error::Error` once `error_in_core` Rust |
| 259 | // feature is stabilized. |
| 260 | derive_more::with_trait::Error::provide(source, request); |
| 261 | } |
| 262 | }) |
| 263 | } |
| 264 | Some(source) => { |
| 265 | let pattern = self.data.matcher( |
| 266 | &[source, backtrace], |
| 267 | &[quote! { source }, quote! { backtrace }], |
| 268 | ); |
| 269 | Some(quote! { |
| 270 | #pattern => { |
| 271 | request.provide_ref::<::std::backtrace::Backtrace>(backtrace); |
| 272 | // TODO: Use `derive_more::core::error::Error` once `error_in_core` Rust |
| 273 | // feature is stabilized. |
| 274 | derive_more::with_trait::Error::provide(source, request); |
| 275 | } |
| 276 | }) |
| 277 | } |
| 278 | None => { |
| 279 | let pattern = self.data.matcher(&[backtrace], &[quote! { backtrace }]); |
| 280 | Some(quote! { |
| 281 | #pattern => { |
| 282 | request.provide_ref::<::std::backtrace::Backtrace>(backtrace); |
| 283 | } |
| 284 | }) |
| 285 | } |
| 286 | } |
| 287 | } |
| 288 | } |
| 289 | |
| 290 | fn render_some<T>(expr: T) -> TokenStream |
| 291 | where |
| 292 | T: quote::ToTokens, |
| 293 | { |
| 294 | quote! { Some(#expr.as_dyn_error()) } |
| 295 | } |
| 296 | |
| 297 | fn parse_fields<'input, 'state>( |
| 298 | type_params: &HashSet<syn::Ident>, |
| 299 | state: &'state State<'input>, |
| 300 | ) -> Result<ParsedFields<'input, 'state>> { |
| 301 | let mut parsed_fields = match state.derive_type { |
| 302 | DeriveType::Named => { |
| 303 | parse_fields_impl(state, |attr, field, _| { |
| 304 | // Unwrapping is safe, cause fields in named struct |
| 305 | // always have an ident |
| 306 | let ident = field.ident.as_ref().unwrap(); |
| 307 | |
| 308 | match attr { |
| 309 | "source" => ident == "source" , |
| 310 | "backtrace" => { |
| 311 | ident == "backtrace" |
| 312 | || is_type_path_ends_with_segment(&field.ty, "Backtrace" ) |
| 313 | } |
| 314 | _ => unreachable!(), |
| 315 | } |
| 316 | }) |
| 317 | } |
| 318 | |
| 319 | DeriveType::Unnamed => { |
| 320 | let mut parsed_fields = |
| 321 | parse_fields_impl(state, |attr, field, len| match attr { |
| 322 | "source" => { |
| 323 | len == 1 |
| 324 | && !is_type_path_ends_with_segment(&field.ty, "Backtrace" ) |
| 325 | } |
| 326 | "backtrace" => { |
| 327 | is_type_path_ends_with_segment(&field.ty, "Backtrace" ) |
| 328 | } |
| 329 | _ => unreachable!(), |
| 330 | })?; |
| 331 | |
| 332 | parsed_fields.source = parsed_fields |
| 333 | .source |
| 334 | .or_else(|| infer_source_field(&state.fields, &parsed_fields)); |
| 335 | |
| 336 | Ok(parsed_fields) |
| 337 | } |
| 338 | |
| 339 | _ => unreachable!(), |
| 340 | }?; |
| 341 | |
| 342 | if let Some(source) = parsed_fields.source { |
| 343 | add_bound_if_type_parameter_used_in_type( |
| 344 | &mut parsed_fields.bounds, |
| 345 | type_params, |
| 346 | &state.fields[source].ty, |
| 347 | ); |
| 348 | } |
| 349 | |
| 350 | Ok(parsed_fields) |
| 351 | } |
| 352 | |
| 353 | /// Checks if `ty` is [`syn::Type::Path`] and ends with segment matching `tail` |
| 354 | /// and doesn't contain any generic parameters. |
| 355 | fn is_type_path_ends_with_segment(ty: &syn::Type, tail: &str) -> bool { |
| 356 | let syn::Type::Path(ty: &TypePath) = ty else { |
| 357 | return false; |
| 358 | }; |
| 359 | |
| 360 | // Unwrapping is safe, cause 'syn::TypePath.path.segments' |
| 361 | // have to have at least one segment |
| 362 | let segment: &PathSegment = ty.path.segments.last().unwrap(); |
| 363 | |
| 364 | if !matches!(segment.arguments, syn::PathArguments::None) { |
| 365 | return false; |
| 366 | } |
| 367 | |
| 368 | segment.ident == tail |
| 369 | } |
| 370 | |
| 371 | fn infer_source_field( |
| 372 | fields: &[&syn::Field], |
| 373 | parsed_fields: &ParsedFields, |
| 374 | ) -> Option<usize> { |
| 375 | // if we have exactly two fields |
| 376 | if fields.len() != 2 { |
| 377 | return None; |
| 378 | } |
| 379 | |
| 380 | // no source field was specified/inferred |
| 381 | if parsed_fields.source.is_some() { |
| 382 | return None; |
| 383 | } |
| 384 | |
| 385 | // but one of the fields was specified/inferred as backtrace field |
| 386 | if let Some(backtrace: usize) = parsed_fields.backtrace { |
| 387 | // then infer *other field* as source field |
| 388 | let source: usize = (backtrace + 1) % 2; |
| 389 | // unless it was explicitly marked as non-source |
| 390 | if parsed_fields.data.infos[source].info.source != Some(false) { |
| 391 | return Some(source); |
| 392 | } |
| 393 | } |
| 394 | |
| 395 | None |
| 396 | } |
| 397 | |
| 398 | fn parse_fields_impl<'input, 'state, P>( |
| 399 | state: &'state State<'input>, |
| 400 | is_valid_default_field_for_attr: P, |
| 401 | ) -> Result<ParsedFields<'input, 'state>> |
| 402 | where |
| 403 | P: Fn(&str, &syn::Field, usize) -> bool, |
| 404 | { |
| 405 | let MultiFieldData { fields, infos, .. } = state.enabled_fields_data(); |
| 406 | |
| 407 | let iter = fields |
| 408 | .iter() |
| 409 | .zip(infos.iter().map(|info| &info.info)) |
| 410 | .enumerate() |
| 411 | .map(|(index, (field, info))| (index, *field, info)); |
| 412 | |
| 413 | let source = parse_field_impl( |
| 414 | &is_valid_default_field_for_attr, |
| 415 | state.fields.len(), |
| 416 | iter.clone(), |
| 417 | "source" , |
| 418 | |info| info.source, |
| 419 | )?; |
| 420 | |
| 421 | let backtrace = parse_field_impl( |
| 422 | &is_valid_default_field_for_attr, |
| 423 | state.fields.len(), |
| 424 | iter.clone(), |
| 425 | "backtrace" , |
| 426 | |info| info.backtrace, |
| 427 | )?; |
| 428 | |
| 429 | let mut parsed_fields = ParsedFields::new(state.enabled_fields_data()); |
| 430 | |
| 431 | if let Some((index, _, _)) = source { |
| 432 | parsed_fields.source = Some(index); |
| 433 | } |
| 434 | |
| 435 | if let Some((index, _, _)) = backtrace { |
| 436 | parsed_fields.backtrace = Some(index); |
| 437 | } |
| 438 | |
| 439 | Ok(parsed_fields) |
| 440 | } |
| 441 | |
| 442 | fn parse_field_impl<'a, P, V>( |
| 443 | is_valid_default_field_for_attr: &P, |
| 444 | len: usize, |
| 445 | iter: impl Iterator<Item = (usize, &'a syn::Field, &'a MetaInfo)> + Clone, |
| 446 | attr: &str, |
| 447 | value: V, |
| 448 | ) -> Result<Option<(usize, &'a syn::Field, &'a MetaInfo)>> |
| 449 | where |
| 450 | P: Fn(&str, &syn::Field, usize) -> bool, |
| 451 | V: Fn(&MetaInfo) -> Option<bool>, |
| 452 | { |
| 453 | let explicit_fields = iter |
| 454 | .clone() |
| 455 | .filter(|(_, _, info)| matches!(value(info), Some(true))); |
| 456 | |
| 457 | let inferred_fields = iter.filter(|(_, field, info)| match value(info) { |
| 458 | None => is_valid_default_field_for_attr(attr, field, len), |
| 459 | _ => false, |
| 460 | }); |
| 461 | |
| 462 | let field = assert_iter_contains_zero_or_one_item( |
| 463 | explicit_fields, |
| 464 | &format!( |
| 465 | "Multiple ` {attr}` attributes specified. \ |
| 466 | Single attribute per struct/enum variant allowed." , |
| 467 | ), |
| 468 | )?; |
| 469 | |
| 470 | let field = match field { |
| 471 | field @ Some(_) => field, |
| 472 | None => assert_iter_contains_zero_or_one_item( |
| 473 | inferred_fields, |
| 474 | "Conflicting fields found. Consider specifying some \ |
| 475 | `#[error(...)]` attributes to resolve conflict." , |
| 476 | )?, |
| 477 | }; |
| 478 | |
| 479 | Ok(field) |
| 480 | } |
| 481 | |
| 482 | fn assert_iter_contains_zero_or_one_item<'a>( |
| 483 | mut iter: impl Iterator<Item = (usize, &'a syn::Field, &'a MetaInfo)>, |
| 484 | error_msg: &str, |
| 485 | ) -> Result<Option<(usize, &'a syn::Field, &'a MetaInfo)>> { |
| 486 | let Some(item: (usize, &'a Field, &'a MetaInfo)) = iter.next() else { |
| 487 | return Ok(None); |
| 488 | }; |
| 489 | |
| 490 | if let Some((_, field: &'a Field, _)) = iter.next() { |
| 491 | return Err(Error::new(field.span(), message:error_msg)); |
| 492 | } |
| 493 | |
| 494 | Ok(Some(item)) |
| 495 | } |
| 496 | |
| 497 | fn add_bound_if_type_parameter_used_in_type( |
| 498 | bounds: &mut HashSet<syn::Type>, |
| 499 | type_params: &HashSet<syn::Ident>, |
| 500 | ty: &syn::Type, |
| 501 | ) { |
| 502 | if let Some(ty: Type) = utils::get_if_type_parameter_used_in_type(type_parameters:type_params, ty) { |
| 503 | bounds.insert(ty); |
| 504 | } |
| 505 | } |
| 506 | |