1// SPDX-License-Identifier: Apache-2.0 OR MIT
2
3use std::cell::Cell;
4
5use derive_utils::EnumData as Data;
6use proc_macro2::TokenStream;
7use quote::{quote, ToTokens};
8use syn::{
9 parse::{Parse, ParseStream},
10 parse_quote, Error, ItemEnum, Path, Result, Token,
11};
12
13pub(crate) fn attribute(args: TokenStream, input: TokenStream) -> TokenStream {
14 expand(args, input).unwrap_or_else(op:Error::into_compile_error)
15}
16
17#[derive(Default)]
18pub(crate) struct DeriveContext {
19 needs_pin_projection: Cell<bool>,
20}
21
22impl DeriveContext {
23 pub(crate) fn needs_pin_projection(&self) {
24 self.needs_pin_projection.set(val:true);
25 }
26}
27
28type DeriveFn = fn(&'_ DeriveContext, &'_ Data) -> Result<TokenStream>;
29
30fn get_derive(s: &str) -> Option<DeriveFn> {
31 macro_rules! match_derive {
32 ($($(#[$meta:meta])* $($arm:ident)::*,)*) => {$(
33 $(#[$meta])*
34 {
35 if crate::derive::$($arm)::*::NAME.iter().any(|name| *name == s) {
36 return Some(crate::derive::$($arm)::*::derive)
37 }
38 }
39 )*};
40 }
41
42 match_derive! {
43 // core
44 #[cfg(feature = "convert")]
45 core::convert::as_mut,
46 #[cfg(feature = "convert")]
47 core::convert::as_ref,
48 core::fmt::debug,
49 core::fmt::display,
50 #[cfg(feature = "fmt")]
51 core::fmt::pointer,
52 #[cfg(feature = "fmt")]
53 core::fmt::binary,
54 #[cfg(feature = "fmt")]
55 core::fmt::octal,
56 #[cfg(feature = "fmt")]
57 core::fmt::upper_hex,
58 #[cfg(feature = "fmt")]
59 core::fmt::lower_hex,
60 #[cfg(feature = "fmt")]
61 core::fmt::upper_exp,
62 #[cfg(feature = "fmt")]
63 core::fmt::lower_exp,
64 core::fmt::write,
65 core::iter::iterator,
66 core::iter::double_ended_iterator,
67 core::iter::exact_size_iterator,
68 core::iter::fused_iterator,
69 #[cfg(feature = "trusted_len")]
70 core::iter::trusted_len,
71 core::iter::extend,
72 #[cfg(feature = "ops")]
73 core::ops::deref,
74 #[cfg(feature = "ops")]
75 core::ops::deref_mut,
76 #[cfg(feature = "ops")]
77 core::ops::index,
78 #[cfg(feature = "ops")]
79 core::ops::index_mut,
80 #[cfg(feature = "ops")]
81 core::ops::range_bounds,
82 #[cfg(feature = "fn_traits")]
83 core::ops::fn_,
84 #[cfg(feature = "fn_traits")]
85 core::ops::fn_mut,
86 #[cfg(feature = "fn_traits")]
87 core::ops::fn_once,
88 #[cfg(feature = "coroutine_trait")]
89 core::ops::coroutine,
90 core::future,
91 // std
92 #[cfg(feature = "std")]
93 std::io::read,
94 #[cfg(feature = "std")]
95 std::io::buf_read,
96 #[cfg(feature = "std")]
97 std::io::seek,
98 #[cfg(feature = "std")]
99 std::io::write,
100 #[cfg(feature = "std")]
101 std::error,
102 // type impls
103 #[cfg(feature = "transpose_methods")]
104 ty_impls::transpose,
105 // futures03
106 #[cfg(feature = "futures03")]
107 external::futures03::stream,
108 #[cfg(feature = "futures03")]
109 external::futures03::sink,
110 #[cfg(feature = "futures03")]
111 external::futures03::async_read,
112 #[cfg(feature = "futures03")]
113 external::futures03::async_write,
114 #[cfg(feature = "futures03")]
115 external::futures03::async_seek,
116 #[cfg(feature = "futures03")]
117 external::futures03::async_buf_read,
118 // futures01
119 #[cfg(feature = "futures01")]
120 external::futures01::future,
121 #[cfg(feature = "futures01")]
122 external::futures01::stream,
123 #[cfg(feature = "futures01")]
124 external::futures01::sink,
125 // rayon
126 #[cfg(feature = "rayon")]
127 external::rayon::par_iter,
128 #[cfg(feature = "rayon")]
129 external::rayon::indexed_par_iter,
130 #[cfg(feature = "rayon")]
131 external::rayon::par_extend,
132 // serde
133 #[cfg(feature = "serde")]
134 external::serde::serialize,
135 // tokio1
136 #[cfg(feature = "tokio1")]
137 external::tokio1::async_read,
138 #[cfg(feature = "tokio1")]
139 external::tokio1::async_write,
140 #[cfg(feature = "tokio1")]
141 external::tokio1::async_seek,
142 #[cfg(feature = "tokio1")]
143 external::tokio1::async_buf_read,
144 // tokio03
145 #[cfg(feature = "tokio03")]
146 external::tokio03::async_read,
147 #[cfg(feature = "tokio03")]
148 external::tokio03::async_write,
149 #[cfg(feature = "tokio03")]
150 external::tokio03::async_seek,
151 #[cfg(feature = "tokio03")]
152 external::tokio03::async_buf_read,
153 // tokio02
154 #[cfg(feature = "tokio02")]
155 external::tokio02::async_read,
156 #[cfg(feature = "tokio02")]
157 external::tokio02::async_write,
158 #[cfg(feature = "tokio02")]
159 external::tokio02::async_seek,
160 #[cfg(feature = "tokio02")]
161 external::tokio02::async_buf_read,
162 // tokio01
163 #[cfg(feature = "tokio01")]
164 external::tokio01::async_read,
165 #[cfg(feature = "tokio01")]
166 external::tokio01::async_write,
167 // http_body1
168 #[cfg(feature = "http_body1")]
169 external::http_body1::body,
170 }
171
172 None
173}
174
175struct Args {
176 inner: Vec<(String, Path)>,
177}
178
179impl Parse for Args {
180 fn parse(input: ParseStream<'_>) -> Result<Self> {
181 fn to_trimmed_string(p: &Path) -> String {
182 p.to_token_stream().to_string().replace(from:' ', to:"")
183 }
184
185 let mut inner: Vec<(String, Path)> = vec![];
186 while !input.is_empty() {
187 let path: Path = input.parse()?;
188 inner.push((to_trimmed_string(&path), path));
189
190 if input.is_empty() {
191 break;
192 }
193 let _: Token![,] = input.parse()?;
194 }
195
196 Ok(Self { inner })
197 }
198}
199
200fn get_trait_deps(s: &str) -> Option<&'static [&'static str]> {
201 Some(match s {
202 "Copy" => &["Clone"],
203 "Eq" | "PartialOrd" => &["PartialEq"],
204 "Ord" => &["PartialOrd", "Eq", "PartialEq"],
205 #[cfg(feature = "ops")]
206 "DerefMut" => &["Deref"],
207 #[cfg(feature = "ops")]
208 "IndexMut" => &["Index"],
209 #[cfg(feature = "fn_traits")]
210 "Fn" => &["FnMut", "FnOnce"],
211 #[cfg(feature = "fn_traits")]
212 "FnMut" => &["FnOnce"],
213 "DoubleEndedIterator" | "ExactSizeIterator" | "FusedIterator" => &["Iterator"],
214 #[cfg(feature = "trusted_len")]
215 "TrustedLen" => &["Iterator"],
216 #[cfg(feature = "std")]
217 "BufRead" | "io::BufRead" => &["Read"],
218 #[cfg(feature = "std")]
219 "Error" => &["Display", "Debug"],
220 #[cfg(feature = "rayon")]
221 "rayon::IndexedParallelIterator" => &["rayon::ParallelIterator"],
222 _ => return None,
223 })
224}
225
226fn exists_alias(s: &str, v: &[(&str, Option<&Path>)]) -> bool {
227 fn get_alias(s: &str) -> Option<&'static str> {
228 macro_rules! match_alias {
229 ($($(#[$meta:meta])* $($arm:ident)::*,)*) => {$(
230 $(#[$meta])*
231 {
232 if s == crate::derive::$($arm)::*::NAME[0] {
233 return Some(crate::derive::$($arm)::*::NAME[1]);
234 } else if s == crate::derive::$($arm)::*::NAME[1] {
235 return Some(crate::derive::$($arm)::*::NAME[0]);
236 }
237 }
238 )*};
239 }
240
241 match_alias! {
242 // core
243 core::fmt::debug,
244 core::fmt::display,
245 // std
246 #[cfg(feature = "std")]
247 std::io::read,
248 #[cfg(feature = "std")]
249 std::io::buf_read,
250 #[cfg(feature = "std")]
251 std::io::seek,
252 #[cfg(feature = "std")]
253 std::io::write,
254 }
255
256 None
257 }
258
259 get_alias(s).map_or(false, |x| v.iter().any(|(s, _)| *s == x))
260}
261
262fn expand(args: TokenStream, input: TokenStream) -> Result<TokenStream> {
263 let data = syn::parse2::<Data>(input)?;
264 let args = syn::parse2::<Args>(args)?.inner;
265 let args = args.iter().fold(vec![], |mut v, (s, arg)| {
266 if let Some(traits) = get_trait_deps(s) {
267 for s in traits.iter().filter(|&x| !args.iter().any(|(s, _)| s == x)) {
268 if !exists_alias(s, &v) {
269 v.push((s, None));
270 }
271 }
272 }
273 if !exists_alias(s, &v) {
274 v.push((s, Some(arg)));
275 }
276 v
277 });
278
279 let mut derive = vec![];
280 let mut items = TokenStream::new();
281 let cx = DeriveContext::default();
282 for (s, arg) in args {
283 match (get_derive(s), arg) {
284 (Some(f), _) => {
285 items.extend(
286 f(&cx, &data).map_err(|e| format_err!(data, "`enum_derive({})` {}", s, e))?,
287 );
288 }
289 (_, Some(arg)) => derive.push(arg),
290 _ => {}
291 }
292 }
293
294 let mut item = if cx.needs_pin_projection.get() {
295 // If a user creates their own Unpin or Drop implementation, trait implementations with
296 // `Pin<&mut self>` receiver can cause unsoundness.
297 //
298 // This was not a problem in #[auto_enum] attribute where enums are anonymized,
299 // but it becomes a problem when users have access to enums (i.e., when using #[enum_derive]).
300 //
301 // So, we ensure safety here by an Unpin implementation that implements Unpin
302 // only if all fields are Unpin (this also forbids custom Unpin implementation),
303 // and a hack that forbids custom Drop implementation. (Both are what pin-project does by default.)
304 // The repr(packed) check is not needed since repr(packed) is not available on enum.
305
306 // Automatically create the appropriate conditional `Unpin` implementation.
307 // https://github.com/taiki-e/pin-project/blob/v1.0.10/examples/struct-default-expanded.rs#L89
308 // TODO: use https://github.com/taiki-e/pin-project/issues/102#issuecomment-540472282's trick.
309 items.extend(derive_utils::derive_trait(
310 &data,
311 &parse_quote!(::core::marker::Unpin),
312 None,
313 parse_quote! {
314 trait Unpin {}
315 },
316 ));
317
318 let item: ItemEnum = data.into();
319 let name = &item.ident;
320 let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl();
321 // Ensure that enum does not implement `Drop`.
322 // https://github.com/taiki-e/pin-project/blob/v1.0.10/examples/struct-default-expanded.rs#L138
323 items.extend(quote! {
324 const _: () = {
325 trait MustNotImplDrop {}
326 #[allow(clippy::drop_bounds, drop_bounds)]
327 impl<T: ::core::ops::Drop> MustNotImplDrop for T {}
328 impl #impl_generics MustNotImplDrop for #name #ty_generics #where_clause {}
329 };
330 });
331 item
332 } else {
333 data.into()
334 };
335
336 if !derive.is_empty() {
337 item.attrs.push(parse_quote!(#[derive(#(#derive),*)]));
338 }
339
340 let mut item = item.into_token_stream();
341 item.extend(items);
342 Ok(item)
343}
344