1// SPDX-License-Identifier: Apache-2.0 OR MIT
2
3use core::mem;
4use std::borrow::Cow;
5
6use proc_macro2::{TokenStream, TokenTree};
7use quote::{quote, ToTokens as _};
8use syn::{
9 parse_quote, token, Block, FnArg, GenericParam, Generics, Ident, ImplItem, ImplItemFn,
10 ItemImpl, ItemTrait, Path, Signature, Stmt, Token, TraitItem, TraitItemFn, TraitItemType, Type,
11 TypeParamBound, TypePath, Visibility, WherePredicate,
12};
13
14use crate::ast::EnumData;
15
16/// A function for creating `proc_macro_derive` like deriving trait to enum so
17/// long as all variants are implemented that trait.
18///
19/// # Examples
20///
21/// ```
22/// # extern crate proc_macro;
23/// use derive_utils::derive_trait;
24/// use proc_macro::TokenStream;
25/// use quote::format_ident;
26/// use syn::{parse_macro_input, parse_quote};
27///
28/// # #[cfg(any())]
29/// #[proc_macro_derive(Iterator)]
30/// # pub fn _derive_iterator(_: TokenStream) -> TokenStream { unimplemented!() }
31/// pub fn derive_iterator(input: TokenStream) -> TokenStream {
32/// derive_trait(
33/// &parse_macro_input!(input),
34/// // trait path
35/// &parse_quote!(std::iter::Iterator),
36/// // super trait's associated types
37/// None,
38/// // trait definition
39/// parse_quote! {
40/// trait Iterator {
41/// type Item;
42/// fn next(&mut self) -> Option<Self::Item>;
43/// fn size_hint(&self) -> (usize, Option<usize>);
44/// }
45/// },
46/// )
47/// .into()
48/// }
49///
50/// # #[cfg(any())]
51/// #[proc_macro_derive(ExactSizeIterator)]
52/// # pub fn _derive_exact_size_iterator(_: TokenStream) -> TokenStream { unimplemented!() }
53/// pub fn derive_exact_size_iterator(input: TokenStream) -> TokenStream {
54/// derive_trait(
55/// &parse_macro_input!(input),
56/// // trait path
57/// &parse_quote!(std::iter::ExactSizeIterator),
58/// // super trait's associated types
59/// Some(format_ident!("Item")),
60/// // trait definition
61/// parse_quote! {
62/// trait ExactSizeIterator: Iterator {
63/// fn len(&self) -> usize;
64/// }
65/// },
66/// )
67/// .into()
68/// }
69/// ```
70pub fn derive_trait<I>(
71 data: &EnumData,
72 trait_path: &Path,
73 supertraits_types: I,
74 trait_def: ItemTrait,
75) -> TokenStream
76where
77 I: IntoIterator<Item = Ident>,
78 I::IntoIter: ExactSizeIterator,
79{
80 EnumImpl::from_trait(data, trait_path, supertraits_types, trait_def).build()
81}
82
83/// A builder for implementing a trait for enums.
84pub struct EnumImpl<'a> {
85 data: &'a EnumData,
86 defaultness: bool,
87 unsafety: bool,
88 generics: Generics,
89 trait_: Option<Path>,
90 self_ty: Box<Type>,
91 items: Vec<ImplItem>,
92}
93
94impl<'a> EnumImpl<'a> {
95 /// Creates a new `EnumImpl`.
96 pub fn new(data: &'a EnumData) -> Self {
97 let ident = &data.ident;
98 let ty_generics = data.generics.split_for_impl().1;
99 Self {
100 data,
101 defaultness: false,
102 unsafety: false,
103 generics: data.generics.clone(),
104 trait_: None,
105 self_ty: Box::new(parse_quote!(#ident #ty_generics)),
106 items: vec![],
107 }
108 }
109
110 /// Creates a new `EnumImpl` from a trait definition.
111 ///
112 /// The following items are ignored:
113 /// - Generic associated types (GAT) ([`TraitItem::Type`] that has generics)
114 /// - [`TraitItem::Const`]
115 /// - [`TraitItem::Macro`]
116 /// - [`TraitItem::Verbatim`]
117 ///
118 /// # Panics
119 ///
120 /// Panics if a trait method has a body, no receiver, or a receiver other
121 /// than the following:
122 ///
123 /// - `&self`
124 /// - `&mut self`
125 /// - `self`
126 pub fn from_trait<I>(
127 data: &'a EnumData,
128 trait_path: &Path,
129 supertraits_types: I,
130 mut trait_def: ItemTrait,
131 ) -> Self
132 where
133 I: IntoIterator<Item = Ident>,
134 I::IntoIter: ExactSizeIterator,
135 {
136 let mut generics = data.generics.clone();
137 let trait_ = {
138 if trait_def.generics.params.is_empty() {
139 trait_path.clone()
140 } else {
141 let ty_generics = trait_def.generics.split_for_impl().1;
142 parse_quote!(#trait_path #ty_generics)
143 }
144 };
145
146 let fst = data.field_types().next().unwrap();
147 let mut types: Vec<_> = trait_def
148 .items
149 .iter()
150 .filter_map(|item| match item {
151 TraitItem::Type(ty) => Some((false, Cow::Borrowed(&ty.ident))),
152 _ => None,
153 })
154 .collect();
155
156 let supertraits_types = supertraits_types.into_iter();
157 if supertraits_types.len() > 0 {
158 if let Some(TypeParamBound::Trait(_)) = trait_def.supertraits.iter().next() {
159 types.extend(supertraits_types.map(|ident| (true, Cow::Owned(ident))));
160 }
161 }
162
163 // https://github.com/taiki-e/derive_utils/issues/47
164 let type_params = generics.type_params().map(|p| p.ident.to_string()).collect::<Vec<_>>();
165 let has_method = trait_def.items.iter().any(|i| matches!(i, TraitItem::Fn(..)));
166 if !has_method || !type_params.is_empty() {
167 struct HasTypeParam<'a>(&'a [String]);
168
169 impl HasTypeParam<'_> {
170 fn check_ident(&self, ident: &Ident) -> bool {
171 let ident = ident.to_string();
172 self.0.contains(&ident)
173 }
174
175 fn visit_type(&self, ty: &Type) -> bool {
176 if let Type::Path(node) = ty {
177 if node.qself.is_none() {
178 if let Some(ident) = node.path.get_ident() {
179 return self.check_ident(ident);
180 }
181 }
182 }
183 self.visit_token_stream(ty.to_token_stream())
184 }
185
186 fn visit_token_stream(&self, tokens: TokenStream) -> bool {
187 for tt in tokens {
188 match tt {
189 TokenTree::Ident(ident) => {
190 if self.check_ident(&ident) {
191 return true;
192 }
193 }
194 TokenTree::Group(group) => {
195 let content = group.stream();
196 if self.visit_token_stream(content) {
197 return true;
198 }
199 }
200 _ => {}
201 }
202 }
203 false
204 }
205 }
206
207 let visitor = HasTypeParam(&type_params);
208 let where_clause = &mut generics.make_where_clause().predicates;
209 if !has_method || visitor.visit_type(fst) {
210 where_clause.push(parse_quote!(#fst: #trait_));
211 }
212 if data.field_types().len() > 1 {
213 let fst_tokens = fst.to_token_stream().to_string();
214 where_clause.extend(data.field_types().skip(1).filter_map(
215 |variant| -> Option<WherePredicate> {
216 if has_method && !visitor.visit_type(variant) {
217 return None;
218 }
219 if variant.to_token_stream().to_string() == fst_tokens {
220 return None;
221 }
222 if types.is_empty() {
223 return Some(parse_quote!(#variant: #trait_));
224 }
225 let types = types.iter().map(|(supertraits, ident)| {
226 match trait_def.supertraits.iter().next() {
227 Some(TypeParamBound::Trait(trait_)) if *supertraits => {
228 quote!(#ident = <#fst as #trait_>::#ident)
229 }
230 _ => quote!(#ident = <#fst as #trait_>::#ident),
231 }
232 });
233 if trait_def.generics.params.is_empty() {
234 Some(parse_quote!(#variant: #trait_path<#(#types),*>))
235 } else {
236 let generics =
237 trait_def.generics.params.iter().map(|param| match param {
238 GenericParam::Lifetime(def) => def.lifetime.to_token_stream(),
239 GenericParam::Type(param) => param.ident.to_token_stream(),
240 GenericParam::Const(param) => param.ident.to_token_stream(),
241 });
242 Some(parse_quote!(#variant: #trait_path<#(#generics),*, #(#types),*>))
243 }
244 },
245 ));
246 }
247 }
248
249 if !trait_def.generics.params.is_empty() {
250 generics.params.extend(mem::take(&mut trait_def.generics.params));
251 }
252
253 if let Some(old) = trait_def.generics.where_clause.as_mut() {
254 if !old.predicates.is_empty() {
255 generics.make_where_clause().predicates.extend(mem::take(&mut old.predicates));
256 }
257 }
258
259 let ident = &data.ident;
260 let ty_generics = data.generics.split_for_impl().1;
261 let mut impls = Self {
262 data,
263 defaultness: false,
264 unsafety: trait_def.unsafety.is_some(),
265 generics,
266 trait_: Some(trait_),
267 self_ty: Box::new(parse_quote!(#ident #ty_generics)),
268 items: Vec::with_capacity(trait_def.items.len()),
269 };
270 impls.append_items_from_trait(trait_def);
271 impls
272 }
273
274 pub fn set_trait(&mut self, path: Path) {
275 self.trait_ = Some(path);
276 }
277
278 /// Appends a generic type parameter to the back of generics.
279 pub fn push_generic_param(&mut self, param: GenericParam) {
280 self.generics.params.push(param);
281 }
282
283 /// Appends a predicate to the back of `where`-clause.
284 pub fn push_where_predicate(&mut self, predicate: WherePredicate) {
285 self.generics.make_where_clause().predicates.push(predicate);
286 }
287
288 /// Appends an item to impl items.
289 pub fn push_item(&mut self, item: ImplItem) {
290 self.items.push(item);
291 }
292
293 /// Appends a method to impl items.
294 ///
295 /// # Panics
296 ///
297 /// Panics if a trait method has a body, no receiver, or a receiver other
298 /// than the following:
299 ///
300 /// - `&self`
301 /// - `&mut self`
302 /// - `self`
303 pub fn push_method(&mut self, item: TraitItemFn) {
304 assert!(item.default.is_none(), "method `{}` has a body", item.sig.ident);
305
306 let self_ty = ReceiverKind::new(&item.sig);
307 let mut args = Vec::with_capacity(item.sig.inputs.len());
308 item.sig.inputs.iter().skip(1).for_each(|arg| match arg {
309 FnArg::Typed(arg) => args.push(&arg.pat),
310 FnArg::Receiver(_) => panic!(
311 "method `{}` has a receiver in a position other than the first argument",
312 item.sig.ident
313 ),
314 });
315
316 let method = &item.sig.ident;
317 let ident = &self.data.ident;
318 let method = match self_ty {
319 ReceiverKind::Normal => match &self.trait_ {
320 None => {
321 let arms = self.data.variant_idents().map(|v| {
322 quote! {
323 #ident::#v(x) => x.#method(#(#args),*),
324 }
325 });
326 parse_quote!(match self { #(#arms)* })
327 }
328 Some(trait_) => {
329 let arms =
330 self.data.variant_idents().zip(self.data.field_types()).map(|(v, ty)| {
331 quote! {
332 #ident::#v(x) => <#ty as #trait_>::#method(x #(,#args)*),
333 }
334 });
335 parse_quote!(match self { #(#arms)* })
336 }
337 },
338 };
339
340 self.push_item(ImplItem::Fn(ImplItemFn {
341 attrs: item.attrs,
342 vis: Visibility::Inherited,
343 defaultness: None,
344 sig: item.sig,
345 block: Block {
346 brace_token: token::Brace::default(),
347 stmts: vec![Stmt::Expr(method, None)],
348 },
349 }));
350 }
351
352 /// Appends items from a trait definition to impl items.
353 ///
354 /// # Panics
355 ///
356 /// Panics if a trait method has a body, no receiver, or a receiver other
357 /// than the following:
358 ///
359 /// - `&self`
360 /// - `&mut self`
361 /// - `self`
362 pub fn append_items_from_trait(&mut self, trait_def: ItemTrait) {
363 let fst = self.data.field_types().next();
364 trait_def.items.into_iter().for_each(|item| match item {
365 // The TraitItemType::generics field (Generic associated types (GAT)) are not supported
366 TraitItem::Type(TraitItemType { ident, .. }) => {
367 let trait_ = &self.trait_;
368 let ty = parse_quote!(type #ident = <#fst as #trait_>::#ident;);
369 self.push_item(ImplItem::Type(ty));
370 }
371 TraitItem::Fn(method) => self.push_method(method),
372 _ => {}
373 });
374 }
375
376 pub fn build(self) -> TokenStream {
377 self.build_impl().to_token_stream()
378 }
379
380 pub fn build_impl(self) -> ItemImpl {
381 ItemImpl {
382 attrs: vec![parse_quote!(#[automatically_derived])],
383 defaultness: if self.defaultness { Some(<Token![default]>::default()) } else { None },
384 unsafety: if self.unsafety { Some(<Token![unsafe]>::default()) } else { None },
385 impl_token: token::Impl::default(),
386 generics: self.generics,
387 trait_: self.trait_.map(|trait_| (None, trait_, <Token![for]>::default())),
388 self_ty: self.self_ty,
389 brace_token: token::Brace::default(),
390 items: self.items,
391 }
392 }
393}
394
395enum ReceiverKind {
396 /// `&(mut) self`, `(mut) self`, `(mut) self: &(mut) Self`, or `(mut) self: Self`
397 Normal,
398}
399
400impl ReceiverKind {
401 fn new(sig: &Signature) -> Self {
402 fn get_ty_path(ty: &Type) -> Option<&Path> {
403 if let Type::Path(TypePath { qself: None, path }) = ty {
404 Some(path)
405 } else {
406 None
407 }
408 }
409
410 match sig.receiver() {
411 None => panic!("method `{}` has no receiver", sig.ident),
412 Some(receiver) => {
413 if receiver.colon_token.is_none() {
414 return ReceiverKind::Normal;
415 }
416 match &*receiver.ty {
417 Type::Path(TypePath { qself: None, path }) => {
418 // (mut) self: Self
419 if path.is_ident("Self") {
420 return ReceiverKind::Normal;
421 }
422 }
423 Type::Reference(ty) => {
424 // (mut) self: &(mut) Self
425 if get_ty_path(&ty.elem).map_or(false, |path| path.is_ident("Self")) {
426 return ReceiverKind::Normal;
427 }
428 }
429 _ => {}
430 }
431 panic!(
432 "method `{}` has unsupported receiver type: {}",
433 sig.ident,
434 receiver.ty.to_token_stream()
435 );
436 }
437 }
438 }
439}
440