1 | #![recursion_limit = "2048" ] |
2 | extern crate proc_macro; |
3 | #[macro_use ] |
4 | extern crate quote; |
5 | |
6 | use proc_macro2::{Span, TokenStream}; |
7 | use std::convert::TryFrom; |
8 | use syn::{ |
9 | parse::{Parse, ParseStream}, |
10 | parse_macro_input, |
11 | spanned::Spanned, |
12 | Expr, Ident, Item, ItemEnum, Token, Variant, |
13 | }; |
14 | |
15 | struct Flag<'a> { |
16 | name: Ident, |
17 | span: Span, |
18 | value: FlagValue<'a>, |
19 | } |
20 | |
21 | enum FlagValue<'a> { |
22 | Literal(u128), |
23 | Deferred, |
24 | Inferred(&'a mut Variant), |
25 | } |
26 | |
27 | impl FlagValue<'_> { |
28 | fn is_inferred(&self) -> bool { |
29 | matches!(self, FlagValue::Inferred(_)) |
30 | } |
31 | } |
32 | |
33 | struct Parameters { |
34 | default: Vec<Ident>, |
35 | } |
36 | |
37 | impl Parse for Parameters { |
38 | fn parse(input: ParseStream) -> syn::parse::Result<Self> { |
39 | if input.is_empty() { |
40 | return Ok(Parameters { default: vec![] }); |
41 | } |
42 | |
43 | input.parse::<Token![default]>()?; |
44 | input.parse::<Token![=]>()?; |
45 | let mut default: Vec = vec![input.parse()?]; |
46 | while !input.is_empty() { |
47 | input.parse::<Token![|]>()?; |
48 | default.push(input.parse()?); |
49 | } |
50 | |
51 | Ok(Parameters { default }) |
52 | } |
53 | } |
54 | |
55 | #[proc_macro_attribute ] |
56 | pub fn bitflags_internal ( |
57 | attr: proc_macro::TokenStream, |
58 | input: proc_macro::TokenStream, |
59 | ) -> proc_macro::TokenStream { |
60 | let Parameters { default: Vec } = parse_macro_input!(attr as Parameters); |
61 | let mut ast: Item = parse_macro_input!(input as Item); |
62 | let output: Result = match ast { |
63 | Item::Enum(ref mut item_enum: &mut ItemEnum) => gen_enumflags(ast:item_enum, default), |
64 | _ => Err(syn::Error::new_spanned( |
65 | &ast, |
66 | message:"#[bitflags] requires an enum" , |
67 | )), |
68 | }; |
69 | |
70 | outputTokenStream |
71 | .unwrap_or_else(|err: Error| { |
72 | let error: TokenStream = err.to_compile_error(); |
73 | quote! { |
74 | #ast |
75 | #error |
76 | } |
77 | }) |
78 | .into() |
79 | } |
80 | |
81 | /// Try to evaluate the expression given. |
82 | fn fold_expr(expr: &syn::Expr) -> Option<u128> { |
83 | match expr { |
84 | Expr::Lit(ref expr_lit: &ExprLit) => match expr_lit.lit { |
85 | syn::Lit::Int(ref lit_int: &LitInt) => lit_int.base10_parse().ok(), |
86 | _ => None, |
87 | }, |
88 | Expr::Binary(ref expr_binary: &ExprBinary) => { |
89 | let l: u128 = fold_expr(&expr_binary.left)?; |
90 | let r: u128 = fold_expr(&expr_binary.right)?; |
91 | match &expr_binary.op { |
92 | syn::BinOp::Shl(_) => u32::try_from(r).ok().and_then(|r: u32| l.checked_shl(r)), |
93 | _ => None, |
94 | } |
95 | } |
96 | Expr::Paren(syn::ExprParen { expr: &Box, .. }) | Expr::Group(syn::ExprGroup { expr: &Box, .. }) => { |
97 | fold_expr(expr) |
98 | } |
99 | _ => None, |
100 | } |
101 | } |
102 | |
103 | fn collect_flags<'a>( |
104 | variants: impl Iterator<Item = &'a mut Variant>, |
105 | ) -> Result<Vec<Flag<'a>>, syn::Error> { |
106 | variants |
107 | .map(|variant| { |
108 | if !matches!(variant.fields, syn::Fields::Unit) { |
109 | return Err(syn::Error::new_spanned( |
110 | &variant.fields, |
111 | "Bitflag variants cannot contain additional data" , |
112 | )); |
113 | } |
114 | |
115 | let name = variant.ident.clone(); |
116 | let span = variant.span(); |
117 | let value = if let Some(ref expr) = variant.discriminant { |
118 | if let Some(n) = fold_expr(&expr.1) { |
119 | FlagValue::Literal(n) |
120 | } else { |
121 | FlagValue::Deferred |
122 | } |
123 | } else { |
124 | FlagValue::Inferred(variant) |
125 | }; |
126 | |
127 | Ok(Flag { name, span, value }) |
128 | }) |
129 | .collect() |
130 | } |
131 | |
132 | fn inferred_value(type_name: &Ident, previous_variants: &[Ident], repr: &Ident) -> Expr { |
133 | let tokens: TokenStream = if previous_variants.is_empty() { |
134 | quote!(1) |
135 | } else { |
136 | quote!(::enumflags2::_internal::next_bit( |
137 | #(#type_name::#previous_variants as u128)|* |
138 | ) as #repr) |
139 | }; |
140 | |
141 | syn::parse2(tokens).expect(msg:"couldn't parse inferred value" ) |
142 | } |
143 | |
144 | fn infer_values(flags: &mut [Flag], type_name: &Ident, repr: &Ident) { |
145 | let mut previous_variants: Vec<Ident> = flagsimpl Iterator |
146 | .iter() |
147 | .filter(|flag: &&Flag<'_>| !flag.value.is_inferred()) |
148 | .map(|flag: &Flag<'_>| flag.name.clone()) |
149 | .collect(); |
150 | |
151 | for flag: &mut Flag<'_> in flags { |
152 | if let FlagValue::Inferred(ref mut variant: &mut &mut Variant) = flag.value { |
153 | variant.discriminant = Some(( |
154 | <Token![=]>::default(), |
155 | inferred_value(type_name, &previous_variants, repr), |
156 | )); |
157 | previous_variants.push(flag.name.clone()); |
158 | } |
159 | } |
160 | } |
161 | |
162 | /// Given a list of attributes, find the `repr`, if any, and return the integer |
163 | /// type specified. |
164 | fn extract_repr(attrs: &[syn::Attribute]) -> Result<Option<Ident>, syn::Error> { |
165 | let mut res: Option = None; |
166 | for attr: &Attribute in attrs { |
167 | if attr.path().is_ident("repr" ) { |
168 | attr.parse_nested_meta(|meta: ParseNestedMeta<'_>| { |
169 | if let Some(ident: &Ident) = meta.path.get_ident() { |
170 | res = Some(ident.clone()); |
171 | } |
172 | Ok(()) |
173 | })?; |
174 | } |
175 | } |
176 | Ok(res) |
177 | } |
178 | |
179 | /// Check the repr and return the number of bits available |
180 | fn type_bits(ty: &Ident) -> Result<u8, syn::Error> { |
181 | // This would be so much easier if we could just match on an Ident... |
182 | if ty == "usize" { |
183 | Err(syn::Error::new_spanned( |
184 | ty, |
185 | "#[repr(usize)] is not supported. Use u32 or u64 instead." , |
186 | )) |
187 | } else if ty == "i8" |
188 | || ty == "i16" |
189 | || ty == "i32" |
190 | || ty == "i64" |
191 | || ty == "i128" |
192 | || ty == "isize" |
193 | { |
194 | Err(syn::Error::new_spanned( |
195 | ty, |
196 | "Signed types in a repr are not supported." , |
197 | )) |
198 | } else if ty == "u8" { |
199 | Ok(8) |
200 | } else if ty == "u16" { |
201 | Ok(16) |
202 | } else if ty == "u32" { |
203 | Ok(32) |
204 | } else if ty == "u64" { |
205 | Ok(64) |
206 | } else if ty == "u128" { |
207 | Ok(128) |
208 | } else { |
209 | Err(syn::Error::new_spanned( |
210 | ty, |
211 | "repr must be an integer type for #[bitflags]." , |
212 | )) |
213 | } |
214 | } |
215 | |
216 | /// Returns deferred checks |
217 | fn check_flag(type_name: &Ident, flag: &Flag, bits: u8) -> Result<Option<TokenStream>, syn::Error> { |
218 | use FlagValue::*; |
219 | match flag.value { |
220 | Literal(n) => { |
221 | if !n.is_power_of_two() { |
222 | Err(syn::Error::new( |
223 | flag.span, |
224 | "Flags must have exactly one set bit" , |
225 | )) |
226 | } else if bits < 128 && n >= 1 << bits { |
227 | Err(syn::Error::new( |
228 | flag.span, |
229 | format!("Flag value out of range for u {}" , bits), |
230 | )) |
231 | } else { |
232 | Ok(None) |
233 | } |
234 | } |
235 | Inferred(_) => Ok(None), |
236 | Deferred => { |
237 | let variant_name = &flag.name; |
238 | Ok(Some(quote_spanned!(flag.span => |
239 | const _: |
240 | <<[(); ( |
241 | (#type_name::#variant_name as u128).is_power_of_two() |
242 | ) as usize] as enumflags2::_internal::AssertionHelper> |
243 | ::Status as enumflags2::_internal::ExactlyOneBitSet>::X |
244 | = (); |
245 | ))) |
246 | } |
247 | } |
248 | } |
249 | |
250 | fn gen_enumflags(ast: &mut ItemEnum, default: Vec<Ident>) -> Result<TokenStream, syn::Error> { |
251 | let ident = &ast.ident; |
252 | |
253 | let span = Span::call_site(); |
254 | |
255 | let repr = extract_repr(&ast.attrs)? |
256 | .ok_or_else(|| syn::Error::new_spanned(ident, |
257 | "repr attribute missing. Add #[repr(u64)] or a similar attribute to specify the size of the bitfield." ))?; |
258 | let bits = type_bits(&repr)?; |
259 | |
260 | let mut variants = collect_flags(ast.variants.iter_mut())?; |
261 | let deferred = variants |
262 | .iter() |
263 | .flat_map(|variant| check_flag(ident, variant, bits).transpose()) |
264 | .collect::<Result<Vec<_>, _>>()?; |
265 | |
266 | infer_values(&mut variants, ident, &repr); |
267 | |
268 | if (bits as usize) < variants.len() { |
269 | return Err(syn::Error::new_spanned( |
270 | &repr, |
271 | format!("Not enough bits for {} flags" , variants.len()), |
272 | )); |
273 | } |
274 | |
275 | let std = quote_spanned!(span => ::enumflags2::_internal::core); |
276 | let variant_names = ast.variants.iter().map(|v| &v.ident).collect::<Vec<_>>(); |
277 | |
278 | Ok(quote_spanned! { |
279 | span => |
280 | #ast |
281 | #(#deferred)* |
282 | impl #std::ops::Not for #ident { |
283 | type Output = ::enumflags2::BitFlags<Self>; |
284 | #[inline(always)] |
285 | fn not(self) -> Self::Output { |
286 | use ::enumflags2::BitFlags; |
287 | BitFlags::from_flag(self).not() |
288 | } |
289 | } |
290 | |
291 | impl #std::ops::BitOr for #ident { |
292 | type Output = ::enumflags2::BitFlags<Self>; |
293 | #[inline(always)] |
294 | fn bitor(self, other: Self) -> Self::Output { |
295 | use ::enumflags2::BitFlags; |
296 | BitFlags::from_flag(self) | other |
297 | } |
298 | } |
299 | |
300 | impl #std::ops::BitAnd for #ident { |
301 | type Output = ::enumflags2::BitFlags<Self>; |
302 | #[inline(always)] |
303 | fn bitand(self, other: Self) -> Self::Output { |
304 | use ::enumflags2::BitFlags; |
305 | BitFlags::from_flag(self) & other |
306 | } |
307 | } |
308 | |
309 | impl #std::ops::BitXor for #ident { |
310 | type Output = ::enumflags2::BitFlags<Self>; |
311 | #[inline(always)] |
312 | fn bitxor(self, other: Self) -> Self::Output { |
313 | use ::enumflags2::BitFlags; |
314 | BitFlags::from_flag(self) ^ other |
315 | } |
316 | } |
317 | |
318 | unsafe impl ::enumflags2::_internal::RawBitFlags for #ident { |
319 | type Numeric = #repr; |
320 | |
321 | const EMPTY: Self::Numeric = 0; |
322 | |
323 | const DEFAULT: Self::Numeric = |
324 | 0 #(| (Self::#default as #repr))*; |
325 | |
326 | const ALL_BITS: Self::Numeric = |
327 | 0 #(| (Self::#variant_names as #repr))*; |
328 | |
329 | const BITFLAGS_TYPE_NAME : &'static str = |
330 | concat!("BitFlags<" , stringify!(#ident), ">" ); |
331 | |
332 | fn bits(self) -> Self::Numeric { |
333 | self as #repr |
334 | } |
335 | } |
336 | |
337 | impl ::enumflags2::BitFlag for #ident {} |
338 | }) |
339 | } |
340 | |