1#![recursion_limit = "2048"]
2extern crate proc_macro;
3#[macro_use]
4extern crate quote;
5
6use proc_macro2::{Span, TokenStream};
7use std::convert::TryFrom;
8use syn::{
9 parse::{Parse, ParseStream},
10 parse_macro_input,
11 spanned::Spanned,
12 Expr, Ident, Item, ItemEnum, Token, Variant,
13};
14
15struct Flag<'a> {
16 name: Ident,
17 span: Span,
18 value: FlagValue<'a>,
19}
20
21enum FlagValue<'a> {
22 Literal(u128),
23 Deferred,
24 Inferred(&'a mut Variant),
25}
26
27impl FlagValue<'_> {
28 fn is_inferred(&self) -> bool {
29 matches!(self, FlagValue::Inferred(_))
30 }
31}
32
33struct Parameters {
34 default: Vec<Ident>,
35}
36
37impl Parse for Parameters {
38 fn parse(input: ParseStream) -> syn::parse::Result<Self> {
39 if input.is_empty() {
40 return Ok(Parameters { default: vec![] });
41 }
42
43 input.parse::<Token![default]>()?;
44 input.parse::<Token![=]>()?;
45 let mut default: Vec = vec![input.parse()?];
46 while !input.is_empty() {
47 input.parse::<Token![|]>()?;
48 default.push(input.parse()?);
49 }
50
51 Ok(Parameters { default })
52 }
53}
54
55#[proc_macro_attribute]
56pub fn bitflags_internal(
57 attr: proc_macro::TokenStream,
58 input: proc_macro::TokenStream,
59) -> proc_macro::TokenStream {
60 let Parameters { default: Vec } = parse_macro_input!(attr as Parameters);
61 let mut ast: Item = parse_macro_input!(input as Item);
62 let output: Result = match ast {
63 Item::Enum(ref mut item_enum: &mut ItemEnum) => gen_enumflags(ast:item_enum, default),
64 _ => Err(syn::Error::new_spanned(
65 &ast,
66 message:"#[bitflags] requires an enum",
67 )),
68 };
69
70 outputTokenStream
71 .unwrap_or_else(|err: Error| {
72 let error: TokenStream = err.to_compile_error();
73 quote! {
74 #ast
75 #error
76 }
77 })
78 .into()
79}
80
81/// Try to evaluate the expression given.
82fn fold_expr(expr: &syn::Expr) -> Option<u128> {
83 match expr {
84 Expr::Lit(ref expr_lit: &ExprLit) => match expr_lit.lit {
85 syn::Lit::Int(ref lit_int: &LitInt) => lit_int.base10_parse().ok(),
86 _ => None,
87 },
88 Expr::Binary(ref expr_binary: &ExprBinary) => {
89 let l: u128 = fold_expr(&expr_binary.left)?;
90 let r: u128 = fold_expr(&expr_binary.right)?;
91 match &expr_binary.op {
92 syn::BinOp::Shl(_) => u32::try_from(r).ok().and_then(|r: u32| l.checked_shl(r)),
93 _ => None,
94 }
95 }
96 Expr::Paren(syn::ExprParen { expr: &Box, .. }) | Expr::Group(syn::ExprGroup { expr: &Box, .. }) => {
97 fold_expr(expr)
98 }
99 _ => None,
100 }
101}
102
103fn collect_flags<'a>(
104 variants: impl Iterator<Item = &'a mut Variant>,
105) -> Result<Vec<Flag<'a>>, syn::Error> {
106 variants
107 .map(|variant| {
108 if !matches!(variant.fields, syn::Fields::Unit) {
109 return Err(syn::Error::new_spanned(
110 &variant.fields,
111 "Bitflag variants cannot contain additional data",
112 ));
113 }
114
115 let name = variant.ident.clone();
116 let span = variant.span();
117 let value = if let Some(ref expr) = variant.discriminant {
118 if let Some(n) = fold_expr(&expr.1) {
119 FlagValue::Literal(n)
120 } else {
121 FlagValue::Deferred
122 }
123 } else {
124 FlagValue::Inferred(variant)
125 };
126
127 Ok(Flag { name, span, value })
128 })
129 .collect()
130}
131
132fn inferred_value(type_name: &Ident, previous_variants: &[Ident], repr: &Ident) -> Expr {
133 let tokens: TokenStream = if previous_variants.is_empty() {
134 quote!(1)
135 } else {
136 quote!(::enumflags2::_internal::next_bit(
137 #(#type_name::#previous_variants as u128)|*
138 ) as #repr)
139 };
140
141 syn::parse2(tokens).expect(msg:"couldn't parse inferred value")
142}
143
144fn infer_values(flags: &mut [Flag], type_name: &Ident, repr: &Ident) {
145 let mut previous_variants: Vec<Ident> = flagsimpl Iterator
146 .iter()
147 .filter(|flag: &&Flag<'_>| !flag.value.is_inferred())
148 .map(|flag: &Flag<'_>| flag.name.clone())
149 .collect();
150
151 for flag: &mut Flag<'_> in flags {
152 if let FlagValue::Inferred(ref mut variant: &mut &mut Variant) = flag.value {
153 variant.discriminant = Some((
154 <Token![=]>::default(),
155 inferred_value(type_name, &previous_variants, repr),
156 ));
157 previous_variants.push(flag.name.clone());
158 }
159 }
160}
161
162/// Given a list of attributes, find the `repr`, if any, and return the integer
163/// type specified.
164fn extract_repr(attrs: &[syn::Attribute]) -> Result<Option<Ident>, syn::Error> {
165 let mut res: Option = None;
166 for attr: &Attribute in attrs {
167 if attr.path().is_ident("repr") {
168 attr.parse_nested_meta(|meta: ParseNestedMeta<'_>| {
169 if let Some(ident: &Ident) = meta.path.get_ident() {
170 res = Some(ident.clone());
171 }
172 Ok(())
173 })?;
174 }
175 }
176 Ok(res)
177}
178
179/// Check the repr and return the number of bits available
180fn type_bits(ty: &Ident) -> Result<u8, syn::Error> {
181 // This would be so much easier if we could just match on an Ident...
182 if ty == "usize" {
183 Err(syn::Error::new_spanned(
184 ty,
185 "#[repr(usize)] is not supported. Use u32 or u64 instead.",
186 ))
187 } else if ty == "i8"
188 || ty == "i16"
189 || ty == "i32"
190 || ty == "i64"
191 || ty == "i128"
192 || ty == "isize"
193 {
194 Err(syn::Error::new_spanned(
195 ty,
196 "Signed types in a repr are not supported.",
197 ))
198 } else if ty == "u8" {
199 Ok(8)
200 } else if ty == "u16" {
201 Ok(16)
202 } else if ty == "u32" {
203 Ok(32)
204 } else if ty == "u64" {
205 Ok(64)
206 } else if ty == "u128" {
207 Ok(128)
208 } else {
209 Err(syn::Error::new_spanned(
210 ty,
211 "repr must be an integer type for #[bitflags].",
212 ))
213 }
214}
215
216/// Returns deferred checks
217fn check_flag(type_name: &Ident, flag: &Flag, bits: u8) -> Result<Option<TokenStream>, syn::Error> {
218 use FlagValue::*;
219 match flag.value {
220 Literal(n) => {
221 if !n.is_power_of_two() {
222 Err(syn::Error::new(
223 flag.span,
224 "Flags must have exactly one set bit",
225 ))
226 } else if bits < 128 && n >= 1 << bits {
227 Err(syn::Error::new(
228 flag.span,
229 format!("Flag value out of range for u{}", bits),
230 ))
231 } else {
232 Ok(None)
233 }
234 }
235 Inferred(_) => Ok(None),
236 Deferred => {
237 let variant_name = &flag.name;
238 Ok(Some(quote_spanned!(flag.span =>
239 const _:
240 <<[(); (
241 (#type_name::#variant_name as u128).is_power_of_two()
242 ) as usize] as enumflags2::_internal::AssertionHelper>
243 ::Status as enumflags2::_internal::ExactlyOneBitSet>::X
244 = ();
245 )))
246 }
247 }
248}
249
250fn gen_enumflags(ast: &mut ItemEnum, default: Vec<Ident>) -> Result<TokenStream, syn::Error> {
251 let ident = &ast.ident;
252
253 let span = Span::call_site();
254
255 let repr = extract_repr(&ast.attrs)?
256 .ok_or_else(|| syn::Error::new_spanned(ident,
257 "repr attribute missing. Add #[repr(u64)] or a similar attribute to specify the size of the bitfield."))?;
258 let bits = type_bits(&repr)?;
259
260 let mut variants = collect_flags(ast.variants.iter_mut())?;
261 let deferred = variants
262 .iter()
263 .flat_map(|variant| check_flag(ident, variant, bits).transpose())
264 .collect::<Result<Vec<_>, _>>()?;
265
266 infer_values(&mut variants, ident, &repr);
267
268 if (bits as usize) < variants.len() {
269 return Err(syn::Error::new_spanned(
270 &repr,
271 format!("Not enough bits for {} flags", variants.len()),
272 ));
273 }
274
275 let std = quote_spanned!(span => ::enumflags2::_internal::core);
276 let variant_names = ast.variants.iter().map(|v| &v.ident).collect::<Vec<_>>();
277
278 Ok(quote_spanned! {
279 span =>
280 #ast
281 #(#deferred)*
282 impl #std::ops::Not for #ident {
283 type Output = ::enumflags2::BitFlags<Self>;
284 #[inline(always)]
285 fn not(self) -> Self::Output {
286 use ::enumflags2::BitFlags;
287 BitFlags::from_flag(self).not()
288 }
289 }
290
291 impl #std::ops::BitOr for #ident {
292 type Output = ::enumflags2::BitFlags<Self>;
293 #[inline(always)]
294 fn bitor(self, other: Self) -> Self::Output {
295 use ::enumflags2::BitFlags;
296 BitFlags::from_flag(self) | other
297 }
298 }
299
300 impl #std::ops::BitAnd for #ident {
301 type Output = ::enumflags2::BitFlags<Self>;
302 #[inline(always)]
303 fn bitand(self, other: Self) -> Self::Output {
304 use ::enumflags2::BitFlags;
305 BitFlags::from_flag(self) & other
306 }
307 }
308
309 impl #std::ops::BitXor for #ident {
310 type Output = ::enumflags2::BitFlags<Self>;
311 #[inline(always)]
312 fn bitxor(self, other: Self) -> Self::Output {
313 use ::enumflags2::BitFlags;
314 BitFlags::from_flag(self) ^ other
315 }
316 }
317
318 unsafe impl ::enumflags2::_internal::RawBitFlags for #ident {
319 type Numeric = #repr;
320
321 const EMPTY: Self::Numeric = 0;
322
323 const DEFAULT: Self::Numeric =
324 0 #(| (Self::#default as #repr))*;
325
326 const ALL_BITS: Self::Numeric =
327 0 #(| (Self::#variant_names as #repr))*;
328
329 const BITFLAGS_TYPE_NAME : &'static str =
330 concat!("BitFlags<", stringify!(#ident), ">");
331
332 fn bits(self) -> Self::Numeric {
333 self as #repr
334 }
335 }
336
337 impl ::enumflags2::BitFlag for #ident {}
338 })
339}
340