1/* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
4
5use proc_macro2::TokenStream;
6use std::collections::HashSet;
7use syn::fold::Fold;
8use syn::parse::{Parse, ParseStream, Parser, Result as ParseResult};
9
10// $(#[$outer:meta])*
11// ($($vis:tt)*) $BitFlags:ident: $T:ty {
12// $(
13// $(#[$inner:ident $($args:tt)*])*
14// const $Flag:ident = $value:expr;
15// )+
16// }
17#[derive(Debug)]
18pub struct BitflagsStruct {
19 attrs: Vec<syn::Attribute>,
20 vis: syn::Visibility,
21 #[allow(dead_code)]
22 struct_token: Token![struct],
23 name: syn::Ident,
24 #[allow(dead_code)]
25 colon_token: Token![:],
26 repr: syn::Type,
27 flags: Flags,
28}
29
30// impl $BitFlags:ident: $T:ty {
31// $(
32// $(#[$inner:ident $($args:tt)*])*
33// const $Flag:ident = $value:expr;
34// )+
35// }
36#[derive(Debug)]
37pub struct BitflagsImpl {
38 #[allow(dead_code)]
39 impl_token: Token![impl],
40 name: syn::Ident,
41 #[allow(dead_code)]
42 colon_token: Token![:],
43 repr: syn::Type,
44 flags: Flags,
45}
46
47#[derive(Debug)]
48pub enum Bitflags {
49 Struct(BitflagsStruct),
50 Impl(BitflagsImpl),
51}
52
53impl Bitflags {
54 pub fn expand(&self) -> (Option<syn::ItemStruct>, syn::ItemImpl) {
55 match self {
56 Bitflags::Struct(BitflagsStruct {
57 attrs,
58 vis,
59 name,
60 repr,
61 flags,
62 ..
63 }) => {
64 let struct_ = parse_quote! {
65 #(#attrs)*
66 #vis struct #name {
67 bits: #repr,
68 }
69 };
70
71 let consts = flags.expand(name, repr, false);
72 let impl_ = parse_quote! {
73 impl #name {
74 #consts
75 }
76 };
77
78 (Some(struct_), impl_)
79 }
80 Bitflags::Impl(BitflagsImpl {
81 name, repr, flags, ..
82 }) => {
83 let consts = flags.expand(name, repr, true);
84 let impl_: syn::ItemImpl = parse_quote! {
85 impl #name {
86 #consts
87 }
88 };
89 (None, impl_)
90 }
91 }
92 }
93}
94
95impl Parse for Bitflags {
96 fn parse(input: ParseStream) -> ParseResult<Self> {
97 Ok(if input.peek(Token![impl]) {
98 Self::Impl(BitflagsImpl {
99 impl_token: input.parse()?,
100 name: input.parse()?,
101 colon_token: input.parse()?,
102 repr: input.parse()?,
103 flags: input.parse()?,
104 })
105 } else {
106 Self::Struct(BitflagsStruct {
107 attrs: input.call(function:syn::Attribute::parse_outer)?,
108 vis: input.parse()?,
109 struct_token: input.parse()?,
110 name: input.parse()?,
111 colon_token: input.parse()?,
112 repr: input.parse()?,
113 flags: input.parse()?,
114 })
115 })
116 }
117}
118
119// $(#[$inner:ident $($args:tt)*])*
120// const $Flag:ident = $value:expr;
121#[derive(Debug)]
122struct Flag {
123 attrs: Vec<syn::Attribute>,
124 #[allow(dead_code)]
125 const_token: Token![const],
126 name: syn::Ident,
127 #[allow(dead_code)]
128 equals_token: Token![=],
129 value: syn::Expr,
130 #[allow(dead_code)]
131 semicolon_token: Token![;],
132}
133
134struct FlagValueFold<'a> {
135 struct_name: &'a syn::Ident,
136 flag_names: &'a HashSet<String>,
137 out_of_line: bool,
138}
139
140impl<'a> FlagValueFold<'a> {
141 fn is_self(&self, ident: &syn::Ident) -> bool {
142 ident == self.struct_name || ident == "Self"
143 }
144}
145
146impl<'a> Fold for FlagValueFold<'a> {
147 fn fold_expr(&mut self, node: syn::Expr) -> syn::Expr {
148 // bitflags 2 doesn't expose `bits` publically anymore, and the documented way to
149 // combine flags is using the `bits` method, e.g.
150 // ```
151 // bitflags! {
152 // struct Flags: u8 {
153 // const A = 1;
154 // const B = 1 << 1;
155 // const AB = Flags::A.bits() | Flags::B.bits();
156 // }
157 // }
158 // ```
159 // As we're transforming the struct definition into `struct StructName { bits: T }`
160 // as far as our bindings generation is concerned, `bits` is available as a field,
161 // so by replacing `StructName::FLAG.bits()` with `StructName::FLAG.bits`, we make
162 // e.g. `Flags::AB` available in the generated bindings.
163 // For out-of-line definitions of the struct(*), where the struct is defined as a
164 // newtype, we replace it with `StructName::FLAGS.0`.
165 // * definitions like:
166 // ```
167 // struct Flags(u8);
168 // bitflags! {
169 // impl Flags: u8 {
170 // const A = 1;
171 // const B = 1 << 1;
172 // const AB = Flags::A.bits() | Flags::B.bits();
173 // }
174 // }
175 // ```
176 match node {
177 syn::Expr::MethodCall(syn::ExprMethodCall {
178 attrs,
179 receiver,
180 dot_token,
181 method,
182 args,
183 ..
184 }) if method == "bits"
185 && args.is_empty()
186 && matches!(&*receiver,
187 syn::Expr::Path(syn::ExprPath { path, .. })
188 if path.segments.len() == 2
189 && self.is_self(&path.segments.first().unwrap().ident)
190 && self
191 .flag_names
192 .contains(&path.segments.last().unwrap().ident.to_string())) =>
193 {
194 return syn::Expr::Field(syn::ExprField {
195 attrs,
196 base: receiver,
197 dot_token,
198 member: if self.out_of_line {
199 syn::Member::Unnamed(parse_quote! {0})
200 } else {
201 syn::Member::Named(method)
202 },
203 });
204 }
205 _ => {}
206 }
207 syn::fold::fold_expr(self, node)
208 }
209}
210
211impl Flag {
212 fn expand(
213 &self,
214 struct_name: &syn::Ident,
215 repr: &syn::Type,
216 flag_names: &HashSet<String>,
217 out_of_line: bool,
218 ) -> TokenStream {
219 let Flag {
220 ref attrs,
221 ref name,
222 ref value,
223 ..
224 } = *self;
225 let folded_value = FlagValueFold {
226 struct_name,
227 flag_names,
228 out_of_line,
229 }
230 .fold_expr(value.clone());
231 let value = if out_of_line {
232 quote! { ((#folded_value) as #repr) }
233 } else {
234 quote! { { bits: (#folded_value) as #repr } }
235 };
236 quote! {
237 #(#attrs)*
238 pub const #name : #struct_name = #struct_name #value;
239 }
240 }
241}
242
243impl Parse for Flag {
244 fn parse(input: ParseStream) -> ParseResult<Self> {
245 Ok(Self {
246 attrs: input.call(function:syn::Attribute::parse_outer)?,
247 const_token: input.parse()?,
248 name: input.parse()?,
249 equals_token: input.parse()?,
250 value: input.parse()?,
251 semicolon_token: input.parse()?,
252 })
253 }
254}
255
256#[derive(Debug)]
257struct Flags(Vec<Flag>);
258
259impl Parse for Flags {
260 fn parse(input: ParseStream) -> ParseResult<Self> {
261 let content: ParseBuffer<'_>;
262 let _ = braced!(content in input);
263 let mut flags: Vec = vec![];
264 while !content.is_empty() {
265 flags.push(content.parse()?);
266 }
267 Ok(Flags(flags))
268 }
269}
270
271impl Flags {
272 fn expand(&self, struct_name: &syn::Ident, repr: &syn::Type, out_of_line: bool) -> TokenStream {
273 let mut ts: TokenStream = quote! {};
274 let flag_names: HashSet = self
275 .0
276 .iter()
277 .map(|flag: &Flag| flag.name.to_string())
278 .collect::<HashSet<_>>();
279 for flag: &Flag in &self.0 {
280 ts.extend(iter:flag.expand(struct_name, repr, &flag_names, out_of_line));
281 }
282 ts
283 }
284}
285
286pub fn parse(tokens: TokenStream) -> ParseResult<Bitflags> {
287 let parser: fn parse(&ParseBuffer<'_>) -> … = Bitflags::parse;
288 parser.parse2(tokens)
289}
290