| 1 | #![recursion_limit = "2048" ] |
| 2 | extern crate proc_macro; |
| 3 | #[macro_use ] |
| 4 | extern crate quote; |
| 5 | |
| 6 | use proc_macro2::{Span, TokenStream}; |
| 7 | use std::convert::TryFrom; |
| 8 | use syn::{ |
| 9 | parse::{Parse, ParseStream}, |
| 10 | parse_macro_input, |
| 11 | spanned::Spanned, |
| 12 | Expr, Ident, DeriveInput, Data, Token, Variant, |
| 13 | }; |
| 14 | |
| 15 | struct Flag<'a> { |
| 16 | name: Ident, |
| 17 | span: Span, |
| 18 | value: FlagValue<'a>, |
| 19 | } |
| 20 | |
| 21 | enum FlagValue<'a> { |
| 22 | Literal(u128), |
| 23 | Deferred, |
| 24 | Inferred(&'a mut Variant), |
| 25 | } |
| 26 | |
| 27 | impl FlagValue<'_> { |
| 28 | fn is_inferred(&self) -> bool { |
| 29 | matches!(self, FlagValue::Inferred(_)) |
| 30 | } |
| 31 | } |
| 32 | |
| 33 | struct Parameters { |
| 34 | default: Vec<Ident>, |
| 35 | } |
| 36 | |
| 37 | impl Parse for Parameters { |
| 38 | fn parse(input: ParseStream) -> syn::parse::Result<Self> { |
| 39 | if input.is_empty() { |
| 40 | return Ok(Parameters { default: vec![] }); |
| 41 | } |
| 42 | |
| 43 | input.parse::<Token![default]>()?; |
| 44 | input.parse::<Token![=]>()?; |
| 45 | let mut default: Vec = vec![input.parse()?]; |
| 46 | while !input.is_empty() { |
| 47 | input.parse::<Token![|]>()?; |
| 48 | default.push(input.parse()?); |
| 49 | } |
| 50 | |
| 51 | Ok(Parameters { default }) |
| 52 | } |
| 53 | } |
| 54 | |
| 55 | #[proc_macro_attribute ] |
| 56 | pub fn bitflags_internal ( |
| 57 | attr: proc_macro::TokenStream, |
| 58 | input: proc_macro::TokenStream, |
| 59 | ) -> proc_macro::TokenStream { |
| 60 | let Parameters { default: Vec } = parse_macro_input!(attr as Parameters); |
| 61 | let mut ast: DeriveInput = parse_macro_input!(input as DeriveInput); |
| 62 | let output: Result = gen_enumflags(&mut ast, default); |
| 63 | |
| 64 | outputTokenStream |
| 65 | .unwrap_or_else(|err: Error| { |
| 66 | let error: TokenStream = err.to_compile_error(); |
| 67 | quote! { |
| 68 | #ast |
| 69 | #error |
| 70 | } |
| 71 | }) |
| 72 | .into() |
| 73 | } |
| 74 | |
| 75 | /// Try to evaluate the expression given. |
| 76 | fn fold_expr(expr: &syn::Expr) -> Option<u128> { |
| 77 | match expr { |
| 78 | Expr::Lit(ref expr_lit: &ExprLit) => match expr_lit.lit { |
| 79 | syn::Lit::Int(ref lit_int: &LitInt) => lit_int.base10_parse().ok(), |
| 80 | _ => None, |
| 81 | }, |
| 82 | Expr::Binary(ref expr_binary: &ExprBinary) => { |
| 83 | let l: u128 = fold_expr(&expr_binary.left)?; |
| 84 | let r: u128 = fold_expr(&expr_binary.right)?; |
| 85 | match &expr_binary.op { |
| 86 | syn::BinOp::Shl(_) => u32::try_from(r).ok().and_then(|r: u32| l.checked_shl(r)), |
| 87 | _ => None, |
| 88 | } |
| 89 | } |
| 90 | Expr::Paren(syn::ExprParen { expr: &Box, .. }) | Expr::Group(syn::ExprGroup { expr: &Box, .. }) => { |
| 91 | fold_expr(expr) |
| 92 | } |
| 93 | _ => None, |
| 94 | } |
| 95 | } |
| 96 | |
| 97 | fn collect_flags<'a>( |
| 98 | variants: impl Iterator<Item = &'a mut Variant>, |
| 99 | ) -> Result<Vec<Flag<'a>>, syn::Error> { |
| 100 | variants |
| 101 | .map(|variant| { |
| 102 | if !matches!(variant.fields, syn::Fields::Unit) { |
| 103 | return Err(syn::Error::new_spanned( |
| 104 | &variant.fields, |
| 105 | "Bitflag variants cannot contain additional data" , |
| 106 | )); |
| 107 | } |
| 108 | |
| 109 | let name = variant.ident.clone(); |
| 110 | let span = variant.span(); |
| 111 | let value = if let Some(ref expr) = variant.discriminant { |
| 112 | if let Some(n) = fold_expr(&expr.1) { |
| 113 | FlagValue::Literal(n) |
| 114 | } else { |
| 115 | FlagValue::Deferred |
| 116 | } |
| 117 | } else { |
| 118 | FlagValue::Inferred(variant) |
| 119 | }; |
| 120 | |
| 121 | Ok(Flag { name, span, value }) |
| 122 | }) |
| 123 | .collect() |
| 124 | } |
| 125 | |
| 126 | fn inferred_value(type_name: &Ident, previous_variants: &[Ident], repr: &Ident) -> Expr { |
| 127 | let tokens: TokenStream = if previous_variants.is_empty() { |
| 128 | quote!(1) |
| 129 | } else { |
| 130 | quote!(::enumflags2::_internal::next_bit( |
| 131 | #(#type_name::#previous_variants as u128)|* |
| 132 | ) as #repr) |
| 133 | }; |
| 134 | |
| 135 | syn::parse2(tokens).expect(msg:"couldn't parse inferred value" ) |
| 136 | } |
| 137 | |
| 138 | fn infer_values(flags: &mut [Flag], type_name: &Ident, repr: &Ident) { |
| 139 | let mut previous_variants: Vec<Ident> = flagsimpl Iterator |
| 140 | .iter() |
| 141 | .filter(|flag: &&Flag<'_>| !flag.value.is_inferred()) |
| 142 | .map(|flag: &Flag<'_>| flag.name.clone()) |
| 143 | .collect(); |
| 144 | |
| 145 | for flag: &mut Flag<'_> in flags { |
| 146 | if let FlagValue::Inferred(ref mut variant: &mut &mut Variant) = flag.value { |
| 147 | variant.discriminant = Some(( |
| 148 | <Token![=]>::default(), |
| 149 | inferred_value(type_name, &previous_variants, repr), |
| 150 | )); |
| 151 | previous_variants.push(flag.name.clone()); |
| 152 | } |
| 153 | } |
| 154 | } |
| 155 | |
| 156 | /// Given a list of attributes, find the `repr`, if any, and return the integer |
| 157 | /// type specified. |
| 158 | fn extract_repr(attrs: &[syn::Attribute]) -> Result<Option<Ident>, syn::Error> { |
| 159 | let mut res: Option = None; |
| 160 | for attr: &Attribute in attrs { |
| 161 | if attr.path().is_ident("repr" ) { |
| 162 | attr.parse_nested_meta(|meta: ParseNestedMeta<'_>| { |
| 163 | if let Some(ident: &Ident) = meta.path.get_ident() { |
| 164 | res = Some(ident.clone()); |
| 165 | } |
| 166 | Ok(()) |
| 167 | })?; |
| 168 | } |
| 169 | } |
| 170 | Ok(res) |
| 171 | } |
| 172 | |
| 173 | /// Check the repr and return the number of bits available |
| 174 | fn type_bits(ty: &Ident) -> Result<u8, syn::Error> { |
| 175 | // This would be so much easier if we could just match on an Ident... |
| 176 | if ty == "usize" { |
| 177 | Err(syn::Error::new_spanned( |
| 178 | ty, |
| 179 | "#[repr(usize)] is not supported. Use u32 or u64 instead." , |
| 180 | )) |
| 181 | } else if ty == "i8" |
| 182 | || ty == "i16" |
| 183 | || ty == "i32" |
| 184 | || ty == "i64" |
| 185 | || ty == "i128" |
| 186 | || ty == "isize" |
| 187 | { |
| 188 | Err(syn::Error::new_spanned( |
| 189 | ty, |
| 190 | "Signed types in a repr are not supported." , |
| 191 | )) |
| 192 | } else if ty == "u8" { |
| 193 | Ok(8) |
| 194 | } else if ty == "u16" { |
| 195 | Ok(16) |
| 196 | } else if ty == "u32" { |
| 197 | Ok(32) |
| 198 | } else if ty == "u64" { |
| 199 | Ok(64) |
| 200 | } else if ty == "u128" { |
| 201 | Ok(128) |
| 202 | } else { |
| 203 | Err(syn::Error::new_spanned( |
| 204 | ty, |
| 205 | "repr must be an integer type for #[bitflags]." , |
| 206 | )) |
| 207 | } |
| 208 | } |
| 209 | |
| 210 | /// Returns deferred checks |
| 211 | fn check_flag(type_name: &Ident, flag: &Flag, bits: u8) -> Result<Option<TokenStream>, syn::Error> { |
| 212 | use FlagValue::*; |
| 213 | match flag.value { |
| 214 | Literal(n) => { |
| 215 | if !n.is_power_of_two() { |
| 216 | Err(syn::Error::new( |
| 217 | flag.span, |
| 218 | "Flags must have exactly one set bit" , |
| 219 | )) |
| 220 | } else if bits < 128 && n >= 1 << bits { |
| 221 | Err(syn::Error::new( |
| 222 | flag.span, |
| 223 | format!("Flag value out of range for u {}" , bits), |
| 224 | )) |
| 225 | } else { |
| 226 | Ok(None) |
| 227 | } |
| 228 | } |
| 229 | Inferred(_) => Ok(None), |
| 230 | Deferred => { |
| 231 | let variant_name = &flag.name; |
| 232 | Ok(Some(quote_spanned!(flag.span => |
| 233 | const _: |
| 234 | <<[(); ( |
| 235 | (#type_name::#variant_name as u128).is_power_of_two() |
| 236 | ) as usize] as ::enumflags2::_internal::AssertionHelper> |
| 237 | ::Status as ::enumflags2::_internal::ExactlyOneBitSet>::X |
| 238 | = (); |
| 239 | ))) |
| 240 | } |
| 241 | } |
| 242 | } |
| 243 | |
| 244 | fn gen_enumflags(ast: &mut DeriveInput, default: Vec<Ident>) -> Result<TokenStream, syn::Error> { |
| 245 | let ident = &ast.ident; |
| 246 | |
| 247 | let span = Span::call_site(); |
| 248 | |
| 249 | let ast_variants = match &mut ast.data { |
| 250 | Data::Enum(ref mut data) => &mut data.variants, |
| 251 | Data::Struct(data) => { |
| 252 | return Err(syn::Error::new_spanned(&data.struct_token, |
| 253 | "expected enum for #[bitflags], found struct" )); |
| 254 | } |
| 255 | Data::Union(data) => { |
| 256 | return Err(syn::Error::new_spanned(&data.union_token, |
| 257 | "expected enum for #[bitflags], found union" )); |
| 258 | } |
| 259 | }; |
| 260 | |
| 261 | if ast.generics.lt_token.is_some() || ast.generics.where_clause.is_some() { |
| 262 | return Err(syn::Error::new_spanned(&ast.generics, |
| 263 | "bitflags cannot be generic" )); |
| 264 | } |
| 265 | |
| 266 | let repr = extract_repr(&ast.attrs)? |
| 267 | .ok_or_else(|| syn::Error::new_spanned(ident, |
| 268 | "repr attribute missing. Add #[repr(u64)] or a similar attribute to specify the size of the bitfield." ))?; |
| 269 | let bits = type_bits(&repr)?; |
| 270 | |
| 271 | let mut variants = collect_flags(ast_variants.iter_mut())?; |
| 272 | let deferred = variants |
| 273 | .iter() |
| 274 | .flat_map(|variant| check_flag(ident, variant, bits).transpose()) |
| 275 | .collect::<Result<Vec<_>, _>>()?; |
| 276 | |
| 277 | infer_values(&mut variants, ident, &repr); |
| 278 | |
| 279 | if (bits as usize) < variants.len() { |
| 280 | return Err(syn::Error::new_spanned( |
| 281 | &repr, |
| 282 | format!("Not enough bits for {} flags" , variants.len()), |
| 283 | )); |
| 284 | } |
| 285 | |
| 286 | let std = quote_spanned!(span => ::enumflags2::_internal::core); |
| 287 | let ast_variants = match &ast.data { |
| 288 | Data::Enum(ref data) => &data.variants, |
| 289 | _ => unreachable!(), |
| 290 | }; |
| 291 | |
| 292 | let variant_names = ast_variants.iter().map(|v| &v.ident).collect::<Vec<_>>(); |
| 293 | |
| 294 | Ok(quote_spanned! { |
| 295 | span => |
| 296 | #ast |
| 297 | #(#deferred)* |
| 298 | impl #std::ops::Not for #ident { |
| 299 | type Output = ::enumflags2::BitFlags<Self>; |
| 300 | #[inline(always)] |
| 301 | fn not(self) -> Self::Output { |
| 302 | use ::enumflags2::BitFlags; |
| 303 | BitFlags::from_flag(self).not() |
| 304 | } |
| 305 | } |
| 306 | |
| 307 | impl #std::ops::BitOr for #ident { |
| 308 | type Output = ::enumflags2::BitFlags<Self>; |
| 309 | #[inline(always)] |
| 310 | fn bitor(self, other: Self) -> Self::Output { |
| 311 | use ::enumflags2::BitFlags; |
| 312 | BitFlags::from_flag(self) | other |
| 313 | } |
| 314 | } |
| 315 | |
| 316 | impl #std::ops::BitAnd for #ident { |
| 317 | type Output = ::enumflags2::BitFlags<Self>; |
| 318 | #[inline(always)] |
| 319 | fn bitand(self, other: Self) -> Self::Output { |
| 320 | use ::enumflags2::BitFlags; |
| 321 | BitFlags::from_flag(self) & other |
| 322 | } |
| 323 | } |
| 324 | |
| 325 | impl #std::ops::BitXor for #ident { |
| 326 | type Output = ::enumflags2::BitFlags<Self>; |
| 327 | #[inline(always)] |
| 328 | fn bitxor(self, other: Self) -> Self::Output { |
| 329 | use ::enumflags2::BitFlags; |
| 330 | BitFlags::from_flag(self) ^ other |
| 331 | } |
| 332 | } |
| 333 | |
| 334 | unsafe impl ::enumflags2::_internal::RawBitFlags for #ident { |
| 335 | type Numeric = #repr; |
| 336 | |
| 337 | const EMPTY: Self::Numeric = 0; |
| 338 | |
| 339 | const DEFAULT: Self::Numeric = |
| 340 | 0 #(| (Self::#default as #repr))*; |
| 341 | |
| 342 | const ALL_BITS: Self::Numeric = |
| 343 | 0 #(| (Self::#variant_names as #repr))*; |
| 344 | |
| 345 | const BITFLAGS_TYPE_NAME : &'static str = |
| 346 | concat!("BitFlags<" , stringify!(#ident), ">" ); |
| 347 | |
| 348 | fn bits(self) -> Self::Numeric { |
| 349 | self as #repr |
| 350 | } |
| 351 | } |
| 352 | |
| 353 | impl ::enumflags2::BitFlag for #ident {} |
| 354 | }) |
| 355 | } |
| 356 | |