1 | // SPDX-License-Identifier: Apache-2.0 OR MIT |
2 | |
3 | use core::mem; |
4 | use std::borrow::Cow; |
5 | |
6 | use proc_macro2::{TokenStream, TokenTree}; |
7 | use quote::{quote, ToTokens}; |
8 | use syn::{ |
9 | parse_quote, token, Block, FnArg, GenericParam, Generics, Ident, ImplItem, ImplItemFn, |
10 | ItemImpl, ItemTrait, Path, Signature, Stmt, Token, TraitItem, TraitItemFn, TraitItemType, Type, |
11 | TypeParamBound, TypePath, Visibility, WherePredicate, |
12 | }; |
13 | |
14 | use crate::ast::EnumData; |
15 | |
16 | /// A function for creating `proc_macro_derive` like deriving trait to enum so |
17 | /// long as all variants are implemented that trait. |
18 | /// |
19 | /// # Examples |
20 | /// |
21 | /// ``` |
22 | /// # extern crate proc_macro; |
23 | /// # |
24 | /// use derive_utils::derive_trait; |
25 | /// use proc_macro::TokenStream; |
26 | /// use quote::format_ident; |
27 | /// use syn::{parse_macro_input, parse_quote}; |
28 | /// |
29 | /// # #[cfg (any())] |
30 | /// #[proc_macro_derive(Iterator)] |
31 | /// # pub fn _derive_iterator(_: TokenStream) -> TokenStream { unimplemented!() } |
32 | /// pub fn derive_iterator(input: TokenStream) -> TokenStream { |
33 | /// derive_trait( |
34 | /// &parse_macro_input!(input), |
35 | /// // trait path |
36 | /// &parse_quote!(std::iter::Iterator), |
37 | /// // super trait's associated types |
38 | /// None, |
39 | /// // trait definition |
40 | /// parse_quote! { |
41 | /// trait Iterator { |
42 | /// type Item; |
43 | /// fn next(&mut self) -> Option<Self::Item>; |
44 | /// fn size_hint(&self) -> (usize, Option<usize>); |
45 | /// } |
46 | /// }, |
47 | /// ) |
48 | /// .into() |
49 | /// } |
50 | /// |
51 | /// # #[cfg (any())] |
52 | /// #[proc_macro_derive(ExactSizeIterator)] |
53 | /// # pub fn _derive_exact_size_iterator(_: TokenStream) -> TokenStream { unimplemented!() } |
54 | /// pub fn derive_exact_size_iterator(input: TokenStream) -> TokenStream { |
55 | /// derive_trait( |
56 | /// &parse_macro_input!(input), |
57 | /// // trait path |
58 | /// &parse_quote!(std::iter::ExactSizeIterator), |
59 | /// // super trait's associated types |
60 | /// Some(format_ident!("Item" )), |
61 | /// // trait definition |
62 | /// parse_quote! { |
63 | /// trait ExactSizeIterator: Iterator { |
64 | /// fn len(&self) -> usize; |
65 | /// } |
66 | /// }, |
67 | /// ) |
68 | /// .into() |
69 | /// } |
70 | /// ``` |
71 | pub fn derive_trait<I>( |
72 | data: &EnumData, |
73 | trait_path: &Path, |
74 | supertraits_types: I, |
75 | trait_def: ItemTrait, |
76 | ) -> TokenStream |
77 | where |
78 | I: IntoIterator<Item = Ident>, |
79 | I::IntoIter: ExactSizeIterator, |
80 | { |
81 | EnumImpl::from_trait(data, trait_path, supertraits_types, trait_def).build() |
82 | } |
83 | |
84 | /// A builder for implementing a trait for enums. |
85 | pub struct EnumImpl<'a> { |
86 | data: &'a EnumData, |
87 | defaultness: bool, |
88 | unsafety: bool, |
89 | generics: Generics, |
90 | trait_: Option<Path>, |
91 | self_ty: Box<Type>, |
92 | items: Vec<ImplItem>, |
93 | } |
94 | |
95 | impl<'a> EnumImpl<'a> { |
96 | /// Creates a new `EnumImpl`. |
97 | pub fn new(data: &'a EnumData) -> Self { |
98 | let ident = &data.ident; |
99 | let ty_generics = data.generics.split_for_impl().1; |
100 | Self { |
101 | data, |
102 | defaultness: false, |
103 | unsafety: false, |
104 | generics: data.generics.clone(), |
105 | trait_: None, |
106 | self_ty: Box::new(parse_quote!(#ident #ty_generics)), |
107 | items: vec![], |
108 | } |
109 | } |
110 | |
111 | /// Creates a new `EnumImpl` from a trait definition. |
112 | /// |
113 | /// The following items are ignored: |
114 | /// - Generic associated types (GAT) ([`TraitItem::Type`] that has generics) |
115 | /// - [`TraitItem::Const`] |
116 | /// - [`TraitItem::Macro`] |
117 | /// - [`TraitItem::Verbatim`] |
118 | /// |
119 | /// # Panics |
120 | /// |
121 | /// Panics if a trait method has a body, no receiver, or a receiver other |
122 | /// than the following: |
123 | /// |
124 | /// - `&self` |
125 | /// - `&mut self` |
126 | /// - `self` |
127 | pub fn from_trait<I>( |
128 | data: &'a EnumData, |
129 | trait_path: &Path, |
130 | supertraits_types: I, |
131 | mut trait_def: ItemTrait, |
132 | ) -> Self |
133 | where |
134 | I: IntoIterator<Item = Ident>, |
135 | I::IntoIter: ExactSizeIterator, |
136 | { |
137 | let mut generics = data.generics.clone(); |
138 | let trait_ = { |
139 | if trait_def.generics.params.is_empty() { |
140 | trait_path.clone() |
141 | } else { |
142 | let ty_generics = trait_def.generics.split_for_impl().1; |
143 | parse_quote!(#trait_path #ty_generics) |
144 | } |
145 | }; |
146 | |
147 | let fst = data.field_types().next().unwrap(); |
148 | let mut types: Vec<_> = trait_def |
149 | .items |
150 | .iter() |
151 | .filter_map(|item| match item { |
152 | TraitItem::Type(ty) => Some((false, Cow::Borrowed(&ty.ident))), |
153 | _ => None, |
154 | }) |
155 | .collect(); |
156 | |
157 | let supertraits_types = supertraits_types.into_iter(); |
158 | if supertraits_types.len() > 0 { |
159 | if let Some(TypeParamBound::Trait(_)) = trait_def.supertraits.iter().next() { |
160 | types.extend(supertraits_types.map(|ident| (true, Cow::Owned(ident)))); |
161 | } |
162 | } |
163 | |
164 | // https://github.com/taiki-e/derive_utils/issues/47 |
165 | let type_params = generics.type_params().map(|p| p.ident.to_string()).collect::<Vec<_>>(); |
166 | let has_method = trait_def.items.iter().any(|i| matches!(i, TraitItem::Fn(..))); |
167 | if !has_method || !type_params.is_empty() { |
168 | struct HasTypeParam<'a>(&'a [String]); |
169 | |
170 | impl HasTypeParam<'_> { |
171 | fn check_ident(&self, ident: &Ident) -> bool { |
172 | let ident = ident.to_string(); |
173 | self.0.contains(&ident) |
174 | } |
175 | |
176 | fn visit_type(&self, ty: &Type) -> bool { |
177 | if let Type::Path(node) = ty { |
178 | if node.qself.is_none() { |
179 | if let Some(ident) = node.path.get_ident() { |
180 | return self.check_ident(ident); |
181 | } |
182 | } |
183 | } |
184 | self.visit_token_stream(ty.to_token_stream()) |
185 | } |
186 | |
187 | fn visit_token_stream(&self, tokens: TokenStream) -> bool { |
188 | for tt in tokens { |
189 | match tt { |
190 | TokenTree::Ident(ident) => { |
191 | if self.check_ident(&ident) { |
192 | return true; |
193 | } |
194 | } |
195 | TokenTree::Group(group) => { |
196 | let content = group.stream(); |
197 | if self.visit_token_stream(content) { |
198 | return true; |
199 | } |
200 | } |
201 | _ => {} |
202 | } |
203 | } |
204 | false |
205 | } |
206 | } |
207 | |
208 | let visitor = HasTypeParam(&type_params); |
209 | let where_clause = &mut generics.make_where_clause().predicates; |
210 | if !has_method || visitor.visit_type(fst) { |
211 | where_clause.push(parse_quote!(#fst: #trait_)); |
212 | } |
213 | if data.field_types().len() > 1 { |
214 | let fst_tokens = fst.to_token_stream().to_string(); |
215 | where_clause.extend(data.field_types().skip(1).filter_map( |
216 | |variant| -> Option<WherePredicate> { |
217 | if has_method && !visitor.visit_type(variant) { |
218 | return None; |
219 | } |
220 | if variant.to_token_stream().to_string() == fst_tokens { |
221 | return None; |
222 | } |
223 | if types.is_empty() { |
224 | return Some(parse_quote!(#variant: #trait_)); |
225 | } |
226 | let types = types.iter().map(|(supertraits, ident)| { |
227 | match trait_def.supertraits.iter().next() { |
228 | Some(TypeParamBound::Trait(trait_)) if *supertraits => { |
229 | quote!(#ident = <#fst as #trait_>::#ident) |
230 | } |
231 | _ => quote!(#ident = <#fst as #trait_>::#ident), |
232 | } |
233 | }); |
234 | if trait_def.generics.params.is_empty() { |
235 | Some(parse_quote!(#variant: #trait_path<#(#types),*>)) |
236 | } else { |
237 | let generics = |
238 | trait_def.generics.params.iter().map(|param| match param { |
239 | GenericParam::Lifetime(def) => def.lifetime.to_token_stream(), |
240 | GenericParam::Type(param) => param.ident.to_token_stream(), |
241 | GenericParam::Const(param) => param.ident.to_token_stream(), |
242 | }); |
243 | Some(parse_quote!(#variant: #trait_path<#(#generics),*, #(#types),*>)) |
244 | } |
245 | }, |
246 | )); |
247 | } |
248 | } |
249 | |
250 | if !trait_def.generics.params.is_empty() { |
251 | generics.params.extend(mem::take(&mut trait_def.generics.params)); |
252 | } |
253 | |
254 | if let Some(old) = trait_def.generics.where_clause.as_mut() { |
255 | if !old.predicates.is_empty() { |
256 | generics.make_where_clause().predicates.extend(mem::take(&mut old.predicates)); |
257 | } |
258 | } |
259 | |
260 | let ident = &data.ident; |
261 | let ty_generics = data.generics.split_for_impl().1; |
262 | let mut impls = Self { |
263 | data, |
264 | defaultness: false, |
265 | unsafety: trait_def.unsafety.is_some(), |
266 | generics, |
267 | trait_: Some(trait_), |
268 | self_ty: Box::new(parse_quote!(#ident #ty_generics)), |
269 | items: Vec::with_capacity(trait_def.items.len()), |
270 | }; |
271 | impls.append_items_from_trait(trait_def); |
272 | impls |
273 | } |
274 | |
275 | pub fn set_trait(&mut self, path: Path) { |
276 | self.trait_ = Some(path); |
277 | } |
278 | |
279 | /// Appends a generic type parameter to the back of generics. |
280 | pub fn push_generic_param(&mut self, param: GenericParam) { |
281 | self.generics.params.push(param); |
282 | } |
283 | |
284 | /// Appends a predicate to the back of `where`-clause. |
285 | pub fn push_where_predicate(&mut self, predicate: WherePredicate) { |
286 | self.generics.make_where_clause().predicates.push(predicate); |
287 | } |
288 | |
289 | /// Appends an item to impl items. |
290 | pub fn push_item(&mut self, item: ImplItem) { |
291 | self.items.push(item); |
292 | } |
293 | |
294 | /// Appends a method to impl items. |
295 | /// |
296 | /// # Panics |
297 | /// |
298 | /// Panics if a trait method has a body, no receiver, or a receiver other |
299 | /// than the following: |
300 | /// |
301 | /// - `&self` |
302 | /// - `&mut self` |
303 | /// - `self` |
304 | pub fn push_method(&mut self, item: TraitItemFn) { |
305 | assert!(item.default.is_none(), "method ` {}` has a body" , item.sig.ident); |
306 | |
307 | let self_ty = ReceiverKind::new(&item.sig); |
308 | let mut args = Vec::with_capacity(item.sig.inputs.len()); |
309 | item.sig.inputs.iter().skip(1).for_each(|arg| match arg { |
310 | FnArg::Typed(arg) => args.push(&arg.pat), |
311 | FnArg::Receiver(_) => panic!( |
312 | "method ` {}` has a receiver in a position other than the first argument" , |
313 | item.sig.ident |
314 | ), |
315 | }); |
316 | |
317 | let method = &item.sig.ident; |
318 | let ident = &self.data.ident; |
319 | let method = match self_ty { |
320 | ReceiverKind::Normal => match &self.trait_ { |
321 | None => { |
322 | let arms = self.data.variant_idents().map(|v| { |
323 | quote! { |
324 | #ident::#v(x) => x.#method(#(#args),*), |
325 | } |
326 | }); |
327 | parse_quote!(match self { #(#arms)* }) |
328 | } |
329 | Some(trait_) => { |
330 | let arms = |
331 | self.data.variant_idents().zip(self.data.field_types()).map(|(v, ty)| { |
332 | quote! { |
333 | #ident::#v(x) => <#ty as #trait_>::#method(x #(,#args)*), |
334 | } |
335 | }); |
336 | parse_quote!(match self { #(#arms)* }) |
337 | } |
338 | }, |
339 | }; |
340 | |
341 | self.push_item(ImplItem::Fn(ImplItemFn { |
342 | attrs: item.attrs, |
343 | vis: Visibility::Inherited, |
344 | defaultness: None, |
345 | sig: item.sig, |
346 | block: Block { |
347 | brace_token: token::Brace::default(), |
348 | stmts: vec![Stmt::Expr(method, None)], |
349 | }, |
350 | })); |
351 | } |
352 | |
353 | /// Appends items from a trait definition to impl items. |
354 | /// |
355 | /// # Panics |
356 | /// |
357 | /// Panics if a trait method has a body, no receiver, or a receiver other |
358 | /// than the following: |
359 | /// |
360 | /// - `&self` |
361 | /// - `&mut self` |
362 | /// - `self` |
363 | pub fn append_items_from_trait(&mut self, trait_def: ItemTrait) { |
364 | let fst = self.data.field_types().next(); |
365 | trait_def.items.into_iter().for_each(|item| match item { |
366 | // The TraitItemType::generics field (Generic associated types (GAT)) are not supported |
367 | TraitItem::Type(TraitItemType { ident, .. }) => { |
368 | let trait_ = &self.trait_; |
369 | let ty = parse_quote!(type #ident = <#fst as #trait_>::#ident;); |
370 | self.push_item(ImplItem::Type(ty)); |
371 | } |
372 | TraitItem::Fn(method) => self.push_method(method), |
373 | _ => {} |
374 | }); |
375 | } |
376 | |
377 | pub fn build(self) -> TokenStream { |
378 | self.build_impl().to_token_stream() |
379 | } |
380 | |
381 | pub fn build_impl(self) -> ItemImpl { |
382 | ItemImpl { |
383 | attrs: vec![], |
384 | defaultness: if self.defaultness { Some(<Token![default]>::default()) } else { None }, |
385 | unsafety: if self.unsafety { Some(<Token![unsafe]>::default()) } else { None }, |
386 | impl_token: token::Impl::default(), |
387 | generics: self.generics, |
388 | trait_: self.trait_.map(|trait_| (None, trait_, <Token![for]>::default())), |
389 | self_ty: self.self_ty, |
390 | brace_token: token::Brace::default(), |
391 | items: self.items, |
392 | } |
393 | } |
394 | } |
395 | |
396 | enum ReceiverKind { |
397 | /// `&(mut) self`, `(mut) self`, `(mut) self: &(mut) Self`, or `(mut) self: Self` |
398 | Normal, |
399 | } |
400 | |
401 | impl ReceiverKind { |
402 | fn new(sig: &Signature) -> Self { |
403 | fn get_ty_path(ty: &Type) -> Option<&Path> { |
404 | if let Type::Path(TypePath { qself: None, path }) = ty { |
405 | Some(path) |
406 | } else { |
407 | None |
408 | } |
409 | } |
410 | |
411 | match sig.receiver() { |
412 | None => panic!("method ` {}` has no receiver" , sig.ident), |
413 | Some(receiver) => { |
414 | if receiver.colon_token.is_none() { |
415 | return ReceiverKind::Normal; |
416 | } |
417 | match &*receiver.ty { |
418 | Type::Path(TypePath { qself: None, path }) => { |
419 | // (mut) self: Self |
420 | if path.is_ident("Self" ) { |
421 | return ReceiverKind::Normal; |
422 | } |
423 | } |
424 | Type::Reference(ty) => { |
425 | // (mut) self: &(mut) Self |
426 | if get_ty_path(&ty.elem).map_or(false, |path| path.is_ident("Self" )) { |
427 | return ReceiverKind::Normal; |
428 | } |
429 | } |
430 | _ => {} |
431 | } |
432 | panic!( |
433 | "method ` {}` has unsupported receiver type: {}" , |
434 | sig.ident, |
435 | receiver.ty.to_token_stream() |
436 | ); |
437 | } |
438 | } |
439 | } |
440 | } |
441 | |