1//! Derive macros for [bytemuck](https://docs.rs/bytemuck) traits.
2
3extern crate proc_macro;
4
5mod traits;
6
7use proc_macro2::TokenStream;
8use quote::quote;
9use syn::{parse_macro_input, DeriveInput, Result};
10
11use crate::traits::{
12 bytemuck_crate_name, AnyBitPattern, CheckedBitPattern, Contiguous, Derivable,
13 NoUninit, Pod, TransparentWrapper, Zeroable,
14};
15
16/// Derive the `Pod` trait for a struct
17///
18/// The macro ensures that the struct follows all the the safety requirements
19/// for the `Pod` trait.
20///
21/// The following constraints need to be satisfied for the macro to succeed
22///
23/// - All fields in the struct must implement `Pod`
24/// - The struct must be `#[repr(C)]` or `#[repr(transparent)]`
25/// - The struct must not contain any padding bytes
26/// - The struct contains no generic parameters, if it is not
27/// `#[repr(transparent)]`
28///
29/// ## Examples
30///
31/// ```rust
32/// # use std::marker::PhantomData;
33/// # use bytemuck_derive::{Pod, Zeroable};
34/// #[derive(Copy, Clone, Pod, Zeroable)]
35/// #[repr(C)]
36/// struct Test {
37/// a: u16,
38/// b: u16,
39/// }
40///
41/// #[derive(Copy, Clone, Pod, Zeroable)]
42/// #[repr(transparent)]
43/// struct Generic<A, B> {
44/// a: A,
45/// b: PhantomData<B>,
46/// }
47/// ```
48///
49/// If the struct is generic, it must be `#[repr(transparent)]` also.
50///
51/// ```compile_fail
52/// # use bytemuck::{Pod, Zeroable};
53/// # use std::marker::PhantomData;
54/// #[derive(Copy, Clone, Pod, Zeroable)]
55/// #[repr(C)] // must be `#[repr(transparent)]`
56/// struct Generic<A> {
57/// a: A,
58/// }
59/// ```
60///
61/// If the struct is generic and `#[repr(transparent)]`, then it is only `Pod`
62/// when all of its generics are `Pod`, not just its fields.
63///
64/// ```
65/// # use bytemuck::{Pod, Zeroable};
66/// # use std::marker::PhantomData;
67/// #[derive(Copy, Clone, Pod, Zeroable)]
68/// #[repr(transparent)]
69/// struct Generic<A, B> {
70/// a: A,
71/// b: PhantomData<B>,
72/// }
73///
74/// let _: u32 = bytemuck::cast(Generic { a: 4u32, b: PhantomData::<u32> });
75/// ```
76///
77/// ```compile_fail
78/// # use bytemuck::{Pod, Zeroable};
79/// # use std::marker::PhantomData;
80/// # #[derive(Copy, Clone, Pod, Zeroable)]
81/// # #[repr(transparent)]
82/// # struct Generic<A, B> {
83/// # a: A,
84/// # b: PhantomData<B>,
85/// # }
86/// struct NotPod;
87///
88/// let _: u32 = bytemuck::cast(Generic { a: 4u32, b: PhantomData::<NotPod> });
89/// ```
90#[proc_macro_derive(Pod, attributes(bytemuck))]
91pub fn derive_pod(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
92 let expanded: TokenStream =
93 derive_marker_trait::<Pod>(parse_macro_input!(input as DeriveInput));
94
95 proc_macro::TokenStream::from(expanded)
96}
97
98/// Derive the `AnyBitPattern` trait for a struct
99///
100/// The macro ensures that the struct follows all the the safety requirements
101/// for the `AnyBitPattern` trait.
102///
103/// The following constraints need to be satisfied for the macro to succeed
104///
105/// - All fields in the struct must to implement `AnyBitPattern`
106#[proc_macro_derive(AnyBitPattern, attributes(bytemuck))]
107pub fn derive_anybitpattern(
108 input: proc_macro::TokenStream,
109) -> proc_macro::TokenStream {
110 let expanded: TokenStream = derive_marker_trait::<AnyBitPattern>(parse_macro_input!(
111 input as DeriveInput
112 ));
113
114 proc_macro::TokenStream::from(expanded)
115}
116
117/// Derive the `Zeroable` trait for a struct
118///
119/// The macro ensures that the struct follows all the the safety requirements
120/// for the `Zeroable` trait.
121///
122/// The following constraints need to be satisfied for the macro to succeed
123///
124/// - All fields in the struct must to implement `Zeroable`
125///
126/// ## Example
127///
128/// ```rust
129/// # use bytemuck_derive::{Zeroable};
130/// #[derive(Copy, Clone, Zeroable)]
131/// #[repr(C)]
132/// struct Test {
133/// a: u16,
134/// b: u16,
135/// }
136/// ```
137///
138/// # Custom bounds
139///
140/// Custom bounds for the derived `Zeroable` impl can be given using the
141/// `#[zeroable(bound = "")]` helper attribute.
142///
143/// Using this attribute additionally opts-in to "perfect derive" semantics,
144/// where instead of adding bounds for each generic type parameter, bounds are
145/// added for each field's type.
146///
147/// ## Examples
148///
149/// ```rust
150/// # use bytemuck::Zeroable;
151/// # use std::marker::PhantomData;
152/// #[derive(Clone, Zeroable)]
153/// #[zeroable(bound = "")]
154/// struct AlwaysZeroable<T> {
155/// a: PhantomData<T>,
156/// }
157///
158/// AlwaysZeroable::<std::num::NonZeroU8>::zeroed();
159/// ```
160///
161/// ```rust,compile_fail
162/// # use bytemuck::Zeroable;
163/// # use std::marker::PhantomData;
164/// #[derive(Clone, Zeroable)]
165/// #[zeroable(bound = "T: Copy")]
166/// struct ZeroableWhenTIsCopy<T> {
167/// a: PhantomData<T>,
168/// }
169///
170/// ZeroableWhenTIsCopy::<String>::zeroed();
171/// ```
172///
173/// The restriction that all fields must be Zeroable is still applied, and this
174/// is enforced using the mentioned "perfect derive" semantics.
175///
176/// ```rust
177/// # use bytemuck::Zeroable;
178/// #[derive(Clone, Zeroable)]
179/// #[zeroable(bound = "")]
180/// struct ZeroableWhenTIsZeroable<T> {
181/// a: T,
182/// }
183/// ZeroableWhenTIsZeroable::<u32>::zeroed();
184/// ```
185///
186/// ```rust,compile_fail
187/// # use bytemuck::Zeroable;
188/// # #[derive(Clone, Zeroable)]
189/// # #[zeroable(bound = "")]
190/// # struct ZeroableWhenTIsZeroable<T> {
191/// # a: T,
192/// # }
193/// ZeroableWhenTIsZeroable::<String>::zeroed();
194/// ```
195#[proc_macro_derive(Zeroable, attributes(bytemuck, zeroable))]
196pub fn derive_zeroable(
197 input: proc_macro::TokenStream,
198) -> proc_macro::TokenStream {
199 let expanded: TokenStream =
200 derive_marker_trait::<Zeroable>(parse_macro_input!(input as DeriveInput));
201
202 proc_macro::TokenStream::from(expanded)
203}
204
205/// Derive the `NoUninit` trait for a struct or enum
206///
207/// The macro ensures that the type follows all the the safety requirements
208/// for the `NoUninit` trait.
209///
210/// The following constraints need to be satisfied for the macro to succeed
211/// (the rest of the constraints are guaranteed by the `NoUninit` subtrait
212/// bounds, i.e. the type must be `Sized + Copy + 'static`):
213///
214/// If applied to a struct:
215/// - All fields in the struct must implement `NoUninit`
216/// - The struct must be `#[repr(C)]` or `#[repr(transparent)]`
217/// - The struct must not contain any padding bytes
218/// - The struct must contain no generic parameters
219///
220/// If applied to an enum:
221/// - The enum must be explicit `#[repr(Int)]`, `#[repr(C)]`, or both
222/// - All variants must be fieldless
223/// - The enum must contain no generic parameters
224#[proc_macro_derive(NoUninit)]
225pub fn derive_no_uninit(
226 input: proc_macro::TokenStream,
227) -> proc_macro::TokenStream {
228 let expanded: TokenStream =
229 derive_marker_trait::<NoUninit>(parse_macro_input!(input as DeriveInput));
230
231 proc_macro::TokenStream::from(expanded)
232}
233
234/// Derive the `CheckedBitPattern` trait for a struct or enum.
235///
236/// The macro ensures that the type follows all the the safety requirements
237/// for the `CheckedBitPattern` trait and derives the required `Bits` type
238/// definition and `is_valid_bit_pattern` method for the type automatically.
239///
240/// The following constraints need to be satisfied for the macro to succeed:
241///
242/// If applied to a struct:
243/// - All fields must implement `CheckedBitPattern`
244/// - The struct must be `#[repr(C)]` or `#[repr(transparent)]`
245/// - The struct must contain no generic parameters
246///
247/// If applied to an enum:
248/// - The enum must be explicit `#[repr(Int)]`
249/// - All fields in variants must implement `CheckedBitPattern`
250/// - The enum must contain no generic parameters
251#[proc_macro_derive(CheckedBitPattern)]
252pub fn derive_maybe_pod(
253 input: proc_macro::TokenStream,
254) -> proc_macro::TokenStream {
255 let expanded: TokenStream = derive_marker_trait::<CheckedBitPattern>(parse_macro_input!(
256 input as DeriveInput
257 ));
258
259 proc_macro::TokenStream::from(expanded)
260}
261
262/// Derive the `TransparentWrapper` trait for a struct
263///
264/// The macro ensures that the struct follows all the the safety requirements
265/// for the `TransparentWrapper` trait.
266///
267/// The following constraints need to be satisfied for the macro to succeed
268///
269/// - The struct must be `#[repr(transparent)]`
270/// - The struct must contain the `Wrapped` type
271/// - Any ZST fields must be [`Zeroable`][derive@Zeroable].
272///
273/// If the struct only contains a single field, the `Wrapped` type will
274/// automatically be determined. If there is more then one field in the struct,
275/// you need to specify the `Wrapped` type using `#[transparent(T)]`
276///
277/// ## Examples
278///
279/// ```rust
280/// # use bytemuck_derive::TransparentWrapper;
281/// # use std::marker::PhantomData;
282/// #[derive(Copy, Clone, TransparentWrapper)]
283/// #[repr(transparent)]
284/// #[transparent(u16)]
285/// struct Test<T> {
286/// inner: u16,
287/// extra: PhantomData<T>,
288/// }
289/// ```
290///
291/// If the struct contains more than one field, the `Wrapped` type must be
292/// explicitly specified.
293///
294/// ```rust,compile_fail
295/// # use bytemuck_derive::TransparentWrapper;
296/// # use std::marker::PhantomData;
297/// #[derive(Copy, Clone, TransparentWrapper)]
298/// #[repr(transparent)]
299/// // missing `#[transparent(u16)]`
300/// struct Test<T> {
301/// inner: u16,
302/// extra: PhantomData<T>,
303/// }
304/// ```
305///
306/// Any ZST fields must be `Zeroable`.
307///
308/// ```rust,compile_fail
309/// # use bytemuck_derive::TransparentWrapper;
310/// # use std::marker::PhantomData;
311/// struct NonTransparentSafeZST;
312///
313/// #[derive(TransparentWrapper)]
314/// #[repr(transparent)]
315/// #[transparent(u16)]
316/// struct Test<T> {
317/// inner: u16,
318/// extra: PhantomData<T>,
319/// another_extra: NonTransparentSafeZST, // not `Zeroable`
320/// }
321/// ```
322#[proc_macro_derive(TransparentWrapper, attributes(bytemuck, transparent))]
323pub fn derive_transparent(
324 input: proc_macro::TokenStream,
325) -> proc_macro::TokenStream {
326 let expanded: TokenStream = derive_marker_trait::<TransparentWrapper>(parse_macro_input!(
327 input as DeriveInput
328 ));
329
330 proc_macro::TokenStream::from(expanded)
331}
332
333/// Derive the `Contiguous` trait for an enum
334///
335/// The macro ensures that the enum follows all the the safety requirements
336/// for the `Contiguous` trait.
337///
338/// The following constraints need to be satisfied for the macro to succeed
339///
340/// - The enum must be `#[repr(Int)]`
341/// - The enum must be fieldless
342/// - The enum discriminants must form a contiguous range
343///
344/// ## Example
345///
346/// ```rust
347/// # use bytemuck_derive::{Contiguous};
348///
349/// #[derive(Copy, Clone, Contiguous)]
350/// #[repr(u8)]
351/// enum Test {
352/// A = 0,
353/// B = 1,
354/// C = 2,
355/// }
356/// ```
357#[proc_macro_derive(Contiguous)]
358pub fn derive_contiguous(
359 input: proc_macro::TokenStream,
360) -> proc_macro::TokenStream {
361 let expanded: TokenStream =
362 derive_marker_trait::<Contiguous>(parse_macro_input!(input as DeriveInput));
363
364 proc_macro::TokenStream::from(expanded)
365}
366
367/// Derive the `PartialEq` and `Eq` trait for a type
368///
369/// The macro implements `PartialEq` and `Eq` by casting both sides of the
370/// comparison to a byte slice and then compares those.
371///
372/// ## Warning
373///
374/// Since this implements a byte wise comparison, the behavior of floating point
375/// numbers does not match their usual comparison behavior. Additionally other
376/// custom comparison behaviors of the individual fields are also ignored. This
377/// also does not implement `StructuralPartialEq` / `StructuralEq` like
378/// `PartialEq` / `Eq` would. This means you can't pattern match on the values.
379///
380/// ## Example
381///
382/// ```rust
383/// # use bytemuck_derive::{ByteEq, NoUninit};
384/// #[derive(Copy, Clone, NoUninit, ByteEq)]
385/// #[repr(C)]
386/// struct Test {
387/// a: u32,
388/// b: char,
389/// c: f32,
390/// }
391/// ```
392#[proc_macro_derive(ByteEq)]
393pub fn derive_byte_eq(
394 input: proc_macro::TokenStream,
395) -> proc_macro::TokenStream {
396 let input: DeriveInput = parse_macro_input!(input as DeriveInput);
397 let crate_name: TokenStream = bytemuck_crate_name(&input);
398 let ident: Ident = input.ident;
399
400 proc_macro::TokenStream::from(quote! {
401 impl ::core::cmp::PartialEq for #ident {
402 #[inline]
403 #[must_use]
404 fn eq(&self, other: &Self) -> bool {
405 #crate_name::bytes_of(self) == #crate_name::bytes_of(other)
406 }
407 }
408 impl ::core::cmp::Eq for #ident { }
409 })
410}
411
412/// Derive the `Hash` trait for a type
413///
414/// The macro implements `Hash` by casting the value to a byte slice and hashing
415/// that.
416///
417/// ## Warning
418///
419/// The hash does not match the standard library's `Hash` derive.
420///
421/// ## Example
422///
423/// ```rust
424/// # use bytemuck_derive::{ByteHash, NoUninit};
425/// #[derive(Copy, Clone, NoUninit, ByteHash)]
426/// #[repr(C)]
427/// struct Test {
428/// a: u32,
429/// b: char,
430/// c: f32,
431/// }
432/// ```
433#[proc_macro_derive(ByteHash)]
434pub fn derive_byte_hash(
435 input: proc_macro::TokenStream,
436) -> proc_macro::TokenStream {
437 let input: DeriveInput = parse_macro_input!(input as DeriveInput);
438 let crate_name: TokenStream = bytemuck_crate_name(&input);
439 let ident: Ident = input.ident;
440
441 proc_macro::TokenStream::from(quote! {
442 impl ::core::hash::Hash for #ident {
443 #[inline]
444 fn hash<H: ::core::hash::Hasher>(&self, state: &mut H) {
445 ::core::hash::Hash::hash_slice(#crate_name::bytes_of(self), state)
446 }
447
448 #[inline]
449 fn hash_slice<H: ::core::hash::Hasher>(data: &[Self], state: &mut H) {
450 ::core::hash::Hash::hash_slice(#crate_name::cast_slice::<_, u8>(data), state)
451 }
452 }
453 })
454}
455
456/// Basic wrapper for error handling
457fn derive_marker_trait<Trait: Derivable>(input: DeriveInput) -> TokenStream {
458 derive_marker_trait_inner::<Trait>(input)
459 .unwrap_or_else(|err: Error| err.into_compile_error())
460}
461
462/// Find `#[name(key = "value")]` helper attributes on the struct, and return
463/// their `"value"`s parsed with `parser`.
464///
465/// Returns an error if any attributes with the given `name` do not match the
466/// expected format. Returns `Ok([])` if no attributes with `name` are found.
467fn find_and_parse_helper_attributes<P: syn::parse::Parser + Copy>(
468 attributes: &[syn::Attribute], name: &str, key: &str, parser: P,
469 example_value: &str, invalid_value_msg: &str,
470) -> Result<Vec<P::Output>> {
471 let invalid_format_msg =
472 format!("{name} attribute must be `{name}({key} = \"{example_value}\")`",);
473 let values_to_check = attributes.iter().filter_map(|attr| match &attr.meta {
474 // If a `Path` matches our `name`, return an error, else ignore it.
475 // e.g. `#[zeroable]`
476 syn::Meta::Path(path) => path
477 .is_ident(name)
478 .then(|| Err(syn::Error::new_spanned(path, &invalid_format_msg))),
479 // If a `NameValue` matches our `name`, return an error, else ignore it.
480 // e.g. `#[zeroable = "hello"]`
481 syn::Meta::NameValue(namevalue) => {
482 namevalue.path.is_ident(name).then(|| {
483 Err(syn::Error::new_spanned(&namevalue.path, &invalid_format_msg))
484 })
485 }
486 // If a `List` matches our `name`, match its contents to our format, else
487 // ignore it. If its contents match our format, return the value, else
488 // return an error.
489 syn::Meta::List(list) => list.path.is_ident(name).then(|| {
490 let namevalue: syn::MetaNameValue = syn::parse2(list.tokens.clone())
491 .map_err(|_| {
492 syn::Error::new_spanned(&list.tokens, &invalid_format_msg)
493 })?;
494 if namevalue.path.is_ident(key) {
495 match namevalue.value {
496 syn::Expr::Lit(syn::ExprLit {
497 lit: syn::Lit::Str(strlit), ..
498 }) => Ok(strlit),
499 _ => {
500 Err(syn::Error::new_spanned(&namevalue.path, &invalid_format_msg))
501 }
502 }
503 } else {
504 Err(syn::Error::new_spanned(&namevalue.path, &invalid_format_msg))
505 }
506 }),
507 });
508 // Parse each value found with the given parser, and return them if no errors
509 // occur.
510 values_to_check
511 .map(|lit| {
512 let lit = lit?;
513 lit.parse_with(parser).map_err(|err| {
514 syn::Error::new_spanned(&lit, format!("{invalid_value_msg}: {err}"))
515 })
516 })
517 .collect()
518}
519
520fn derive_marker_trait_inner<Trait: Derivable>(
521 mut input: DeriveInput,
522) -> Result<TokenStream> {
523 let crate_name = bytemuck_crate_name(&input);
524 let trait_ = Trait::ident(&input, &crate_name)?;
525 // If this trait allows explicit bounds, and any explicit bounds were given,
526 // then use those explicit bounds. Else, apply the default bounds (bound
527 // each generic type on this trait).
528 if let Some(name) = Trait::explicit_bounds_attribute_name() {
529 // See if any explicit bounds were given in attributes.
530 let explicit_bounds = find_and_parse_helper_attributes(
531 &input.attrs,
532 name,
533 "bound",
534 <syn::punctuated::Punctuated<syn::WherePredicate, syn::Token![,]>>::parse_terminated,
535 "Type: Trait",
536 "invalid where predicate",
537 )?;
538
539 if !explicit_bounds.is_empty() {
540 // Explicit bounds were given.
541 // Enforce explicitly given bounds, and emit "perfect derive" (i.e. add
542 // bounds for each field's type).
543 let explicit_bounds = explicit_bounds
544 .into_iter()
545 .flatten()
546 .collect::<Vec<syn::WherePredicate>>();
547
548 let predicates = &mut input.generics.make_where_clause().predicates;
549
550 predicates.extend(explicit_bounds);
551
552 let fields = match &input.data {
553 syn::Data::Struct(syn::DataStruct { fields, .. }) => fields.clone(),
554 syn::Data::Union(_) => {
555 return Err(syn::Error::new_spanned(
556 trait_,
557 &"perfect derive is not supported for unions",
558 ));
559 }
560 syn::Data::Enum(_) => {
561 return Err(syn::Error::new_spanned(
562 trait_,
563 &"perfect derive is not supported for enums",
564 ));
565 }
566 };
567
568 for field in fields {
569 let ty = field.ty;
570 predicates.push(syn::parse_quote!(
571 #ty: #trait_
572 ));
573 }
574 } else {
575 // No explicit bounds were given.
576 // Enforce trait bound on all type generics.
577 add_trait_marker(&mut input.generics, &trait_);
578 }
579 } else {
580 // This trait does not allow explicit bounds.
581 // Enforce trait bound on all type generics.
582 add_trait_marker(&mut input.generics, &trait_);
583 }
584
585 let name = &input.ident;
586
587 let (impl_generics, ty_generics, where_clause) =
588 input.generics.split_for_impl();
589
590 Trait::check_attributes(&input.data, &input.attrs)?;
591 let asserts = Trait::asserts(&input, &crate_name)?;
592 let (trait_impl_extras, trait_impl) = Trait::trait_impl(&input, &crate_name)?;
593
594 let implies_trait = if let Some(implies_trait) =
595 Trait::implies_trait(&crate_name)
596 {
597 quote!(unsafe impl #impl_generics #implies_trait for #name #ty_generics #where_clause {})
598 } else {
599 quote!()
600 };
601
602 let where_clause =
603 if Trait::requires_where_clause() { where_clause } else { None };
604
605 Ok(quote! {
606 #asserts
607
608 #trait_impl_extras
609
610 unsafe impl #impl_generics #trait_ for #name #ty_generics #where_clause {
611 #trait_impl
612 }
613
614 #implies_trait
615 })
616}
617
618/// Add a trait marker to the generics if it is not already present
619fn add_trait_marker(generics: &mut syn::Generics, trait_name: &syn::Path) {
620 // Get each generic type parameter.
621 let type_params: Vec = genericsimpl Iterator
622 .type_params()
623 .map(|param: &TypeParam| &param.ident)
624 .map(|param: &Ident| {
625 syn::parse_quote!(
626 #param: #trait_name
627 )
628 })
629 .collect::<Vec<syn::WherePredicate>>();
630
631 generics.make_where_clause().predicates.extend(iter:type_params);
632}
633