1#![recursion_limit = "128"]
2
3extern crate proc_macro;
4
5use proc_macro::TokenStream;
6
7use proc_macro2::{TokenStream as TokenStream2, TokenTree};
8use quote::{quote, quote_spanned, ToTokens, TokenStreamExt};
9use syn::{
10 parse_macro_input, parse_quote, spanned::Spanned, Error, Fields, FnArg, Ident, ItemFn,
11 ItemStruct, LitStr, Pat, Visibility,
12};
13
14macro_rules! err {
15 ($span:expr, $message:expr $(,)?) => {
16 Error::new($span.span(), $message).to_compile_error()
17 };
18 ($span:expr, $message:expr, $($args:expr),*) => {
19 Error::new($span.span(), format!($message, $($args),*)).to_compile_error()
20 };
21}
22
23/// Attribute macro for marking structs as UEFI protocols.
24///
25/// The macro takes one argument, a GUID string.
26///
27/// The macro can only be applied to a struct, and the struct must have
28/// named fields (i.e. not a unit or tuple struct). It implements the
29/// [`Protocol`] trait and the `unsafe` [`Identify`] trait for the
30/// struct. It also adds a hidden field that causes the struct to be
31/// marked as [`!Send` and `!Sync`][send-and-sync].
32///
33/// # Safety
34///
35/// The caller must ensure that the correct GUID is attached to the
36/// type. An incorrect GUID could lead to invalid casts and other
37/// unsound behavior.
38///
39/// # Example
40///
41/// ```
42/// use uefi::{Identify, guid};
43/// use uefi::proto::unsafe_protocol;
44///
45/// #[unsafe_protocol("12345678-9abc-def0-1234-56789abcdef0")]
46/// struct ExampleProtocol {}
47///
48/// assert_eq!(ExampleProtocol::GUID, guid!("12345678-9abc-def0-1234-56789abcdef0"));
49/// ```
50///
51/// [`Identify`]: https://docs.rs/uefi/latest/uefi/trait.Identify.html
52/// [`Protocol`]: https://docs.rs/uefi/latest/uefi/proto/trait.Protocol.html
53/// [send-and-sync]: https://doc.rust-lang.org/nomicon/send-and-sync.html
54#[proc_macro_attribute]
55pub fn unsafe_protocol(args: TokenStream, input: TokenStream) -> TokenStream {
56 // Parse `args` as a GUID string.
57 let (time_low, time_mid, time_high_and_version, clock_seq_and_variant, node) =
58 match parse_guid(parse_macro_input!(args as LitStr)) {
59 Ok(data) => data,
60 Err(tokens) => return tokens.into(),
61 };
62
63 let item_struct = parse_macro_input!(input as ItemStruct);
64
65 let ident = &item_struct.ident;
66 let struct_attrs = &item_struct.attrs;
67 let struct_vis = &item_struct.vis;
68 let struct_fields = if let Fields::Named(struct_fields) = &item_struct.fields {
69 &struct_fields.named
70 } else {
71 return err!(item_struct, "Protocol struct must used named fields").into();
72 };
73 let struct_generics = &item_struct.generics;
74 let (impl_generics, ty_generics, where_clause) = item_struct.generics.split_for_impl();
75
76 quote! {
77 #(#struct_attrs)*
78 #struct_vis struct #ident #struct_generics {
79 // Add a hidden field with `PhantomData` of a raw
80 // pointer. This has the implicit side effect of making the
81 // struct !Send and !Sync.
82 _no_send_or_sync: ::core::marker::PhantomData<*const u8>,
83 #struct_fields
84 }
85
86 unsafe impl #impl_generics ::uefi::Identify for #ident #ty_generics #where_clause {
87 const GUID: ::uefi::Guid = ::uefi::Guid::from_values(
88 #time_low,
89 #time_mid,
90 #time_high_and_version,
91 #clock_seq_and_variant,
92 #node,
93 );
94 }
95
96 impl #impl_generics ::uefi::proto::Protocol for #ident #ty_generics #where_clause {}
97 }
98 .into()
99}
100
101/// Create a `Guid` at compile time.
102///
103/// # Example
104///
105/// ```
106/// use uefi::{guid, Guid};
107/// const EXAMPLE_GUID: Guid = guid!("12345678-9abc-def0-1234-56789abcdef0");
108/// ```
109#[proc_macro]
110pub fn guid(args: TokenStream) -> TokenStream {
111 let (time_low: u32, time_mid: u16, time_high_and_version: u16, clock_seq_and_variant: u16, node: u64) =
112 match parse_guid(guid_lit:parse_macro_input!(args as LitStr)) {
113 Ok(data: (u32, u16, u16, u16, u64)) => data,
114 Err(tokens: TokenStream) => return tokens.into(),
115 };
116
117 quoteTokenStream!({
118 const g: ::uefi::Guid = ::uefi::Guid::from_values(
119 #time_low,
120 #time_mid,
121 #time_high_and_version,
122 #clock_seq_and_variant,
123 #node,
124 );
125 g
126 })
127 .into()
128}
129
130fn parse_guid(guid_lit: LitStr) -> Result<(u32, u16, u16, u16, u64), TokenStream2> {
131 let guid_str = guid_lit.value();
132
133 // We expect a canonical GUID string, such as "12345678-9abc-def0-fedc-ba9876543210"
134 if guid_str.len() != 36 {
135 return Err(err!(
136 guid_lit,
137 "\"{}\" is not a canonical GUID string (expected 36 bytes, found {})",
138 guid_str,
139 guid_str.len()
140 ));
141 }
142 let mut offset = 1; // 1 is for the starting quote
143 let mut guid_hex_iter = guid_str.split('-');
144 let mut next_guid_int = |len: usize| -> Result<u64, TokenStream2> {
145 let guid_hex_component = guid_hex_iter.next().unwrap();
146
147 // convert syn::LitStr to proc_macro2::Literal..
148 let lit = match guid_lit.to_token_stream().into_iter().next().unwrap() {
149 TokenTree::Literal(lit) => lit,
150 _ => unreachable!(),
151 };
152 // ..so that we can call subspan and nightly users (us) will get the fancy span
153 let span = lit
154 .subspan(offset..offset + guid_hex_component.len())
155 .unwrap_or_else(|| lit.span());
156
157 if guid_hex_component.len() != len * 2 {
158 return Err(err!(
159 span,
160 "GUID component \"{}\" is not a {}-bit hexadecimal string",
161 guid_hex_component,
162 len * 8
163 ));
164 }
165 offset += guid_hex_component.len() + 1; // + 1 for the dash
166 u64::from_str_radix(guid_hex_component, 16).map_err(|_| {
167 err!(
168 span,
169 "GUID component \"{}\" is not a hexadecimal number",
170 guid_hex_component
171 )
172 })
173 };
174
175 // The GUID string is composed of a 32-bit integer, three 16-bit ones, and a 48-bit one
176 Ok((
177 next_guid_int(4)? as u32,
178 next_guid_int(2)? as u16,
179 next_guid_int(2)? as u16,
180 next_guid_int(2)? as u16,
181 next_guid_int(6)?,
182 ))
183}
184
185/// Get the name of a function's argument at `arg_index`.
186fn get_function_arg_name(f: &ItemFn, arg_index: usize, errors: &mut TokenStream2) -> Option<Ident> {
187 if let Some(FnArg::Typed(arg: &PatType)) = f.sig.inputs.iter().nth(arg_index) {
188 if let Pat::Ident(pat_ident: &PatIdent) = &*arg.pat {
189 // The argument has a valid name such as `handle` or `_handle`.
190 Some(pat_ident.ident.clone())
191 } else {
192 // The argument is unnamed, i.e. `_`.
193 errors.append_all(iter:err!(
194 arg.pat.span(),
195 "Entry method's arguments must be named"
196 ));
197 None
198 }
199 } else {
200 // Either there are too few arguments, or it's the wrong kind of
201 // argument (e.g. `self`).
202 //
203 // Don't append an error in this case. The error will be caught
204 // by the typecheck later on, which will give a better error
205 // message.
206 None
207 }
208}
209
210/// Custom attribute for a UEFI executable entry point.
211///
212/// This attribute modifies a function to mark it as the entry point for
213/// a UEFI executable. The function must have two parameters, [`Handle`]
214/// and [`SystemTable<Boot>`], and return a [`Status`]. The function can
215/// optionally be `unsafe`.
216///
217/// Due to internal implementation details the parameters must both be
218/// named, so `arg` or `_arg` are allowed, but not `_`.
219///
220/// The [`BootServices::set_image_handle`] function will be called
221/// automatically with the image [`Handle`] argument.
222///
223/// # Examples
224///
225/// ```no_run
226/// #![no_main]
227///
228/// use uefi::prelude::*;
229///
230/// #[entry]
231/// fn main(image: Handle, st: SystemTable<Boot>) -> Status {
232/// Status::SUCCESS
233/// }
234/// ```
235///
236/// [`Handle`]: https://docs.rs/uefi/latest/uefi/data_types/struct.Handle.html
237/// [`SystemTable<Boot>`]: https://docs.rs/uefi/latest/uefi/table/struct.SystemTable.html
238/// [`Status`]: https://docs.rs/uefi/latest/uefi/struct.Status.html
239/// [`BootServices::set_image_handle`]: https://docs.rs/uefi/latest/uefi/table/boot/struct.BootServices.html#method.set_image_handle
240#[proc_macro_attribute]
241pub fn entry(args: TokenStream, input: TokenStream) -> TokenStream {
242 // This code is inspired by the approach in this embedded Rust crate:
243 // https://github.com/rust-embedded/cortex-m-rt/blob/965bf1e3291571e7e3b34834864117dc020fb391/macros/src/lib.rs#L85
244
245 let mut errors = TokenStream2::new();
246
247 if !args.is_empty() {
248 errors.append_all(err!(
249 TokenStream2::from(args),
250 "Entry attribute accepts no arguments"
251 ));
252 }
253
254 let mut f = parse_macro_input!(input as ItemFn);
255
256 if let Some(ref abi) = f.sig.abi {
257 errors.append_all(err!(abi, "Entry method must have no ABI modifier"));
258 }
259 if let Some(asyncness) = f.sig.asyncness {
260 errors.append_all(err!(asyncness, "Entry method should not be async"));
261 }
262 if let Some(constness) = f.sig.constness {
263 errors.append_all(err!(constness, "Entry method should not be const"));
264 }
265 if !f.sig.generics.params.is_empty() {
266 errors.append_all(err!(
267 f.sig.generics.params,
268 "Entry method should not be generic"
269 ));
270 }
271
272 let image_handle_ident = get_function_arg_name(&f, 0, &mut errors);
273 let system_table_ident = get_function_arg_name(&f, 1, &mut errors);
274
275 // show most errors at once instead of one by one
276 if !errors.is_empty() {
277 return errors.into();
278 }
279
280 // allow the entry function to be unsafe (by moving the keyword around so that it actually works)
281 let unsafety = f.sig.unsafety.take();
282 // strip any visibility modifiers
283 f.vis = Visibility::Inherited;
284 // Set the global image handle. If `image_handle_ident` is `None`
285 // then the typecheck is going to fail anyway.
286 if let Some(image_handle_ident) = image_handle_ident {
287 f.block.stmts.insert(
288 0,
289 parse_quote! {
290 unsafe {
291 #system_table_ident.boot_services().set_image_handle(#image_handle_ident);
292 }
293 },
294 );
295 }
296
297 let fn_ident = &f.sig.ident;
298 // Get an iterator of the function inputs types. This is needed instead of
299 // directly using `sig.inputs` because patterns you can use in fn items like
300 // `mut <arg>` aren't valid in fn pointers.
301 let fn_inputs = f.sig.inputs.iter().map(|arg| match arg {
302 FnArg::Receiver(arg) => quote!(#arg),
303 FnArg::Typed(arg) => {
304 let ty = &arg.ty;
305 quote!(#ty)
306 }
307 });
308 let fn_output = &f.sig.output;
309 let signature_span = f.sig.span();
310
311 let fn_type_check = quote_spanned! {signature_span=>
312 // Cast from the function type to a function pointer with the same
313 // signature first, then try to assign that to an unnamed constant with
314 // the desired function pointer type.
315 //
316 // The cast is used to avoid an "expected fn pointer, found fn item"
317 // error if the signature is wrong, since that's not what we are
318 // interested in here. Instead we want to tell the user what
319 // specifically in the function signature is incorrect.
320 const _:
321 // The expected fn pointer type.
322 #unsafety extern "efiapi" fn(::uefi::Handle, ::uefi::table::SystemTable<::uefi::table::Boot>) -> ::uefi::Status =
323 // Cast from a fn item to a function pointer.
324 #fn_ident as #unsafety extern "efiapi" fn(#(#fn_inputs),*) #fn_output;
325 };
326
327 let result = quote! {
328 #fn_type_check
329
330 #[export_name = "efi_main"]
331 #unsafety extern "efiapi" #f
332
333 };
334 result.into()
335}
336
337/// Builds a `CStr8` literal at compile time from a string literal.
338///
339/// This will throw a compile error if an invalid character is in the passed string.
340///
341/// # Example
342/// ```
343/// # use uefi_macros::cstr8;
344/// assert_eq!(cstr8!("test").to_bytes_with_nul(), [116, 101, 115, 116, 0]);
345/// ```
346#[proc_macro]
347pub fn cstr8(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
348 let input: LitStr = parse_macro_input!(input);
349 let input: String = input.value();
350 match inputimpl Iterator>
351 .chars()
352 .map(u8::try_from)
353 .collect::<Result<Vec<u8>, _>>()
354 {
355 Ok(c: Vec) => {
356 quote!(unsafe { ::uefi::CStr8::from_bytes_with_nul_unchecked(&[ #(#c),* , 0 ]) }).into()
357 }
358 Err(_) => synTokenStream::Error::new_spanned(tokens:input, message:"invalid character in string")
359 .into_compile_error()
360 .into(),
361 }
362}
363
364/// Builds a `CStr16` literal at compile time from a string literal.
365///
366/// This will throw a compile error if an invalid character is in the passed string.
367///
368/// # Example
369/// ```
370/// # use uefi_macros::cstr16;
371/// assert_eq!(cstr16!("test €").to_u16_slice_with_nul(), [116, 101, 115, 116, 32, 8364, 0]);
372/// ```
373#[proc_macro]
374pub fn cstr16(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
375 let input: LitStr = parse_macro_input!(input);
376 let input: String = input.value();
377 match inputimpl Iterator>
378 .chars()
379 .map(|c: char| u16::try_from(c as u32))
380 .collect::<Result<Vec<u16>, _>>()
381 {
382 Ok(c: Vec) => {
383 quote!(unsafe { ::uefi::CStr16::from_u16_with_nul_unchecked(&[ #(#c),* , 0 ]) }).into()
384 }
385 Err(_) => synTokenStream::Error::new_spanned(tokens:input, message:"invalid character in string")
386 .into_compile_error()
387 .into(),
388 }
389}
390