1// SPDX-License-Identifier: Apache-2.0 OR MIT
2
3use core::mem;
4use std::borrow::Cow;
5
6use proc_macro2::{TokenStream, TokenTree};
7use quote::{quote, ToTokens};
8use syn::{
9 parse_quote, token, Block, FnArg, GenericParam, Generics, Ident, ImplItem, ImplItemFn,
10 ItemImpl, ItemTrait, Path, Signature, Stmt, Token, TraitItem, TraitItemFn, TraitItemType, Type,
11 TypeParamBound, TypePath, Visibility, WherePredicate,
12};
13
14use crate::ast::EnumData;
15
16/// A function for creating `proc_macro_derive` like deriving trait to enum so
17/// long as all variants are implemented that trait.
18///
19/// # Examples
20///
21/// ```
22/// # extern crate proc_macro;
23/// #
24/// use derive_utils::derive_trait;
25/// use proc_macro::TokenStream;
26/// use quote::format_ident;
27/// use syn::{parse_macro_input, parse_quote};
28///
29/// # #[cfg(any())]
30/// #[proc_macro_derive(Iterator)]
31/// # pub fn _derive_iterator(_: TokenStream) -> TokenStream { unimplemented!() }
32/// pub fn derive_iterator(input: TokenStream) -> TokenStream {
33/// derive_trait(
34/// &parse_macro_input!(input),
35/// // trait path
36/// &parse_quote!(std::iter::Iterator),
37/// // super trait's associated types
38/// None,
39/// // trait definition
40/// parse_quote! {
41/// trait Iterator {
42/// type Item;
43/// fn next(&mut self) -> Option<Self::Item>;
44/// fn size_hint(&self) -> (usize, Option<usize>);
45/// }
46/// },
47/// )
48/// .into()
49/// }
50///
51/// # #[cfg(any())]
52/// #[proc_macro_derive(ExactSizeIterator)]
53/// # pub fn _derive_exact_size_iterator(_: TokenStream) -> TokenStream { unimplemented!() }
54/// pub fn derive_exact_size_iterator(input: TokenStream) -> TokenStream {
55/// derive_trait(
56/// &parse_macro_input!(input),
57/// // trait path
58/// &parse_quote!(std::iter::ExactSizeIterator),
59/// // super trait's associated types
60/// Some(format_ident!("Item")),
61/// // trait definition
62/// parse_quote! {
63/// trait ExactSizeIterator: Iterator {
64/// fn len(&self) -> usize;
65/// }
66/// },
67/// )
68/// .into()
69/// }
70/// ```
71pub fn derive_trait<I>(
72 data: &EnumData,
73 trait_path: &Path,
74 supertraits_types: I,
75 trait_def: ItemTrait,
76) -> TokenStream
77where
78 I: IntoIterator<Item = Ident>,
79 I::IntoIter: ExactSizeIterator,
80{
81 EnumImpl::from_trait(data, trait_path, supertraits_types, trait_def).build()
82}
83
84/// A builder for implementing a trait for enums.
85pub struct EnumImpl<'a> {
86 data: &'a EnumData,
87 defaultness: bool,
88 unsafety: bool,
89 generics: Generics,
90 trait_: Option<Path>,
91 self_ty: Box<Type>,
92 items: Vec<ImplItem>,
93}
94
95impl<'a> EnumImpl<'a> {
96 /// Creates a new `EnumImpl`.
97 pub fn new(data: &'a EnumData) -> Self {
98 let ident = &data.ident;
99 let ty_generics = data.generics.split_for_impl().1;
100 Self {
101 data,
102 defaultness: false,
103 unsafety: false,
104 generics: data.generics.clone(),
105 trait_: None,
106 self_ty: Box::new(parse_quote!(#ident #ty_generics)),
107 items: vec![],
108 }
109 }
110
111 /// Creates a new `EnumImpl` from a trait definition.
112 ///
113 /// The following items are ignored:
114 /// - Generic associated types (GAT) ([`TraitItem::Type`] that has generics)
115 /// - [`TraitItem::Const`]
116 /// - [`TraitItem::Macro`]
117 /// - [`TraitItem::Verbatim`]
118 ///
119 /// # Panics
120 ///
121 /// Panics if a trait method has a body, no receiver, or a receiver other
122 /// than the following:
123 ///
124 /// - `&self`
125 /// - `&mut self`
126 /// - `self`
127 pub fn from_trait<I>(
128 data: &'a EnumData,
129 trait_path: &Path,
130 supertraits_types: I,
131 mut trait_def: ItemTrait,
132 ) -> Self
133 where
134 I: IntoIterator<Item = Ident>,
135 I::IntoIter: ExactSizeIterator,
136 {
137 let mut generics = data.generics.clone();
138 let trait_ = {
139 if trait_def.generics.params.is_empty() {
140 trait_path.clone()
141 } else {
142 let ty_generics = trait_def.generics.split_for_impl().1;
143 parse_quote!(#trait_path #ty_generics)
144 }
145 };
146
147 let fst = data.field_types().next().unwrap();
148 let mut types: Vec<_> = trait_def
149 .items
150 .iter()
151 .filter_map(|item| match item {
152 TraitItem::Type(ty) => Some((false, Cow::Borrowed(&ty.ident))),
153 _ => None,
154 })
155 .collect();
156
157 let supertraits_types = supertraits_types.into_iter();
158 if supertraits_types.len() > 0 {
159 if let Some(TypeParamBound::Trait(_)) = trait_def.supertraits.iter().next() {
160 types.extend(supertraits_types.map(|ident| (true, Cow::Owned(ident))));
161 }
162 }
163
164 // https://github.com/taiki-e/derive_utils/issues/47
165 let type_params = generics.type_params().map(|p| p.ident.to_string()).collect::<Vec<_>>();
166 let has_method = trait_def.items.iter().any(|i| matches!(i, TraitItem::Fn(..)));
167 if !has_method || !type_params.is_empty() {
168 struct HasTypeParam<'a>(&'a [String]);
169
170 impl HasTypeParam<'_> {
171 fn check_ident(&self, ident: &Ident) -> bool {
172 let ident = ident.to_string();
173 self.0.contains(&ident)
174 }
175
176 fn visit_type(&self, ty: &Type) -> bool {
177 if let Type::Path(node) = ty {
178 if node.qself.is_none() {
179 if let Some(ident) = node.path.get_ident() {
180 return self.check_ident(ident);
181 }
182 }
183 }
184 self.visit_token_stream(ty.to_token_stream())
185 }
186
187 fn visit_token_stream(&self, tokens: TokenStream) -> bool {
188 for tt in tokens {
189 match tt {
190 TokenTree::Ident(ident) => {
191 if self.check_ident(&ident) {
192 return true;
193 }
194 }
195 TokenTree::Group(group) => {
196 let content = group.stream();
197 if self.visit_token_stream(content) {
198 return true;
199 }
200 }
201 _ => {}
202 }
203 }
204 false
205 }
206 }
207
208 let visitor = HasTypeParam(&type_params);
209 let where_clause = &mut generics.make_where_clause().predicates;
210 if !has_method || visitor.visit_type(fst) {
211 where_clause.push(parse_quote!(#fst: #trait_));
212 }
213 if data.field_types().len() > 1 {
214 let fst_tokens = fst.to_token_stream().to_string();
215 where_clause.extend(data.field_types().skip(1).filter_map(
216 |variant| -> Option<WherePredicate> {
217 if has_method && !visitor.visit_type(variant) {
218 return None;
219 }
220 if variant.to_token_stream().to_string() == fst_tokens {
221 return None;
222 }
223 if types.is_empty() {
224 return Some(parse_quote!(#variant: #trait_));
225 }
226 let types = types.iter().map(|(supertraits, ident)| {
227 match trait_def.supertraits.iter().next() {
228 Some(TypeParamBound::Trait(trait_)) if *supertraits => {
229 quote!(#ident = <#fst as #trait_>::#ident)
230 }
231 _ => quote!(#ident = <#fst as #trait_>::#ident),
232 }
233 });
234 if trait_def.generics.params.is_empty() {
235 Some(parse_quote!(#variant: #trait_path<#(#types),*>))
236 } else {
237 let generics =
238 trait_def.generics.params.iter().map(|param| match param {
239 GenericParam::Lifetime(def) => def.lifetime.to_token_stream(),
240 GenericParam::Type(param) => param.ident.to_token_stream(),
241 GenericParam::Const(param) => param.ident.to_token_stream(),
242 });
243 Some(parse_quote!(#variant: #trait_path<#(#generics),*, #(#types),*>))
244 }
245 },
246 ));
247 }
248 }
249
250 if !trait_def.generics.params.is_empty() {
251 generics.params.extend(mem::take(&mut trait_def.generics.params));
252 }
253
254 if let Some(old) = trait_def.generics.where_clause.as_mut() {
255 if !old.predicates.is_empty() {
256 generics.make_where_clause().predicates.extend(mem::take(&mut old.predicates));
257 }
258 }
259
260 let ident = &data.ident;
261 let ty_generics = data.generics.split_for_impl().1;
262 let mut impls = Self {
263 data,
264 defaultness: false,
265 unsafety: trait_def.unsafety.is_some(),
266 generics,
267 trait_: Some(trait_),
268 self_ty: Box::new(parse_quote!(#ident #ty_generics)),
269 items: Vec::with_capacity(trait_def.items.len()),
270 };
271 impls.append_items_from_trait(trait_def);
272 impls
273 }
274
275 pub fn set_trait(&mut self, path: Path) {
276 self.trait_ = Some(path);
277 }
278
279 /// Appends a generic type parameter to the back of generics.
280 pub fn push_generic_param(&mut self, param: GenericParam) {
281 self.generics.params.push(param);
282 }
283
284 /// Appends a predicate to the back of `where`-clause.
285 pub fn push_where_predicate(&mut self, predicate: WherePredicate) {
286 self.generics.make_where_clause().predicates.push(predicate);
287 }
288
289 /// Appends an item to impl items.
290 pub fn push_item(&mut self, item: ImplItem) {
291 self.items.push(item);
292 }
293
294 /// Appends a method to impl items.
295 ///
296 /// # Panics
297 ///
298 /// Panics if a trait method has a body, no receiver, or a receiver other
299 /// than the following:
300 ///
301 /// - `&self`
302 /// - `&mut self`
303 /// - `self`
304 pub fn push_method(&mut self, item: TraitItemFn) {
305 assert!(item.default.is_none(), "method `{}` has a body", item.sig.ident);
306
307 let self_ty = ReceiverKind::new(&item.sig);
308 let mut args = Vec::with_capacity(item.sig.inputs.len());
309 item.sig.inputs.iter().skip(1).for_each(|arg| match arg {
310 FnArg::Typed(arg) => args.push(&arg.pat),
311 FnArg::Receiver(_) => panic!(
312 "method `{}` has a receiver in a position other than the first argument",
313 item.sig.ident
314 ),
315 });
316
317 let method = &item.sig.ident;
318 let ident = &self.data.ident;
319 let method = match self_ty {
320 ReceiverKind::Normal => match &self.trait_ {
321 None => {
322 let arms = self.data.variant_idents().map(|v| {
323 quote! {
324 #ident::#v(x) => x.#method(#(#args),*),
325 }
326 });
327 parse_quote!(match self { #(#arms)* })
328 }
329 Some(trait_) => {
330 let arms =
331 self.data.variant_idents().zip(self.data.field_types()).map(|(v, ty)| {
332 quote! {
333 #ident::#v(x) => <#ty as #trait_>::#method(x #(,#args)*),
334 }
335 });
336 parse_quote!(match self { #(#arms)* })
337 }
338 },
339 };
340
341 self.push_item(ImplItem::Fn(ImplItemFn {
342 attrs: item.attrs,
343 vis: Visibility::Inherited,
344 defaultness: None,
345 sig: item.sig,
346 block: Block {
347 brace_token: token::Brace::default(),
348 stmts: vec![Stmt::Expr(method, None)],
349 },
350 }));
351 }
352
353 /// Appends items from a trait definition to impl items.
354 ///
355 /// # Panics
356 ///
357 /// Panics if a trait method has a body, no receiver, or a receiver other
358 /// than the following:
359 ///
360 /// - `&self`
361 /// - `&mut self`
362 /// - `self`
363 pub fn append_items_from_trait(&mut self, trait_def: ItemTrait) {
364 let fst = self.data.field_types().next();
365 trait_def.items.into_iter().for_each(|item| match item {
366 // The TraitItemType::generics field (Generic associated types (GAT)) are not supported
367 TraitItem::Type(TraitItemType { ident, .. }) => {
368 let trait_ = &self.trait_;
369 let ty = parse_quote!(type #ident = <#fst as #trait_>::#ident;);
370 self.push_item(ImplItem::Type(ty));
371 }
372 TraitItem::Fn(method) => self.push_method(method),
373 _ => {}
374 });
375 }
376
377 pub fn build(self) -> TokenStream {
378 self.build_impl().to_token_stream()
379 }
380
381 pub fn build_impl(self) -> ItemImpl {
382 ItemImpl {
383 attrs: vec![],
384 defaultness: if self.defaultness { Some(<Token![default]>::default()) } else { None },
385 unsafety: if self.unsafety { Some(<Token![unsafe]>::default()) } else { None },
386 impl_token: token::Impl::default(),
387 generics: self.generics,
388 trait_: self.trait_.map(|trait_| (None, trait_, <Token![for]>::default())),
389 self_ty: self.self_ty,
390 brace_token: token::Brace::default(),
391 items: self.items,
392 }
393 }
394}
395
396enum ReceiverKind {
397 /// `&(mut) self`, `(mut) self`, `(mut) self: &(mut) Self`, or `(mut) self: Self`
398 Normal,
399}
400
401impl ReceiverKind {
402 fn new(sig: &Signature) -> Self {
403 fn get_ty_path(ty: &Type) -> Option<&Path> {
404 if let Type::Path(TypePath { qself: None, path }) = ty {
405 Some(path)
406 } else {
407 None
408 }
409 }
410
411 match sig.receiver() {
412 None => panic!("method `{}` has no receiver", sig.ident),
413 Some(receiver) => {
414 if receiver.colon_token.is_none() {
415 return ReceiverKind::Normal;
416 }
417 match &*receiver.ty {
418 Type::Path(TypePath { qself: None, path }) => {
419 // (mut) self: Self
420 if path.is_ident("Self") {
421 return ReceiverKind::Normal;
422 }
423 }
424 Type::Reference(ty) => {
425 // (mut) self: &(mut) Self
426 if get_ty_path(&ty.elem).map_or(false, |path| path.is_ident("Self")) {
427 return ReceiverKind::Normal;
428 }
429 }
430 _ => {}
431 }
432 panic!(
433 "method `{}` has unsupported receiver type: {}",
434 sig.ident,
435 receiver.ty.to_token_stream()
436 );
437 }
438 }
439 }
440}
441