1 | // SPDX-License-Identifier: Apache-2.0 OR MIT |
2 | |
3 | use std::cell::Cell; |
4 | |
5 | use derive_utils::EnumData as Data; |
6 | use proc_macro2::TokenStream; |
7 | use quote::{quote, ToTokens}; |
8 | use syn::{ |
9 | parse::{Parse, ParseStream}, |
10 | parse_quote, Error, ItemEnum, Path, Result, Token, |
11 | }; |
12 | |
13 | pub(crate) fn attribute(args: TokenStream, input: TokenStream) -> TokenStream { |
14 | expand(args, input).unwrap_or_else(op:Error::into_compile_error) |
15 | } |
16 | |
17 | #[derive (Default)] |
18 | pub(crate) struct DeriveContext { |
19 | needs_pin_projection: Cell<bool>, |
20 | } |
21 | |
22 | impl DeriveContext { |
23 | pub(crate) fn needs_pin_projection(&self) { |
24 | self.needs_pin_projection.set(val:true); |
25 | } |
26 | } |
27 | |
28 | type DeriveFn = fn(&'_ DeriveContext, &'_ Data) -> Result<TokenStream>; |
29 | |
30 | fn get_derive(s: &str) -> Option<DeriveFn> { |
31 | macro_rules! match_derive { |
32 | ($($(#[$meta:meta])* $($arm:ident)::*,)*) => {$( |
33 | $(#[$meta])* |
34 | { |
35 | if crate::derive::$($arm)::*::NAME.iter().any(|name| *name == s) { |
36 | return Some(crate::derive::$($arm)::*::derive) |
37 | } |
38 | } |
39 | )*}; |
40 | } |
41 | |
42 | match_derive! { |
43 | // core |
44 | #[cfg (feature = "convert" )] |
45 | core::convert::as_mut, |
46 | #[cfg (feature = "convert" )] |
47 | core::convert::as_ref, |
48 | core::fmt::debug, |
49 | core::fmt::display, |
50 | #[cfg (feature = "fmt" )] |
51 | core::fmt::pointer, |
52 | #[cfg (feature = "fmt" )] |
53 | core::fmt::binary, |
54 | #[cfg (feature = "fmt" )] |
55 | core::fmt::octal, |
56 | #[cfg (feature = "fmt" )] |
57 | core::fmt::upper_hex, |
58 | #[cfg (feature = "fmt" )] |
59 | core::fmt::lower_hex, |
60 | #[cfg (feature = "fmt" )] |
61 | core::fmt::upper_exp, |
62 | #[cfg (feature = "fmt" )] |
63 | core::fmt::lower_exp, |
64 | core::fmt::write, |
65 | core::iter::iterator, |
66 | core::iter::double_ended_iterator, |
67 | core::iter::exact_size_iterator, |
68 | core::iter::fused_iterator, |
69 | #[cfg (feature = "trusted_len" )] |
70 | core::iter::trusted_len, |
71 | core::iter::extend, |
72 | #[cfg (feature = "ops" )] |
73 | core::ops::deref, |
74 | #[cfg (feature = "ops" )] |
75 | core::ops::deref_mut, |
76 | #[cfg (feature = "ops" )] |
77 | core::ops::index, |
78 | #[cfg (feature = "ops" )] |
79 | core::ops::index_mut, |
80 | #[cfg (feature = "ops" )] |
81 | core::ops::range_bounds, |
82 | #[cfg (feature = "fn_traits" )] |
83 | core::ops::fn_, |
84 | #[cfg (feature = "fn_traits" )] |
85 | core::ops::fn_mut, |
86 | #[cfg (feature = "fn_traits" )] |
87 | core::ops::fn_once, |
88 | #[cfg (feature = "coroutine_trait" )] |
89 | core::ops::coroutine, |
90 | core::future, |
91 | // std |
92 | #[cfg (feature = "std" )] |
93 | std::io::read, |
94 | #[cfg (feature = "std" )] |
95 | std::io::buf_read, |
96 | #[cfg (feature = "std" )] |
97 | std::io::seek, |
98 | #[cfg (feature = "std" )] |
99 | std::io::write, |
100 | #[cfg (feature = "std" )] |
101 | std::error, |
102 | // type impls |
103 | #[cfg (feature = "transpose_methods" )] |
104 | ty_impls::transpose, |
105 | // futures03 |
106 | #[cfg (feature = "futures03" )] |
107 | external::futures03::stream, |
108 | #[cfg (feature = "futures03" )] |
109 | external::futures03::sink, |
110 | #[cfg (feature = "futures03" )] |
111 | external::futures03::async_read, |
112 | #[cfg (feature = "futures03" )] |
113 | external::futures03::async_write, |
114 | #[cfg (feature = "futures03" )] |
115 | external::futures03::async_seek, |
116 | #[cfg (feature = "futures03" )] |
117 | external::futures03::async_buf_read, |
118 | // futures01 |
119 | #[cfg (feature = "futures01" )] |
120 | external::futures01::future, |
121 | #[cfg (feature = "futures01" )] |
122 | external::futures01::stream, |
123 | #[cfg (feature = "futures01" )] |
124 | external::futures01::sink, |
125 | // rayon |
126 | #[cfg (feature = "rayon" )] |
127 | external::rayon::par_iter, |
128 | #[cfg (feature = "rayon" )] |
129 | external::rayon::indexed_par_iter, |
130 | #[cfg (feature = "rayon" )] |
131 | external::rayon::par_extend, |
132 | // serde |
133 | #[cfg (feature = "serde" )] |
134 | external::serde::serialize, |
135 | // tokio1 |
136 | #[cfg (feature = "tokio1" )] |
137 | external::tokio1::async_read, |
138 | #[cfg (feature = "tokio1" )] |
139 | external::tokio1::async_write, |
140 | #[cfg (feature = "tokio1" )] |
141 | external::tokio1::async_seek, |
142 | #[cfg (feature = "tokio1" )] |
143 | external::tokio1::async_buf_read, |
144 | // tokio03 |
145 | #[cfg (feature = "tokio03" )] |
146 | external::tokio03::async_read, |
147 | #[cfg (feature = "tokio03" )] |
148 | external::tokio03::async_write, |
149 | #[cfg (feature = "tokio03" )] |
150 | external::tokio03::async_seek, |
151 | #[cfg (feature = "tokio03" )] |
152 | external::tokio03::async_buf_read, |
153 | // tokio02 |
154 | #[cfg (feature = "tokio02" )] |
155 | external::tokio02::async_read, |
156 | #[cfg (feature = "tokio02" )] |
157 | external::tokio02::async_write, |
158 | #[cfg (feature = "tokio02" )] |
159 | external::tokio02::async_seek, |
160 | #[cfg (feature = "tokio02" )] |
161 | external::tokio02::async_buf_read, |
162 | // tokio01 |
163 | #[cfg (feature = "tokio01" )] |
164 | external::tokio01::async_read, |
165 | #[cfg (feature = "tokio01" )] |
166 | external::tokio01::async_write, |
167 | // http_body1 |
168 | #[cfg (feature = "http_body1" )] |
169 | external::http_body1::body, |
170 | } |
171 | |
172 | None |
173 | } |
174 | |
175 | struct Args { |
176 | inner: Vec<(String, Path)>, |
177 | } |
178 | |
179 | impl Parse for Args { |
180 | fn parse(input: ParseStream<'_>) -> Result<Self> { |
181 | fn to_trimmed_string(p: &Path) -> String { |
182 | p.to_token_stream().to_string().replace(from:' ' , to:"" ) |
183 | } |
184 | |
185 | let mut inner: Vec<(String, Path)> = vec![]; |
186 | while !input.is_empty() { |
187 | let path: Path = input.parse()?; |
188 | inner.push((to_trimmed_string(&path), path)); |
189 | |
190 | if input.is_empty() { |
191 | break; |
192 | } |
193 | let _: Token![,] = input.parse()?; |
194 | } |
195 | |
196 | Ok(Self { inner }) |
197 | } |
198 | } |
199 | |
200 | fn get_trait_deps(s: &str) -> Option<&'static [&'static str]> { |
201 | Some(match s { |
202 | "Copy" => &["Clone" ], |
203 | "Eq" | "PartialOrd" => &["PartialEq" ], |
204 | "Ord" => &["PartialOrd" , "Eq" , "PartialEq" ], |
205 | #[cfg (feature = "ops" )] |
206 | "DerefMut" => &["Deref" ], |
207 | #[cfg (feature = "ops" )] |
208 | "IndexMut" => &["Index" ], |
209 | #[cfg (feature = "fn_traits" )] |
210 | "Fn" => &["FnMut" , "FnOnce" ], |
211 | #[cfg (feature = "fn_traits" )] |
212 | "FnMut" => &["FnOnce" ], |
213 | "DoubleEndedIterator" | "ExactSizeIterator" | "FusedIterator" => &["Iterator" ], |
214 | #[cfg (feature = "trusted_len" )] |
215 | "TrustedLen" => &["Iterator" ], |
216 | #[cfg (feature = "std" )] |
217 | "BufRead" | "io::BufRead" => &["Read" ], |
218 | #[cfg (feature = "std" )] |
219 | "Error" => &["Display" , "Debug" ], |
220 | #[cfg (feature = "rayon" )] |
221 | "rayon::IndexedParallelIterator" => &["rayon::ParallelIterator" ], |
222 | _ => return None, |
223 | }) |
224 | } |
225 | |
226 | fn exists_alias(s: &str, v: &[(&str, Option<&Path>)]) -> bool { |
227 | fn get_alias(s: &str) -> Option<&'static str> { |
228 | macro_rules! match_alias { |
229 | ($($(#[$meta:meta])* $($arm:ident)::*,)*) => {$( |
230 | $(#[$meta])* |
231 | { |
232 | if s == crate::derive::$($arm)::*::NAME[0] { |
233 | return Some(crate::derive::$($arm)::*::NAME[1]); |
234 | } else if s == crate::derive::$($arm)::*::NAME[1] { |
235 | return Some(crate::derive::$($arm)::*::NAME[0]); |
236 | } |
237 | } |
238 | )*}; |
239 | } |
240 | |
241 | match_alias! { |
242 | // core |
243 | core::fmt::debug, |
244 | core::fmt::display, |
245 | // std |
246 | #[cfg (feature = "std" )] |
247 | std::io::read, |
248 | #[cfg (feature = "std" )] |
249 | std::io::buf_read, |
250 | #[cfg (feature = "std" )] |
251 | std::io::seek, |
252 | #[cfg (feature = "std" )] |
253 | std::io::write, |
254 | } |
255 | |
256 | None |
257 | } |
258 | |
259 | get_alias(s).map_or(false, |x| v.iter().any(|(s, _)| *s == x)) |
260 | } |
261 | |
262 | fn expand(args: TokenStream, input: TokenStream) -> Result<TokenStream> { |
263 | let data = syn::parse2::<Data>(input)?; |
264 | let args = syn::parse2::<Args>(args)?.inner; |
265 | let args = args.iter().fold(vec![], |mut v, (s, arg)| { |
266 | if let Some(traits) = get_trait_deps(s) { |
267 | for s in traits.iter().filter(|&x| !args.iter().any(|(s, _)| s == x)) { |
268 | if !exists_alias(s, &v) { |
269 | v.push((s, None)); |
270 | } |
271 | } |
272 | } |
273 | if !exists_alias(s, &v) { |
274 | v.push((s, Some(arg))); |
275 | } |
276 | v |
277 | }); |
278 | |
279 | let mut derive = vec![]; |
280 | let mut items = TokenStream::new(); |
281 | let cx = DeriveContext::default(); |
282 | for (s, arg) in args { |
283 | match (get_derive(s), arg) { |
284 | (Some(f), _) => { |
285 | items.extend( |
286 | f(&cx, &data).map_err(|e| format_err!(data, "`enum_derive( {})` {}" , s, e))?, |
287 | ); |
288 | } |
289 | (_, Some(arg)) => derive.push(arg), |
290 | _ => {} |
291 | } |
292 | } |
293 | |
294 | let mut item = if cx.needs_pin_projection.get() { |
295 | // If a user creates their own Unpin or Drop implementation, trait implementations with |
296 | // `Pin<&mut self>` receiver can cause unsoundness. |
297 | // |
298 | // This was not a problem in #[auto_enum] attribute where enums are anonymized, |
299 | // but it becomes a problem when users have access to enums (i.e., when using #[enum_derive]). |
300 | // |
301 | // So, we ensure safety here by an Unpin implementation that implements Unpin |
302 | // only if all fields are Unpin (this also forbids custom Unpin implementation), |
303 | // and a hack that forbids custom Drop implementation. (Both are what pin-project does by default.) |
304 | // The repr(packed) check is not needed since repr(packed) is not available on enum. |
305 | |
306 | // Automatically create the appropriate conditional `Unpin` implementation. |
307 | // https://github.com/taiki-e/pin-project/blob/v1.0.10/examples/struct-default-expanded.rs#L89 |
308 | // TODO: use https://github.com/taiki-e/pin-project/issues/102#issuecomment-540472282's trick. |
309 | items.extend(derive_utils::derive_trait( |
310 | &data, |
311 | &parse_quote!(::core::marker::Unpin), |
312 | None, |
313 | parse_quote! { |
314 | trait Unpin {} |
315 | }, |
316 | )); |
317 | |
318 | let item: ItemEnum = data.into(); |
319 | let name = &item.ident; |
320 | let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl(); |
321 | // Ensure that enum does not implement `Drop`. |
322 | // https://github.com/taiki-e/pin-project/blob/v1.0.10/examples/struct-default-expanded.rs#L138 |
323 | items.extend(quote! { |
324 | const _: () = { |
325 | trait MustNotImplDrop {} |
326 | #[allow(clippy::drop_bounds, drop_bounds)] |
327 | impl<T: ::core::ops::Drop> MustNotImplDrop for T {} |
328 | impl #impl_generics MustNotImplDrop for #name #ty_generics #where_clause {} |
329 | }; |
330 | }); |
331 | item |
332 | } else { |
333 | data.into() |
334 | }; |
335 | |
336 | if !derive.is_empty() { |
337 | item.attrs.push(parse_quote!(#[derive(#(#derive),*)])); |
338 | } |
339 | |
340 | let mut item = item.into_token_stream(); |
341 | item.extend(items); |
342 | Ok(item) |
343 | } |
344 | |