1 | #![recursion_limit = "2048" ] |
2 | extern crate proc_macro; |
3 | #[macro_use ] |
4 | extern crate quote; |
5 | |
6 | use proc_macro2::{Span, TokenStream}; |
7 | use std::convert::TryFrom; |
8 | use syn::{ |
9 | parse::{Parse, ParseStream}, |
10 | parse_macro_input, |
11 | spanned::Spanned, |
12 | Expr, Ident, DeriveInput, Data, Token, Variant, |
13 | }; |
14 | |
15 | struct Flag<'a> { |
16 | name: Ident, |
17 | span: Span, |
18 | value: FlagValue<'a>, |
19 | } |
20 | |
21 | enum FlagValue<'a> { |
22 | Literal(u128), |
23 | Deferred, |
24 | Inferred(&'a mut Variant), |
25 | } |
26 | |
27 | impl FlagValue<'_> { |
28 | fn is_inferred(&self) -> bool { |
29 | matches!(self, FlagValue::Inferred(_)) |
30 | } |
31 | } |
32 | |
33 | struct Parameters { |
34 | default: Vec<Ident>, |
35 | } |
36 | |
37 | impl Parse for Parameters { |
38 | fn parse(input: ParseStream) -> syn::parse::Result<Self> { |
39 | if input.is_empty() { |
40 | return Ok(Parameters { default: vec![] }); |
41 | } |
42 | |
43 | input.parse::<Token![default]>()?; |
44 | input.parse::<Token![=]>()?; |
45 | let mut default: Vec = vec![input.parse()?]; |
46 | while !input.is_empty() { |
47 | input.parse::<Token![|]>()?; |
48 | default.push(input.parse()?); |
49 | } |
50 | |
51 | Ok(Parameters { default }) |
52 | } |
53 | } |
54 | |
55 | #[proc_macro_attribute ] |
56 | pub fn bitflags_internal ( |
57 | attr: proc_macro::TokenStream, |
58 | input: proc_macro::TokenStream, |
59 | ) -> proc_macro::TokenStream { |
60 | let Parameters { default: Vec } = parse_macro_input!(attr as Parameters); |
61 | let mut ast: DeriveInput = parse_macro_input!(input as DeriveInput); |
62 | let output: Result = gen_enumflags(&mut ast, default); |
63 | |
64 | outputTokenStream |
65 | .unwrap_or_else(|err: Error| { |
66 | let error: TokenStream = err.to_compile_error(); |
67 | quote! { |
68 | #ast |
69 | #error |
70 | } |
71 | }) |
72 | .into() |
73 | } |
74 | |
75 | /// Try to evaluate the expression given. |
76 | fn fold_expr(expr: &syn::Expr) -> Option<u128> { |
77 | match expr { |
78 | Expr::Lit(ref expr_lit: &ExprLit) => match expr_lit.lit { |
79 | syn::Lit::Int(ref lit_int: &LitInt) => lit_int.base10_parse().ok(), |
80 | _ => None, |
81 | }, |
82 | Expr::Binary(ref expr_binary: &ExprBinary) => { |
83 | let l: u128 = fold_expr(&expr_binary.left)?; |
84 | let r: u128 = fold_expr(&expr_binary.right)?; |
85 | match &expr_binary.op { |
86 | syn::BinOp::Shl(_) => u32::try_from(r).ok().and_then(|r: u32| l.checked_shl(r)), |
87 | _ => None, |
88 | } |
89 | } |
90 | Expr::Paren(syn::ExprParen { expr: &Box, .. }) | Expr::Group(syn::ExprGroup { expr: &Box, .. }) => { |
91 | fold_expr(expr) |
92 | } |
93 | _ => None, |
94 | } |
95 | } |
96 | |
97 | fn collect_flags<'a>( |
98 | variants: impl Iterator<Item = &'a mut Variant>, |
99 | ) -> Result<Vec<Flag<'a>>, syn::Error> { |
100 | variants |
101 | .map(|variant| { |
102 | if !matches!(variant.fields, syn::Fields::Unit) { |
103 | return Err(syn::Error::new_spanned( |
104 | &variant.fields, |
105 | "Bitflag variants cannot contain additional data" , |
106 | )); |
107 | } |
108 | |
109 | let name = variant.ident.clone(); |
110 | let span = variant.span(); |
111 | let value = if let Some(ref expr) = variant.discriminant { |
112 | if let Some(n) = fold_expr(&expr.1) { |
113 | FlagValue::Literal(n) |
114 | } else { |
115 | FlagValue::Deferred |
116 | } |
117 | } else { |
118 | FlagValue::Inferred(variant) |
119 | }; |
120 | |
121 | Ok(Flag { name, span, value }) |
122 | }) |
123 | .collect() |
124 | } |
125 | |
126 | fn inferred_value(type_name: &Ident, previous_variants: &[Ident], repr: &Ident) -> Expr { |
127 | let tokens: TokenStream = if previous_variants.is_empty() { |
128 | quote!(1) |
129 | } else { |
130 | quote!(::enumflags2::_internal::next_bit( |
131 | #(#type_name::#previous_variants as u128)|* |
132 | ) as #repr) |
133 | }; |
134 | |
135 | syn::parse2(tokens).expect(msg:"couldn't parse inferred value" ) |
136 | } |
137 | |
138 | fn infer_values(flags: &mut [Flag], type_name: &Ident, repr: &Ident) { |
139 | let mut previous_variants: Vec<Ident> = flagsimpl Iterator |
140 | .iter() |
141 | .filter(|flag: &&Flag<'_>| !flag.value.is_inferred()) |
142 | .map(|flag: &Flag<'_>| flag.name.clone()) |
143 | .collect(); |
144 | |
145 | for flag: &mut Flag<'_> in flags { |
146 | if let FlagValue::Inferred(ref mut variant: &mut &mut Variant) = flag.value { |
147 | variant.discriminant = Some(( |
148 | <Token![=]>::default(), |
149 | inferred_value(type_name, &previous_variants, repr), |
150 | )); |
151 | previous_variants.push(flag.name.clone()); |
152 | } |
153 | } |
154 | } |
155 | |
156 | /// Given a list of attributes, find the `repr`, if any, and return the integer |
157 | /// type specified. |
158 | fn extract_repr(attrs: &[syn::Attribute]) -> Result<Option<Ident>, syn::Error> { |
159 | let mut res: Option = None; |
160 | for attr: &Attribute in attrs { |
161 | if attr.path().is_ident("repr" ) { |
162 | attr.parse_nested_meta(|meta: ParseNestedMeta<'_>| { |
163 | if let Some(ident: &Ident) = meta.path.get_ident() { |
164 | res = Some(ident.clone()); |
165 | } |
166 | Ok(()) |
167 | })?; |
168 | } |
169 | } |
170 | Ok(res) |
171 | } |
172 | |
173 | /// Check the repr and return the number of bits available |
174 | fn type_bits(ty: &Ident) -> Result<u8, syn::Error> { |
175 | // This would be so much easier if we could just match on an Ident... |
176 | if ty == "usize" { |
177 | Err(syn::Error::new_spanned( |
178 | ty, |
179 | "#[repr(usize)] is not supported. Use u32 or u64 instead." , |
180 | )) |
181 | } else if ty == "i8" |
182 | || ty == "i16" |
183 | || ty == "i32" |
184 | || ty == "i64" |
185 | || ty == "i128" |
186 | || ty == "isize" |
187 | { |
188 | Err(syn::Error::new_spanned( |
189 | ty, |
190 | "Signed types in a repr are not supported." , |
191 | )) |
192 | } else if ty == "u8" { |
193 | Ok(8) |
194 | } else if ty == "u16" { |
195 | Ok(16) |
196 | } else if ty == "u32" { |
197 | Ok(32) |
198 | } else if ty == "u64" { |
199 | Ok(64) |
200 | } else if ty == "u128" { |
201 | Ok(128) |
202 | } else { |
203 | Err(syn::Error::new_spanned( |
204 | ty, |
205 | "repr must be an integer type for #[bitflags]." , |
206 | )) |
207 | } |
208 | } |
209 | |
210 | /// Returns deferred checks |
211 | fn check_flag(type_name: &Ident, flag: &Flag, bits: u8) -> Result<Option<TokenStream>, syn::Error> { |
212 | use FlagValue::*; |
213 | match flag.value { |
214 | Literal(n) => { |
215 | if !n.is_power_of_two() { |
216 | Err(syn::Error::new( |
217 | flag.span, |
218 | "Flags must have exactly one set bit" , |
219 | )) |
220 | } else if bits < 128 && n >= 1 << bits { |
221 | Err(syn::Error::new( |
222 | flag.span, |
223 | format!("Flag value out of range for u {}" , bits), |
224 | )) |
225 | } else { |
226 | Ok(None) |
227 | } |
228 | } |
229 | Inferred(_) => Ok(None), |
230 | Deferred => { |
231 | let variant_name = &flag.name; |
232 | Ok(Some(quote_spanned!(flag.span => |
233 | const _: |
234 | <<[(); ( |
235 | (#type_name::#variant_name as u128).is_power_of_two() |
236 | ) as usize] as ::enumflags2::_internal::AssertionHelper> |
237 | ::Status as ::enumflags2::_internal::ExactlyOneBitSet>::X |
238 | = (); |
239 | ))) |
240 | } |
241 | } |
242 | } |
243 | |
244 | fn gen_enumflags(ast: &mut DeriveInput, default: Vec<Ident>) -> Result<TokenStream, syn::Error> { |
245 | let ident = &ast.ident; |
246 | |
247 | let span = Span::call_site(); |
248 | |
249 | let ast_variants = match &mut ast.data { |
250 | Data::Enum(ref mut data) => &mut data.variants, |
251 | Data::Struct(data) => { |
252 | return Err(syn::Error::new_spanned(&data.struct_token, |
253 | "expected enum for #[bitflags], found struct" )); |
254 | } |
255 | Data::Union(data) => { |
256 | return Err(syn::Error::new_spanned(&data.union_token, |
257 | "expected enum for #[bitflags], found union" )); |
258 | } |
259 | }; |
260 | |
261 | if ast.generics.lt_token.is_some() || ast.generics.where_clause.is_some() { |
262 | return Err(syn::Error::new_spanned(&ast.generics, |
263 | "bitflags cannot be generic" )); |
264 | } |
265 | |
266 | let repr = extract_repr(&ast.attrs)? |
267 | .ok_or_else(|| syn::Error::new_spanned(ident, |
268 | "repr attribute missing. Add #[repr(u64)] or a similar attribute to specify the size of the bitfield." ))?; |
269 | let bits = type_bits(&repr)?; |
270 | |
271 | let mut variants = collect_flags(ast_variants.iter_mut())?; |
272 | let deferred = variants |
273 | .iter() |
274 | .flat_map(|variant| check_flag(ident, variant, bits).transpose()) |
275 | .collect::<Result<Vec<_>, _>>()?; |
276 | |
277 | infer_values(&mut variants, ident, &repr); |
278 | |
279 | if (bits as usize) < variants.len() { |
280 | return Err(syn::Error::new_spanned( |
281 | &repr, |
282 | format!("Not enough bits for {} flags" , variants.len()), |
283 | )); |
284 | } |
285 | |
286 | let std = quote_spanned!(span => ::enumflags2::_internal::core); |
287 | let ast_variants = match &ast.data { |
288 | Data::Enum(ref data) => &data.variants, |
289 | _ => unreachable!(), |
290 | }; |
291 | |
292 | let variant_names = ast_variants.iter().map(|v| &v.ident).collect::<Vec<_>>(); |
293 | |
294 | Ok(quote_spanned! { |
295 | span => |
296 | #ast |
297 | #(#deferred)* |
298 | impl #std::ops::Not for #ident { |
299 | type Output = ::enumflags2::BitFlags<Self>; |
300 | #[inline(always)] |
301 | fn not(self) -> Self::Output { |
302 | use ::enumflags2::BitFlags; |
303 | BitFlags::from_flag(self).not() |
304 | } |
305 | } |
306 | |
307 | impl #std::ops::BitOr for #ident { |
308 | type Output = ::enumflags2::BitFlags<Self>; |
309 | #[inline(always)] |
310 | fn bitor(self, other: Self) -> Self::Output { |
311 | use ::enumflags2::BitFlags; |
312 | BitFlags::from_flag(self) | other |
313 | } |
314 | } |
315 | |
316 | impl #std::ops::BitAnd for #ident { |
317 | type Output = ::enumflags2::BitFlags<Self>; |
318 | #[inline(always)] |
319 | fn bitand(self, other: Self) -> Self::Output { |
320 | use ::enumflags2::BitFlags; |
321 | BitFlags::from_flag(self) & other |
322 | } |
323 | } |
324 | |
325 | impl #std::ops::BitXor for #ident { |
326 | type Output = ::enumflags2::BitFlags<Self>; |
327 | #[inline(always)] |
328 | fn bitxor(self, other: Self) -> Self::Output { |
329 | use ::enumflags2::BitFlags; |
330 | BitFlags::from_flag(self) ^ other |
331 | } |
332 | } |
333 | |
334 | unsafe impl ::enumflags2::_internal::RawBitFlags for #ident { |
335 | type Numeric = #repr; |
336 | |
337 | const EMPTY: Self::Numeric = 0; |
338 | |
339 | const DEFAULT: Self::Numeric = |
340 | 0 #(| (Self::#default as #repr))*; |
341 | |
342 | const ALL_BITS: Self::Numeric = |
343 | 0 #(| (Self::#variant_names as #repr))*; |
344 | |
345 | const BITFLAGS_TYPE_NAME : &'static str = |
346 | concat!("BitFlags<" , stringify!(#ident), ">" ); |
347 | |
348 | fn bits(self) -> Self::Numeric { |
349 | self as #repr |
350 | } |
351 | } |
352 | |
353 | impl ::enumflags2::BitFlag for #ident {} |
354 | }) |
355 | } |
356 | |