1 | // SPDX-License-Identifier: Apache-2.0 OR MIT |
2 | |
3 | use core::mem; |
4 | use std::borrow::Cow; |
5 | |
6 | use proc_macro2::{TokenStream, TokenTree}; |
7 | use quote::{quote, ToTokens as _}; |
8 | use syn::{ |
9 | parse_quote, token, Block, FnArg, GenericParam, Generics, Ident, ImplItem, ImplItemFn, |
10 | ItemImpl, ItemTrait, Path, Signature, Stmt, Token, TraitItem, TraitItemFn, TraitItemType, Type, |
11 | TypeParamBound, TypePath, Visibility, WherePredicate, |
12 | }; |
13 | |
14 | use crate::ast::EnumData; |
15 | |
16 | /// A function for creating `proc_macro_derive` like deriving trait to enum so |
17 | /// long as all variants are implemented that trait. |
18 | /// |
19 | /// # Examples |
20 | /// |
21 | /// ``` |
22 | /// # extern crate proc_macro; |
23 | /// use derive_utils::derive_trait; |
24 | /// use proc_macro::TokenStream; |
25 | /// use quote::format_ident; |
26 | /// use syn::{parse_macro_input, parse_quote}; |
27 | /// |
28 | /// # #[cfg (any())] |
29 | /// #[proc_macro_derive(Iterator)] |
30 | /// # pub fn _derive_iterator(_: TokenStream) -> TokenStream { unimplemented!() } |
31 | /// pub fn derive_iterator(input: TokenStream) -> TokenStream { |
32 | /// derive_trait( |
33 | /// &parse_macro_input!(input), |
34 | /// // trait path |
35 | /// &parse_quote!(std::iter::Iterator), |
36 | /// // super trait's associated types |
37 | /// None, |
38 | /// // trait definition |
39 | /// parse_quote! { |
40 | /// trait Iterator { |
41 | /// type Item; |
42 | /// fn next(&mut self) -> Option<Self::Item>; |
43 | /// fn size_hint(&self) -> (usize, Option<usize>); |
44 | /// } |
45 | /// }, |
46 | /// ) |
47 | /// .into() |
48 | /// } |
49 | /// |
50 | /// # #[cfg (any())] |
51 | /// #[proc_macro_derive(ExactSizeIterator)] |
52 | /// # pub fn _derive_exact_size_iterator(_: TokenStream) -> TokenStream { unimplemented!() } |
53 | /// pub fn derive_exact_size_iterator(input: TokenStream) -> TokenStream { |
54 | /// derive_trait( |
55 | /// &parse_macro_input!(input), |
56 | /// // trait path |
57 | /// &parse_quote!(std::iter::ExactSizeIterator), |
58 | /// // super trait's associated types |
59 | /// Some(format_ident!("Item" )), |
60 | /// // trait definition |
61 | /// parse_quote! { |
62 | /// trait ExactSizeIterator: Iterator { |
63 | /// fn len(&self) -> usize; |
64 | /// } |
65 | /// }, |
66 | /// ) |
67 | /// .into() |
68 | /// } |
69 | /// ``` |
70 | pub fn derive_trait<I>( |
71 | data: &EnumData, |
72 | trait_path: &Path, |
73 | supertraits_types: I, |
74 | trait_def: ItemTrait, |
75 | ) -> TokenStream |
76 | where |
77 | I: IntoIterator<Item = Ident>, |
78 | I::IntoIter: ExactSizeIterator, |
79 | { |
80 | EnumImpl::from_trait(data, trait_path, supertraits_types, trait_def).build() |
81 | } |
82 | |
83 | /// A builder for implementing a trait for enums. |
84 | pub struct EnumImpl<'a> { |
85 | data: &'a EnumData, |
86 | defaultness: bool, |
87 | unsafety: bool, |
88 | generics: Generics, |
89 | trait_: Option<Path>, |
90 | self_ty: Box<Type>, |
91 | items: Vec<ImplItem>, |
92 | } |
93 | |
94 | impl<'a> EnumImpl<'a> { |
95 | /// Creates a new `EnumImpl`. |
96 | pub fn new(data: &'a EnumData) -> Self { |
97 | let ident = &data.ident; |
98 | let ty_generics = data.generics.split_for_impl().1; |
99 | Self { |
100 | data, |
101 | defaultness: false, |
102 | unsafety: false, |
103 | generics: data.generics.clone(), |
104 | trait_: None, |
105 | self_ty: Box::new(parse_quote!(#ident #ty_generics)), |
106 | items: vec![], |
107 | } |
108 | } |
109 | |
110 | /// Creates a new `EnumImpl` from a trait definition. |
111 | /// |
112 | /// The following items are ignored: |
113 | /// - Generic associated types (GAT) ([`TraitItem::Type`] that has generics) |
114 | /// - [`TraitItem::Const`] |
115 | /// - [`TraitItem::Macro`] |
116 | /// - [`TraitItem::Verbatim`] |
117 | /// |
118 | /// # Panics |
119 | /// |
120 | /// Panics if a trait method has a body, no receiver, or a receiver other |
121 | /// than the following: |
122 | /// |
123 | /// - `&self` |
124 | /// - `&mut self` |
125 | /// - `self` |
126 | pub fn from_trait<I>( |
127 | data: &'a EnumData, |
128 | trait_path: &Path, |
129 | supertraits_types: I, |
130 | mut trait_def: ItemTrait, |
131 | ) -> Self |
132 | where |
133 | I: IntoIterator<Item = Ident>, |
134 | I::IntoIter: ExactSizeIterator, |
135 | { |
136 | let mut generics = data.generics.clone(); |
137 | let trait_ = { |
138 | if trait_def.generics.params.is_empty() { |
139 | trait_path.clone() |
140 | } else { |
141 | let ty_generics = trait_def.generics.split_for_impl().1; |
142 | parse_quote!(#trait_path #ty_generics) |
143 | } |
144 | }; |
145 | |
146 | let fst = data.field_types().next().unwrap(); |
147 | let mut types: Vec<_> = trait_def |
148 | .items |
149 | .iter() |
150 | .filter_map(|item| match item { |
151 | TraitItem::Type(ty) => Some((false, Cow::Borrowed(&ty.ident))), |
152 | _ => None, |
153 | }) |
154 | .collect(); |
155 | |
156 | let supertraits_types = supertraits_types.into_iter(); |
157 | if supertraits_types.len() > 0 { |
158 | if let Some(TypeParamBound::Trait(_)) = trait_def.supertraits.iter().next() { |
159 | types.extend(supertraits_types.map(|ident| (true, Cow::Owned(ident)))); |
160 | } |
161 | } |
162 | |
163 | // https://github.com/taiki-e/derive_utils/issues/47 |
164 | let type_params = generics.type_params().map(|p| p.ident.to_string()).collect::<Vec<_>>(); |
165 | let has_method = trait_def.items.iter().any(|i| matches!(i, TraitItem::Fn(..))); |
166 | if !has_method || !type_params.is_empty() { |
167 | struct HasTypeParam<'a>(&'a [String]); |
168 | |
169 | impl HasTypeParam<'_> { |
170 | fn check_ident(&self, ident: &Ident) -> bool { |
171 | let ident = ident.to_string(); |
172 | self.0.contains(&ident) |
173 | } |
174 | |
175 | fn visit_type(&self, ty: &Type) -> bool { |
176 | if let Type::Path(node) = ty { |
177 | if node.qself.is_none() { |
178 | if let Some(ident) = node.path.get_ident() { |
179 | return self.check_ident(ident); |
180 | } |
181 | } |
182 | } |
183 | self.visit_token_stream(ty.to_token_stream()) |
184 | } |
185 | |
186 | fn visit_token_stream(&self, tokens: TokenStream) -> bool { |
187 | for tt in tokens { |
188 | match tt { |
189 | TokenTree::Ident(ident) => { |
190 | if self.check_ident(&ident) { |
191 | return true; |
192 | } |
193 | } |
194 | TokenTree::Group(group) => { |
195 | let content = group.stream(); |
196 | if self.visit_token_stream(content) { |
197 | return true; |
198 | } |
199 | } |
200 | _ => {} |
201 | } |
202 | } |
203 | false |
204 | } |
205 | } |
206 | |
207 | let visitor = HasTypeParam(&type_params); |
208 | let where_clause = &mut generics.make_where_clause().predicates; |
209 | if !has_method || visitor.visit_type(fst) { |
210 | where_clause.push(parse_quote!(#fst: #trait_)); |
211 | } |
212 | if data.field_types().len() > 1 { |
213 | let fst_tokens = fst.to_token_stream().to_string(); |
214 | where_clause.extend(data.field_types().skip(1).filter_map( |
215 | |variant| -> Option<WherePredicate> { |
216 | if has_method && !visitor.visit_type(variant) { |
217 | return None; |
218 | } |
219 | if variant.to_token_stream().to_string() == fst_tokens { |
220 | return None; |
221 | } |
222 | if types.is_empty() { |
223 | return Some(parse_quote!(#variant: #trait_)); |
224 | } |
225 | let types = types.iter().map(|(supertraits, ident)| { |
226 | match trait_def.supertraits.iter().next() { |
227 | Some(TypeParamBound::Trait(trait_)) if *supertraits => { |
228 | quote!(#ident = <#fst as #trait_>::#ident) |
229 | } |
230 | _ => quote!(#ident = <#fst as #trait_>::#ident), |
231 | } |
232 | }); |
233 | if trait_def.generics.params.is_empty() { |
234 | Some(parse_quote!(#variant: #trait_path<#(#types),*>)) |
235 | } else { |
236 | let generics = |
237 | trait_def.generics.params.iter().map(|param| match param { |
238 | GenericParam::Lifetime(def) => def.lifetime.to_token_stream(), |
239 | GenericParam::Type(param) => param.ident.to_token_stream(), |
240 | GenericParam::Const(param) => param.ident.to_token_stream(), |
241 | }); |
242 | Some(parse_quote!(#variant: #trait_path<#(#generics),*, #(#types),*>)) |
243 | } |
244 | }, |
245 | )); |
246 | } |
247 | } |
248 | |
249 | if !trait_def.generics.params.is_empty() { |
250 | generics.params.extend(mem::take(&mut trait_def.generics.params)); |
251 | } |
252 | |
253 | if let Some(old) = trait_def.generics.where_clause.as_mut() { |
254 | if !old.predicates.is_empty() { |
255 | generics.make_where_clause().predicates.extend(mem::take(&mut old.predicates)); |
256 | } |
257 | } |
258 | |
259 | let ident = &data.ident; |
260 | let ty_generics = data.generics.split_for_impl().1; |
261 | let mut impls = Self { |
262 | data, |
263 | defaultness: false, |
264 | unsafety: trait_def.unsafety.is_some(), |
265 | generics, |
266 | trait_: Some(trait_), |
267 | self_ty: Box::new(parse_quote!(#ident #ty_generics)), |
268 | items: Vec::with_capacity(trait_def.items.len()), |
269 | }; |
270 | impls.append_items_from_trait(trait_def); |
271 | impls |
272 | } |
273 | |
274 | pub fn set_trait(&mut self, path: Path) { |
275 | self.trait_ = Some(path); |
276 | } |
277 | |
278 | /// Appends a generic type parameter to the back of generics. |
279 | pub fn push_generic_param(&mut self, param: GenericParam) { |
280 | self.generics.params.push(param); |
281 | } |
282 | |
283 | /// Appends a predicate to the back of `where`-clause. |
284 | pub fn push_where_predicate(&mut self, predicate: WherePredicate) { |
285 | self.generics.make_where_clause().predicates.push(predicate); |
286 | } |
287 | |
288 | /// Appends an item to impl items. |
289 | pub fn push_item(&mut self, item: ImplItem) { |
290 | self.items.push(item); |
291 | } |
292 | |
293 | /// Appends a method to impl items. |
294 | /// |
295 | /// # Panics |
296 | /// |
297 | /// Panics if a trait method has a body, no receiver, or a receiver other |
298 | /// than the following: |
299 | /// |
300 | /// - `&self` |
301 | /// - `&mut self` |
302 | /// - `self` |
303 | pub fn push_method(&mut self, item: TraitItemFn) { |
304 | assert!(item.default.is_none(), "method ` {}` has a body" , item.sig.ident); |
305 | |
306 | let self_ty = ReceiverKind::new(&item.sig); |
307 | let mut args = Vec::with_capacity(item.sig.inputs.len()); |
308 | item.sig.inputs.iter().skip(1).for_each(|arg| match arg { |
309 | FnArg::Typed(arg) => args.push(&arg.pat), |
310 | FnArg::Receiver(_) => panic!( |
311 | "method ` {}` has a receiver in a position other than the first argument" , |
312 | item.sig.ident |
313 | ), |
314 | }); |
315 | |
316 | let method = &item.sig.ident; |
317 | let ident = &self.data.ident; |
318 | let method = match self_ty { |
319 | ReceiverKind::Normal => match &self.trait_ { |
320 | None => { |
321 | let arms = self.data.variant_idents().map(|v| { |
322 | quote! { |
323 | #ident::#v(x) => x.#method(#(#args),*), |
324 | } |
325 | }); |
326 | parse_quote!(match self { #(#arms)* }) |
327 | } |
328 | Some(trait_) => { |
329 | let arms = |
330 | self.data.variant_idents().zip(self.data.field_types()).map(|(v, ty)| { |
331 | quote! { |
332 | #ident::#v(x) => <#ty as #trait_>::#method(x #(,#args)*), |
333 | } |
334 | }); |
335 | parse_quote!(match self { #(#arms)* }) |
336 | } |
337 | }, |
338 | }; |
339 | |
340 | self.push_item(ImplItem::Fn(ImplItemFn { |
341 | attrs: item.attrs, |
342 | vis: Visibility::Inherited, |
343 | defaultness: None, |
344 | sig: item.sig, |
345 | block: Block { |
346 | brace_token: token::Brace::default(), |
347 | stmts: vec![Stmt::Expr(method, None)], |
348 | }, |
349 | })); |
350 | } |
351 | |
352 | /// Appends items from a trait definition to impl items. |
353 | /// |
354 | /// # Panics |
355 | /// |
356 | /// Panics if a trait method has a body, no receiver, or a receiver other |
357 | /// than the following: |
358 | /// |
359 | /// - `&self` |
360 | /// - `&mut self` |
361 | /// - `self` |
362 | pub fn append_items_from_trait(&mut self, trait_def: ItemTrait) { |
363 | let fst = self.data.field_types().next(); |
364 | trait_def.items.into_iter().for_each(|item| match item { |
365 | // The TraitItemType::generics field (Generic associated types (GAT)) are not supported |
366 | TraitItem::Type(TraitItemType { ident, .. }) => { |
367 | let trait_ = &self.trait_; |
368 | let ty = parse_quote!(type #ident = <#fst as #trait_>::#ident;); |
369 | self.push_item(ImplItem::Type(ty)); |
370 | } |
371 | TraitItem::Fn(method) => self.push_method(method), |
372 | _ => {} |
373 | }); |
374 | } |
375 | |
376 | pub fn build(self) -> TokenStream { |
377 | self.build_impl().to_token_stream() |
378 | } |
379 | |
380 | pub fn build_impl(self) -> ItemImpl { |
381 | ItemImpl { |
382 | attrs: vec![parse_quote!(#[automatically_derived])], |
383 | defaultness: if self.defaultness { Some(<Token![default]>::default()) } else { None }, |
384 | unsafety: if self.unsafety { Some(<Token![unsafe]>::default()) } else { None }, |
385 | impl_token: token::Impl::default(), |
386 | generics: self.generics, |
387 | trait_: self.trait_.map(|trait_| (None, trait_, <Token![for]>::default())), |
388 | self_ty: self.self_ty, |
389 | brace_token: token::Brace::default(), |
390 | items: self.items, |
391 | } |
392 | } |
393 | } |
394 | |
395 | enum ReceiverKind { |
396 | /// `&(mut) self`, `(mut) self`, `(mut) self: &(mut) Self`, or `(mut) self: Self` |
397 | Normal, |
398 | } |
399 | |
400 | impl ReceiverKind { |
401 | fn new(sig: &Signature) -> Self { |
402 | fn get_ty_path(ty: &Type) -> Option<&Path> { |
403 | if let Type::Path(TypePath { qself: None, path }) = ty { |
404 | Some(path) |
405 | } else { |
406 | None |
407 | } |
408 | } |
409 | |
410 | match sig.receiver() { |
411 | None => panic!("method ` {}` has no receiver" , sig.ident), |
412 | Some(receiver) => { |
413 | if receiver.colon_token.is_none() { |
414 | return ReceiverKind::Normal; |
415 | } |
416 | match &*receiver.ty { |
417 | Type::Path(TypePath { qself: None, path }) => { |
418 | // (mut) self: Self |
419 | if path.is_ident("Self" ) { |
420 | return ReceiverKind::Normal; |
421 | } |
422 | } |
423 | Type::Reference(ty) => { |
424 | // (mut) self: &(mut) Self |
425 | if get_ty_path(&ty.elem).map_or(false, |path| path.is_ident("Self" )) { |
426 | return ReceiverKind::Normal; |
427 | } |
428 | } |
429 | _ => {} |
430 | } |
431 | panic!( |
432 | "method ` {}` has unsupported receiver type: {}" , |
433 | sig.ident, |
434 | receiver.ty.to_token_stream() |
435 | ); |
436 | } |
437 | } |
438 | } |
439 | } |
440 | |