| 1 | // SPDX-License-Identifier: Apache-2.0 OR MIT |
| 2 | |
| 3 | use core::mem; |
| 4 | use std::borrow::Cow; |
| 5 | |
| 6 | use proc_macro2::{TokenStream, TokenTree}; |
| 7 | use quote::{quote, ToTokens as _}; |
| 8 | use syn::{ |
| 9 | parse_quote, token, Block, FnArg, GenericParam, Generics, Ident, ImplItem, ImplItemFn, |
| 10 | ItemImpl, ItemTrait, Path, Signature, Stmt, Token, TraitItem, TraitItemFn, TraitItemType, Type, |
| 11 | TypeParamBound, TypePath, Visibility, WherePredicate, |
| 12 | }; |
| 13 | |
| 14 | use crate::ast::EnumData; |
| 15 | |
| 16 | /// A function for creating `proc_macro_derive` like deriving trait to enum so |
| 17 | /// long as all variants are implemented that trait. |
| 18 | /// |
| 19 | /// # Examples |
| 20 | /// |
| 21 | /// ``` |
| 22 | /// # extern crate proc_macro; |
| 23 | /// use derive_utils::derive_trait; |
| 24 | /// use proc_macro::TokenStream; |
| 25 | /// use quote::format_ident; |
| 26 | /// use syn::{parse_macro_input, parse_quote}; |
| 27 | /// |
| 28 | /// # #[cfg (any())] |
| 29 | /// #[proc_macro_derive(Iterator)] |
| 30 | /// # pub fn _derive_iterator(_: TokenStream) -> TokenStream { unimplemented!() } |
| 31 | /// pub fn derive_iterator(input: TokenStream) -> TokenStream { |
| 32 | /// derive_trait( |
| 33 | /// &parse_macro_input!(input), |
| 34 | /// // trait path |
| 35 | /// &parse_quote!(std::iter::Iterator), |
| 36 | /// // super trait's associated types |
| 37 | /// None, |
| 38 | /// // trait definition |
| 39 | /// parse_quote! { |
| 40 | /// trait Iterator { |
| 41 | /// type Item; |
| 42 | /// fn next(&mut self) -> Option<Self::Item>; |
| 43 | /// fn size_hint(&self) -> (usize, Option<usize>); |
| 44 | /// } |
| 45 | /// }, |
| 46 | /// ) |
| 47 | /// .into() |
| 48 | /// } |
| 49 | /// |
| 50 | /// # #[cfg (any())] |
| 51 | /// #[proc_macro_derive(ExactSizeIterator)] |
| 52 | /// # pub fn _derive_exact_size_iterator(_: TokenStream) -> TokenStream { unimplemented!() } |
| 53 | /// pub fn derive_exact_size_iterator(input: TokenStream) -> TokenStream { |
| 54 | /// derive_trait( |
| 55 | /// &parse_macro_input!(input), |
| 56 | /// // trait path |
| 57 | /// &parse_quote!(std::iter::ExactSizeIterator), |
| 58 | /// // super trait's associated types |
| 59 | /// Some(format_ident!("Item" )), |
| 60 | /// // trait definition |
| 61 | /// parse_quote! { |
| 62 | /// trait ExactSizeIterator: Iterator { |
| 63 | /// fn len(&self) -> usize; |
| 64 | /// } |
| 65 | /// }, |
| 66 | /// ) |
| 67 | /// .into() |
| 68 | /// } |
| 69 | /// ``` |
| 70 | pub fn derive_trait<I>( |
| 71 | data: &EnumData, |
| 72 | trait_path: &Path, |
| 73 | supertraits_types: I, |
| 74 | trait_def: ItemTrait, |
| 75 | ) -> TokenStream |
| 76 | where |
| 77 | I: IntoIterator<Item = Ident>, |
| 78 | I::IntoIter: ExactSizeIterator, |
| 79 | { |
| 80 | EnumImpl::from_trait(data, trait_path, supertraits_types, trait_def).build() |
| 81 | } |
| 82 | |
| 83 | /// A builder for implementing a trait for enums. |
| 84 | pub struct EnumImpl<'a> { |
| 85 | data: &'a EnumData, |
| 86 | defaultness: bool, |
| 87 | unsafety: bool, |
| 88 | generics: Generics, |
| 89 | trait_: Option<Path>, |
| 90 | self_ty: Box<Type>, |
| 91 | items: Vec<ImplItem>, |
| 92 | } |
| 93 | |
| 94 | impl<'a> EnumImpl<'a> { |
| 95 | /// Creates a new `EnumImpl`. |
| 96 | pub fn new(data: &'a EnumData) -> Self { |
| 97 | let ident = &data.ident; |
| 98 | let ty_generics = data.generics.split_for_impl().1; |
| 99 | Self { |
| 100 | data, |
| 101 | defaultness: false, |
| 102 | unsafety: false, |
| 103 | generics: data.generics.clone(), |
| 104 | trait_: None, |
| 105 | self_ty: Box::new(parse_quote!(#ident #ty_generics)), |
| 106 | items: vec![], |
| 107 | } |
| 108 | } |
| 109 | |
| 110 | /// Creates a new `EnumImpl` from a trait definition. |
| 111 | /// |
| 112 | /// The following items are ignored: |
| 113 | /// - Generic associated types (GAT) ([`TraitItem::Type`] that has generics) |
| 114 | /// - [`TraitItem::Const`] |
| 115 | /// - [`TraitItem::Macro`] |
| 116 | /// - [`TraitItem::Verbatim`] |
| 117 | /// |
| 118 | /// # Panics |
| 119 | /// |
| 120 | /// Panics if a trait method has a body, no receiver, or a receiver other |
| 121 | /// than the following: |
| 122 | /// |
| 123 | /// - `&self` |
| 124 | /// - `&mut self` |
| 125 | /// - `self` |
| 126 | pub fn from_trait<I>( |
| 127 | data: &'a EnumData, |
| 128 | trait_path: &Path, |
| 129 | supertraits_types: I, |
| 130 | mut trait_def: ItemTrait, |
| 131 | ) -> Self |
| 132 | where |
| 133 | I: IntoIterator<Item = Ident>, |
| 134 | I::IntoIter: ExactSizeIterator, |
| 135 | { |
| 136 | let mut generics = data.generics.clone(); |
| 137 | let trait_ = { |
| 138 | if trait_def.generics.params.is_empty() { |
| 139 | trait_path.clone() |
| 140 | } else { |
| 141 | let ty_generics = trait_def.generics.split_for_impl().1; |
| 142 | parse_quote!(#trait_path #ty_generics) |
| 143 | } |
| 144 | }; |
| 145 | |
| 146 | let fst = data.field_types().next().unwrap(); |
| 147 | let mut types: Vec<_> = trait_def |
| 148 | .items |
| 149 | .iter() |
| 150 | .filter_map(|item| match item { |
| 151 | TraitItem::Type(ty) => Some((false, Cow::Borrowed(&ty.ident))), |
| 152 | _ => None, |
| 153 | }) |
| 154 | .collect(); |
| 155 | |
| 156 | let supertraits_types = supertraits_types.into_iter(); |
| 157 | if supertraits_types.len() > 0 { |
| 158 | if let Some(TypeParamBound::Trait(_)) = trait_def.supertraits.iter().next() { |
| 159 | types.extend(supertraits_types.map(|ident| (true, Cow::Owned(ident)))); |
| 160 | } |
| 161 | } |
| 162 | |
| 163 | // https://github.com/taiki-e/derive_utils/issues/47 |
| 164 | let type_params = generics.type_params().map(|p| p.ident.to_string()).collect::<Vec<_>>(); |
| 165 | let has_method = trait_def.items.iter().any(|i| matches!(i, TraitItem::Fn(..))); |
| 166 | if !has_method || !type_params.is_empty() { |
| 167 | struct HasTypeParam<'a>(&'a [String]); |
| 168 | |
| 169 | impl HasTypeParam<'_> { |
| 170 | fn check_ident(&self, ident: &Ident) -> bool { |
| 171 | let ident = ident.to_string(); |
| 172 | self.0.contains(&ident) |
| 173 | } |
| 174 | |
| 175 | fn visit_type(&self, ty: &Type) -> bool { |
| 176 | if let Type::Path(node) = ty { |
| 177 | if node.qself.is_none() { |
| 178 | if let Some(ident) = node.path.get_ident() { |
| 179 | return self.check_ident(ident); |
| 180 | } |
| 181 | } |
| 182 | } |
| 183 | self.visit_token_stream(ty.to_token_stream()) |
| 184 | } |
| 185 | |
| 186 | fn visit_token_stream(&self, tokens: TokenStream) -> bool { |
| 187 | for tt in tokens { |
| 188 | match tt { |
| 189 | TokenTree::Ident(ident) => { |
| 190 | if self.check_ident(&ident) { |
| 191 | return true; |
| 192 | } |
| 193 | } |
| 194 | TokenTree::Group(group) => { |
| 195 | let content = group.stream(); |
| 196 | if self.visit_token_stream(content) { |
| 197 | return true; |
| 198 | } |
| 199 | } |
| 200 | _ => {} |
| 201 | } |
| 202 | } |
| 203 | false |
| 204 | } |
| 205 | } |
| 206 | |
| 207 | let visitor = HasTypeParam(&type_params); |
| 208 | let where_clause = &mut generics.make_where_clause().predicates; |
| 209 | if !has_method || visitor.visit_type(fst) { |
| 210 | where_clause.push(parse_quote!(#fst: #trait_)); |
| 211 | } |
| 212 | if data.field_types().len() > 1 { |
| 213 | let fst_tokens = fst.to_token_stream().to_string(); |
| 214 | where_clause.extend(data.field_types().skip(1).filter_map( |
| 215 | |variant| -> Option<WherePredicate> { |
| 216 | if has_method && !visitor.visit_type(variant) { |
| 217 | return None; |
| 218 | } |
| 219 | if variant.to_token_stream().to_string() == fst_tokens { |
| 220 | return None; |
| 221 | } |
| 222 | if types.is_empty() { |
| 223 | return Some(parse_quote!(#variant: #trait_)); |
| 224 | } |
| 225 | let types = types.iter().map(|(supertraits, ident)| { |
| 226 | match trait_def.supertraits.iter().next() { |
| 227 | Some(TypeParamBound::Trait(trait_)) if *supertraits => { |
| 228 | quote!(#ident = <#fst as #trait_>::#ident) |
| 229 | } |
| 230 | _ => quote!(#ident = <#fst as #trait_>::#ident), |
| 231 | } |
| 232 | }); |
| 233 | if trait_def.generics.params.is_empty() { |
| 234 | Some(parse_quote!(#variant: #trait_path<#(#types),*>)) |
| 235 | } else { |
| 236 | let generics = |
| 237 | trait_def.generics.params.iter().map(|param| match param { |
| 238 | GenericParam::Lifetime(def) => def.lifetime.to_token_stream(), |
| 239 | GenericParam::Type(param) => param.ident.to_token_stream(), |
| 240 | GenericParam::Const(param) => param.ident.to_token_stream(), |
| 241 | }); |
| 242 | Some(parse_quote!(#variant: #trait_path<#(#generics),*, #(#types),*>)) |
| 243 | } |
| 244 | }, |
| 245 | )); |
| 246 | } |
| 247 | } |
| 248 | |
| 249 | if !trait_def.generics.params.is_empty() { |
| 250 | generics.params.extend(mem::take(&mut trait_def.generics.params)); |
| 251 | } |
| 252 | |
| 253 | if let Some(old) = trait_def.generics.where_clause.as_mut() { |
| 254 | if !old.predicates.is_empty() { |
| 255 | generics.make_where_clause().predicates.extend(mem::take(&mut old.predicates)); |
| 256 | } |
| 257 | } |
| 258 | |
| 259 | let ident = &data.ident; |
| 260 | let ty_generics = data.generics.split_for_impl().1; |
| 261 | let mut impls = Self { |
| 262 | data, |
| 263 | defaultness: false, |
| 264 | unsafety: trait_def.unsafety.is_some(), |
| 265 | generics, |
| 266 | trait_: Some(trait_), |
| 267 | self_ty: Box::new(parse_quote!(#ident #ty_generics)), |
| 268 | items: Vec::with_capacity(trait_def.items.len()), |
| 269 | }; |
| 270 | impls.append_items_from_trait(trait_def); |
| 271 | impls |
| 272 | } |
| 273 | |
| 274 | pub fn set_trait(&mut self, path: Path) { |
| 275 | self.trait_ = Some(path); |
| 276 | } |
| 277 | |
| 278 | /// Appends a generic type parameter to the back of generics. |
| 279 | pub fn push_generic_param(&mut self, param: GenericParam) { |
| 280 | self.generics.params.push(param); |
| 281 | } |
| 282 | |
| 283 | /// Appends a predicate to the back of `where`-clause. |
| 284 | pub fn push_where_predicate(&mut self, predicate: WherePredicate) { |
| 285 | self.generics.make_where_clause().predicates.push(predicate); |
| 286 | } |
| 287 | |
| 288 | /// Appends an item to impl items. |
| 289 | pub fn push_item(&mut self, item: ImplItem) { |
| 290 | self.items.push(item); |
| 291 | } |
| 292 | |
| 293 | /// Appends a method to impl items. |
| 294 | /// |
| 295 | /// # Panics |
| 296 | /// |
| 297 | /// Panics if a trait method has a body, no receiver, or a receiver other |
| 298 | /// than the following: |
| 299 | /// |
| 300 | /// - `&self` |
| 301 | /// - `&mut self` |
| 302 | /// - `self` |
| 303 | pub fn push_method(&mut self, item: TraitItemFn) { |
| 304 | assert!(item.default.is_none(), "method ` {}` has a body" , item.sig.ident); |
| 305 | |
| 306 | let self_ty = ReceiverKind::new(&item.sig); |
| 307 | let mut args = Vec::with_capacity(item.sig.inputs.len()); |
| 308 | item.sig.inputs.iter().skip(1).for_each(|arg| match arg { |
| 309 | FnArg::Typed(arg) => args.push(&arg.pat), |
| 310 | FnArg::Receiver(_) => panic!( |
| 311 | "method ` {}` has a receiver in a position other than the first argument" , |
| 312 | item.sig.ident |
| 313 | ), |
| 314 | }); |
| 315 | |
| 316 | let method = &item.sig.ident; |
| 317 | let ident = &self.data.ident; |
| 318 | let method = match self_ty { |
| 319 | ReceiverKind::Normal => match &self.trait_ { |
| 320 | None => { |
| 321 | let arms = self.data.variant_idents().map(|v| { |
| 322 | quote! { |
| 323 | #ident::#v(x) => x.#method(#(#args),*), |
| 324 | } |
| 325 | }); |
| 326 | parse_quote!(match self { #(#arms)* }) |
| 327 | } |
| 328 | Some(trait_) => { |
| 329 | let arms = |
| 330 | self.data.variant_idents().zip(self.data.field_types()).map(|(v, ty)| { |
| 331 | quote! { |
| 332 | #ident::#v(x) => <#ty as #trait_>::#method(x #(,#args)*), |
| 333 | } |
| 334 | }); |
| 335 | parse_quote!(match self { #(#arms)* }) |
| 336 | } |
| 337 | }, |
| 338 | }; |
| 339 | |
| 340 | self.push_item(ImplItem::Fn(ImplItemFn { |
| 341 | attrs: item.attrs, |
| 342 | vis: Visibility::Inherited, |
| 343 | defaultness: None, |
| 344 | sig: item.sig, |
| 345 | block: Block { |
| 346 | brace_token: token::Brace::default(), |
| 347 | stmts: vec![Stmt::Expr(method, None)], |
| 348 | }, |
| 349 | })); |
| 350 | } |
| 351 | |
| 352 | /// Appends items from a trait definition to impl items. |
| 353 | /// |
| 354 | /// # Panics |
| 355 | /// |
| 356 | /// Panics if a trait method has a body, no receiver, or a receiver other |
| 357 | /// than the following: |
| 358 | /// |
| 359 | /// - `&self` |
| 360 | /// - `&mut self` |
| 361 | /// - `self` |
| 362 | pub fn append_items_from_trait(&mut self, trait_def: ItemTrait) { |
| 363 | let fst = self.data.field_types().next(); |
| 364 | trait_def.items.into_iter().for_each(|item| match item { |
| 365 | // The TraitItemType::generics field (Generic associated types (GAT)) are not supported |
| 366 | TraitItem::Type(TraitItemType { ident, .. }) => { |
| 367 | let trait_ = &self.trait_; |
| 368 | let ty = parse_quote!(type #ident = <#fst as #trait_>::#ident;); |
| 369 | self.push_item(ImplItem::Type(ty)); |
| 370 | } |
| 371 | TraitItem::Fn(method) => self.push_method(method), |
| 372 | _ => {} |
| 373 | }); |
| 374 | } |
| 375 | |
| 376 | pub fn build(self) -> TokenStream { |
| 377 | self.build_impl().to_token_stream() |
| 378 | } |
| 379 | |
| 380 | pub fn build_impl(self) -> ItemImpl { |
| 381 | ItemImpl { |
| 382 | attrs: vec![parse_quote!(#[automatically_derived])], |
| 383 | defaultness: if self.defaultness { Some(<Token![default]>::default()) } else { None }, |
| 384 | unsafety: if self.unsafety { Some(<Token![unsafe]>::default()) } else { None }, |
| 385 | impl_token: token::Impl::default(), |
| 386 | generics: self.generics, |
| 387 | trait_: self.trait_.map(|trait_| (None, trait_, <Token![for]>::default())), |
| 388 | self_ty: self.self_ty, |
| 389 | brace_token: token::Brace::default(), |
| 390 | items: self.items, |
| 391 | } |
| 392 | } |
| 393 | } |
| 394 | |
| 395 | enum ReceiverKind { |
| 396 | /// `&(mut) self`, `(mut) self`, `(mut) self: &(mut) Self`, or `(mut) self: Self` |
| 397 | Normal, |
| 398 | } |
| 399 | |
| 400 | impl ReceiverKind { |
| 401 | fn new(sig: &Signature) -> Self { |
| 402 | fn get_ty_path(ty: &Type) -> Option<&Path> { |
| 403 | if let Type::Path(TypePath { qself: None, path }) = ty { |
| 404 | Some(path) |
| 405 | } else { |
| 406 | None |
| 407 | } |
| 408 | } |
| 409 | |
| 410 | match sig.receiver() { |
| 411 | None => panic!("method ` {}` has no receiver" , sig.ident), |
| 412 | Some(receiver) => { |
| 413 | if receiver.colon_token.is_none() { |
| 414 | return ReceiverKind::Normal; |
| 415 | } |
| 416 | match &*receiver.ty { |
| 417 | Type::Path(TypePath { qself: None, path }) => { |
| 418 | // (mut) self: Self |
| 419 | if path.is_ident("Self" ) { |
| 420 | return ReceiverKind::Normal; |
| 421 | } |
| 422 | } |
| 423 | Type::Reference(ty) => { |
| 424 | // (mut) self: &(mut) Self |
| 425 | if get_ty_path(&ty.elem).map_or(false, |path| path.is_ident("Self" )) { |
| 426 | return ReceiverKind::Normal; |
| 427 | } |
| 428 | } |
| 429 | _ => {} |
| 430 | } |
| 431 | panic!( |
| 432 | "method ` {}` has unsupported receiver type: {}" , |
| 433 | sig.ident, |
| 434 | receiver.ty.to_token_stream() |
| 435 | ); |
| 436 | } |
| 437 | } |
| 438 | } |
| 439 | } |
| 440 | |