1#![allow(unused_imports)]
2use std::{cmp, convert::TryFrom};
3
4use proc_macro2::{Ident, Span, TokenStream, TokenTree};
5use quote::{quote, quote_spanned, ToTokens};
6use syn::{
7 parse::{Parse, ParseStream, Parser},
8 punctuated::Punctuated,
9 spanned::Spanned,
10 Result, *,
11};
12
13macro_rules! bail {
14 ($msg:expr $(,)?) => {
15 return Err(Error::new(Span::call_site(), &$msg[..]))
16 };
17
18 ( $msg:expr => $span_to_blame:expr $(,)? ) => {
19 return Err(Error::new_spanned(&$span_to_blame, $msg))
20 };
21}
22
23pub trait Derivable {
24 fn ident(input: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path>;
25 fn implies_trait(_crate_name: &TokenStream) -> Option<TokenStream> {
26 None
27 }
28 fn asserts(_input: &DeriveInput, _crate_name: &TokenStream) -> Result<TokenStream> {
29 Ok(quote!())
30 }
31 fn check_attributes(_ty: &Data, _attributes: &[Attribute]) -> Result<()> {
32 Ok(())
33 }
34 fn trait_impl(_input: &DeriveInput, _crate_name: &TokenStream) -> Result<(TokenStream, TokenStream)> {
35 Ok((quote!(), quote!()))
36 }
37 fn requires_where_clause() -> bool {
38 true
39 }
40 fn explicit_bounds_attribute_name() -> Option<&'static str> {
41 None
42 }
43}
44
45pub struct Pod;
46
47impl Derivable for Pod {
48 fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
49 Ok(syn::parse_quote!(#crate_name::Pod))
50 }
51
52 fn asserts(input: &DeriveInput, crate_name: &TokenStream) -> Result<TokenStream> {
53 let repr = get_repr(&input.attrs)?;
54
55 let completly_packed =
56 repr.packed == Some(1) || repr.repr == Repr::Transparent;
57
58 if !completly_packed && !input.generics.params.is_empty() {
59 bail!("\
60 Pod requires cannot be derived for non-packed types containing \
61 generic parameters because the padding requirements can't be verified \
62 for generic non-packed structs\
63 " => input.generics.params.first().unwrap());
64 }
65
66 match &input.data {
67 Data::Struct(_) => {
68 let assert_no_padding = if !completly_packed {
69 Some(generate_assert_no_padding(input)?)
70 } else {
71 None
72 };
73 let assert_fields_are_pod =
74 generate_fields_are_trait(input, Self::ident(input, crate_name)?)?;
75
76 Ok(quote!(
77 #assert_no_padding
78 #assert_fields_are_pod
79 ))
80 }
81 Data::Enum(_) => bail!("Deriving Pod is not supported for enums"),
82 Data::Union(_) => bail!("Deriving Pod is not supported for unions"),
83 }
84 }
85
86 fn check_attributes(_ty: &Data, attributes: &[Attribute]) -> Result<()> {
87 let repr = get_repr(attributes)?;
88 match repr.repr {
89 Repr::C => Ok(()),
90 Repr::Transparent => Ok(()),
91 _ => {
92 bail!("Pod requires the type to be #[repr(C)] or #[repr(transparent)]")
93 }
94 }
95 }
96}
97
98pub struct AnyBitPattern;
99
100impl Derivable for AnyBitPattern {
101 fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
102 Ok(syn::parse_quote!(#crate_name::AnyBitPattern))
103 }
104
105 fn implies_trait(crate_name: &TokenStream) -> Option<TokenStream> {
106 Some(quote!(#crate_name::Zeroable))
107 }
108
109 fn asserts(input: &DeriveInput, crate_name: &TokenStream) -> Result<TokenStream> {
110 match &input.data {
111 Data::Union(_) => Ok(quote!()), // unions are always `AnyBitPattern`
112 Data::Struct(_) => generate_fields_are_trait(input, Self::ident(input, crate_name)?),
113 Data::Enum(_) => {
114 bail!("Deriving AnyBitPattern is not supported for enums")
115 }
116 }
117 }
118}
119
120pub struct Zeroable;
121
122impl Derivable for Zeroable {
123 fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
124 Ok(syn::parse_quote!(#crate_name::Zeroable))
125 }
126
127 fn asserts(input: &DeriveInput, crate_name: &TokenStream) -> Result<TokenStream> {
128 match &input.data {
129 Data::Union(_) => Ok(quote!()), // unions are always `Zeroable`
130 Data::Struct(_) => generate_fields_are_trait(input, Self::ident(input, crate_name)?),
131 Data::Enum(_) => bail!("Deriving Zeroable is not supported for enums"),
132 }
133 }
134
135 fn explicit_bounds_attribute_name() -> Option<&'static str> {
136 Some("zeroable")
137 }
138}
139
140pub struct NoUninit;
141
142impl Derivable for NoUninit {
143 fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
144 Ok(syn::parse_quote!(#crate_name::NoUninit))
145 }
146
147 fn check_attributes(ty: &Data, attributes: &[Attribute]) -> Result<()> {
148 let repr = get_repr(attributes)?;
149 match ty {
150 Data::Struct(_) => match repr.repr {
151 Repr::C | Repr::Transparent => Ok(()),
152 _ => bail!("NoUninit requires the struct to be #[repr(C)] or #[repr(transparent)]"),
153 },
154 Data::Enum(_) => if repr.repr.is_integer() {
155 Ok(())
156 } else {
157 bail!("NoUninit requires the enum to be an explicit #[repr(Int)]")
158 },
159 Data::Union(_) => bail!("NoUninit can only be derived on enums and structs")
160 }
161 }
162
163 fn asserts(input: &DeriveInput, crate_name: &TokenStream) -> Result<TokenStream> {
164 if !input.generics.params.is_empty() {
165 bail!("NoUninit cannot be derived for structs containing generic parameters because the padding requirements can't be verified for generic structs");
166 }
167
168 match &input.data {
169 Data::Struct(DataStruct { .. }) => {
170 let assert_no_padding = generate_assert_no_padding(&input)?;
171 let assert_fields_are_no_padding =
172 generate_fields_are_trait(&input, Self::ident(input, crate_name)?)?;
173
174 Ok(quote!(
175 #assert_no_padding
176 #assert_fields_are_no_padding
177 ))
178 }
179 Data::Enum(DataEnum { variants, .. }) => {
180 if variants.iter().any(|variant| !variant.fields.is_empty()) {
181 bail!("Only fieldless enums are supported for NoUninit")
182 } else {
183 Ok(quote!())
184 }
185 }
186 Data::Union(_) => bail!("NoUninit cannot be derived for unions"), /* shouldn't be possible since we already error in attribute check for this case */
187 }
188 }
189
190 fn trait_impl(_input: &DeriveInput, _crate_name: &TokenStream) -> Result<(TokenStream, TokenStream)> {
191 Ok((quote!(), quote!()))
192 }
193}
194
195pub struct CheckedBitPattern;
196
197impl Derivable for CheckedBitPattern {
198 fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
199 Ok(syn::parse_quote!(#crate_name::CheckedBitPattern))
200 }
201
202 fn check_attributes(ty: &Data, attributes: &[Attribute]) -> Result<()> {
203 let repr = get_repr(attributes)?;
204 match ty {
205 Data::Struct(_) => match repr.repr {
206 Repr::C | Repr::Transparent => Ok(()),
207 _ => bail!("CheckedBitPattern derive requires the struct to be #[repr(C)] or #[repr(transparent)]"),
208 },
209 Data::Enum(DataEnum { variants,.. }) => {
210 if !enum_has_fields(variants.iter()){
211 if repr.repr.is_integer() {
212 Ok(())
213 } else {
214 bail!("CheckedBitPattern requires the enum to be an explicit #[repr(Int)]")
215 }
216 } else if matches!(repr.repr, Repr::Rust) {
217 bail!("CheckedBitPattern requires an explicit repr annotation because `repr(Rust)` doesn't have a specified type layout")
218 } else {
219 Ok(())
220 }
221 }
222 Data::Union(_) => bail!("CheckedBitPattern can only be derived on enums and structs")
223 }
224 }
225
226 fn asserts(input: &DeriveInput, crate_name: &TokenStream) -> Result<TokenStream> {
227 if !input.generics.params.is_empty() {
228 bail!("CheckedBitPattern cannot be derived for structs containing generic parameters");
229 }
230
231 match &input.data {
232 Data::Struct(DataStruct { .. }) => {
233 let assert_fields_are_maybe_pod =
234 generate_fields_are_trait(&input, Self::ident(input, crate_name)?)?;
235
236 Ok(assert_fields_are_maybe_pod)
237 }
238 Data::Enum(_) => Ok(quote!()), /* nothing needed, already guaranteed OK by NoUninit */
239 Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), /* shouldn't be possible since we already error in attribute check for this case */
240 }
241 }
242
243 fn trait_impl(input: &DeriveInput, crate_name: &TokenStream) -> Result<(TokenStream, TokenStream)> {
244 match &input.data {
245 Data::Struct(DataStruct { fields, .. }) => {
246 generate_checked_bit_pattern_struct(&input.ident, fields, &input.attrs, crate_name)
247 }
248 Data::Enum(DataEnum { variants, .. }) => {
249 generate_checked_bit_pattern_enum(input, variants, crate_name)
250 }
251 Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), /* shouldn't be possible since we already error in attribute check for this case */
252 }
253 }
254}
255
256pub struct TransparentWrapper;
257
258impl TransparentWrapper {
259 fn get_wrapper_type(
260 attributes: &[Attribute], fields: &Fields,
261 ) -> Option<TokenStream> {
262 let transparent_param: Option = get_simple_attr(attributes, attr_name:"transparent");
263 transparent_param.map(|ident: Ident| ident.to_token_stream()).or_else(|| {
264 let mut types: impl Iterator = get_field_types(&fields);
265 let first_type: Option<&Type> = types.next();
266 if let Some(_) = types.next() {
267 // can't guess param type if there is more than one field
268 return None;
269 } else {
270 first_type.map(|ty: &Type| ty.to_token_stream())
271 }
272 })
273 }
274}
275
276impl Derivable for TransparentWrapper {
277 fn ident(input: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
278 let fields = get_struct_fields(input)?;
279
280 let ty = match Self::get_wrapper_type(&input.attrs, &fields) {
281 Some(ty) => ty,
282 None => bail!(
283 "\
284 when deriving TransparentWrapper for a struct with more than one field \
285 you need to specify the transparent field using #[transparent(T)]\
286 "
287 ),
288 };
289
290 Ok(syn::parse_quote!(#crate_name::TransparentWrapper<#ty>))
291 }
292
293 fn asserts(input: &DeriveInput, crate_name: &TokenStream) -> Result<TokenStream> {
294 let (impl_generics, _ty_generics, where_clause) =
295 input.generics.split_for_impl();
296 let fields = get_struct_fields(input)?;
297 let wrapped_type = match Self::get_wrapper_type(&input.attrs, &fields) {
298 Some(wrapped_type) => wrapped_type.to_string(),
299 None => unreachable!(), /* other code will already reject this derive */
300 };
301 let mut wrapped_field_ty = None;
302 let mut nonwrapped_field_tys = vec![];
303 for field in fields.iter() {
304 let field_ty = &field.ty;
305 if field_ty.to_token_stream().to_string() == wrapped_type {
306 if wrapped_field_ty.is_some() {
307 bail!(
308 "TransparentWrapper can only have one field of the wrapped type"
309 );
310 }
311 wrapped_field_ty = Some(field_ty);
312 } else {
313 nonwrapped_field_tys.push(field_ty);
314 }
315 }
316 if let Some(wrapped_field_ty) = wrapped_field_ty {
317 Ok(quote!(
318 const _: () = {
319 #[repr(transparent)]
320 struct AssertWrappedIsWrapped #impl_generics((u8, ::core::marker::PhantomData<#wrapped_field_ty>), #(#nonwrapped_field_tys),*) #where_clause;
321 fn assert_zeroable<Z: #crate_name::Zeroable>() {}
322 fn check #impl_generics () #where_clause {
323 #(
324 assert_zeroable::<#nonwrapped_field_tys>();
325 )*
326 }
327 };
328 ))
329 } else {
330 bail!("TransparentWrapper must have one field of the wrapped type")
331 }
332 }
333
334 fn check_attributes(_ty: &Data, attributes: &[Attribute]) -> Result<()> {
335 let repr = get_repr(attributes)?;
336
337 match repr.repr {
338 Repr::Transparent => Ok(()),
339 _ => {
340 bail!(
341 "TransparentWrapper requires the struct to be #[repr(transparent)]"
342 )
343 }
344 }
345 }
346
347 fn requires_where_clause() -> bool {
348 false
349 }
350}
351
352pub struct Contiguous;
353
354impl Derivable for Contiguous {
355 fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
356 Ok(syn::parse_quote!(#crate_name::Contiguous))
357 }
358
359 fn trait_impl(input: &DeriveInput, _crate_name: &TokenStream) -> Result<(TokenStream, TokenStream)> {
360 let repr = get_repr(&input.attrs)?;
361
362 let integer_ty = if let Some(integer_ty) = repr.repr.as_integer() {
363 integer_ty
364 } else {
365 bail!("Contiguous requires the enum to be #[repr(Int)]");
366 };
367
368 let variants = get_enum_variants(input)?;
369 if enum_has_fields(variants.clone()) {
370 return Err(Error::new_spanned(
371 &input,
372 "Only fieldless enums are supported",
373 ));
374 }
375
376 let mut variants_with_discriminator =
377 VariantDiscriminantIterator::new(variants);
378
379 let (min, max, count) = variants_with_discriminator.try_fold(
380 (i64::max_value(), i64::min_value(), 0),
381 |(min, max, count), res| {
382 let discriminator = res?;
383 Ok::<_, Error>((
384 i64::min(min, discriminator),
385 i64::max(max, discriminator),
386 count + 1,
387 ))
388 },
389 )?;
390
391 if max - min != count - 1 {
392 bail! {
393 "Contiguous requires the enum discriminants to be contiguous",
394 }
395 }
396
397 let min_lit = LitInt::new(&format!("{}", min), input.span());
398 let max_lit = LitInt::new(&format!("{}", max), input.span());
399
400 // `from_integer` and `into_integer` are usually provided by the trait's default implementation.
401 // We override this implementation because it goes through `transmute_copy`, which can lead to
402 // inefficient assembly as seen in https://github.com/Lokathor/bytemuck/issues/175 .
403
404 Ok((
405 quote!(),
406 quote! {
407 type Int = #integer_ty;
408 const MIN_VALUE: #integer_ty = #min_lit;
409 const MAX_VALUE: #integer_ty = #max_lit;
410
411 #[inline]
412 fn from_integer(value: Self::Int) -> Option<Self> {
413 #[allow(clippy::manual_range_contains)]
414 if Self::MIN_VALUE <= value && value <= Self::MAX_VALUE {
415 Some(unsafe { ::core::mem::transmute(value) })
416 } else {
417 None
418 }
419 }
420
421 #[inline]
422 fn into_integer(self) -> Self::Int {
423 self as #integer_ty
424 }
425 },
426 ))
427 }
428}
429
430fn get_struct_fields(input: &DeriveInput) -> Result<&Fields> {
431 if let Data::Struct(DataStruct { fields: &Fields, .. }) = &input.data {
432 Ok(fields)
433 } else {
434 bail!("deriving this trait is only supported for structs")
435 }
436}
437
438fn get_fields(input: &DeriveInput) -> Result<Fields> {
439 match &input.data {
440 Data::Struct(DataStruct { fields: &Fields, .. }) => Ok(fields.clone()),
441 Data::Union(DataUnion { fields: &FieldsNamed, .. }) => Ok(Fields::Named(fields.clone())),
442 Data::Enum(_) => bail!("deriving this trait is not supported for enums"),
443 }
444}
445
446fn get_enum_variants<'a>(
447 input: &'a DeriveInput,
448) -> Result<impl Iterator<Item = &'a Variant> + Clone + 'a> {
449 if let Data::Enum(DataEnum { variants: &Punctuated, .. }) = &input.data {
450 Ok(variants.iter())
451 } else {
452 bail!("deriving this trait is only supported for enums")
453 }
454}
455
456fn get_field_types<'a>(
457 fields: &'a Fields,
458) -> impl Iterator<Item = &'a Type> + 'a {
459 fields.iter().map(|field: &Field| &field.ty)
460}
461
462fn generate_checked_bit_pattern_struct(
463 input_ident: &Ident, fields: &Fields, attrs: &[Attribute], crate_name: &TokenStream
464) -> Result<(TokenStream, TokenStream)> {
465 let bits_ty = Ident::new(&format!("{}Bits", input_ident), input_ident.span());
466
467 let repr = get_repr(attrs)?;
468
469 let field_names = fields
470 .iter()
471 .enumerate()
472 .map(|(i, field)| {
473 field.ident.clone().unwrap_or_else(|| {
474 Ident::new(&format!("field{}", i), input_ident.span())
475 })
476 })
477 .collect::<Vec<_>>();
478 let field_tys = fields.iter().map(|field| &field.ty).collect::<Vec<_>>();
479
480 let field_name = &field_names[..];
481 let field_ty = &field_tys[..];
482
483 let derive_dbg =
484 quote!(#[cfg_attr(not(target_arch = "spirv"), derive(Debug))]);
485
486 Ok((
487 quote! {
488 #repr
489 #[derive(Clone, Copy, #crate_name::AnyBitPattern)]
490 #derive_dbg
491 pub struct #bits_ty {
492 #(#field_name: <#field_ty as #crate_name::CheckedBitPattern>::Bits,)*
493 }
494 },
495 quote! {
496 type Bits = #bits_ty;
497
498 #[inline]
499 #[allow(clippy::double_comparisons)]
500 fn is_valid_bit_pattern(bits: &#bits_ty) -> bool {
501 #(<#field_ty as #crate_name::CheckedBitPattern>::is_valid_bit_pattern(&{ bits.#field_name }) && )* true
502 }
503 },
504 ))
505}
506
507fn generate_checked_bit_pattern_enum(
508 input: &DeriveInput, variants: &Punctuated<Variant, Token![,]>, crate_name: &TokenStream
509) -> Result<(TokenStream, TokenStream)> {
510 if enum_has_fields(variants:variants.iter()) {
511 generate_checked_bit_pattern_enum_with_fields(input, variants, crate_name)
512 } else {
513 generate_checked_bit_pattern_enum_without_fields(input, variants)
514 }
515}
516
517fn generate_checked_bit_pattern_enum_without_fields(
518 input: &DeriveInput, variants: &Punctuated<Variant, Token![,]>,
519) -> Result<(TokenStream, TokenStream)> {
520 let span = input.span();
521 let mut variants_with_discriminant =
522 VariantDiscriminantIterator::new(variants.iter());
523
524 let (min, max, count) = variants_with_discriminant.try_fold(
525 (i64::max_value(), i64::min_value(), 0),
526 |(min, max, count), res| {
527 let discriminant = res?;
528 Ok::<_, Error>((
529 i64::min(min, discriminant),
530 i64::max(max, discriminant),
531 count + 1,
532 ))
533 },
534 )?;
535
536 let check = if count == 0 {
537 quote_spanned!(span => false)
538 } else if max - min == count - 1 {
539 // contiguous range
540 let min_lit = LitInt::new(&format!("{}", min), span);
541 let max_lit = LitInt::new(&format!("{}", max), span);
542
543 quote!(*bits >= #min_lit && *bits <= #max_lit)
544 } else {
545 // not contiguous range, check for each
546 let variant_lits = VariantDiscriminantIterator::new(variants.iter())
547 .map(|res| {
548 let variant = res?;
549 Ok(LitInt::new(&format!("{}", variant), span))
550 })
551 .collect::<Result<Vec<_>>>()?;
552
553 // count is at least 1
554 let first = &variant_lits[0];
555 let rest = &variant_lits[1..];
556
557 quote!(matches!(*bits, #first #(| #rest )*))
558 };
559
560 let repr = get_repr(&input.attrs)?;
561 let integer = repr.repr.as_integer().unwrap(); // should be checked in attr check already
562 Ok((
563 quote!(),
564 quote! {
565 type Bits = #integer;
566
567 #[inline]
568 #[allow(clippy::double_comparisons)]
569 fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
570 #check
571 }
572 },
573 ))
574}
575
576fn generate_checked_bit_pattern_enum_with_fields(
577 input: &DeriveInput, variants: &Punctuated<Variant, Token![,]>, crate_name: &TokenStream
578) -> Result<(TokenStream, TokenStream)> {
579 let representation = get_repr(&input.attrs)?;
580 let vis = &input.vis;
581
582 let derive_dbg =
583 quote!(#[cfg_attr(not(target_arch = "spirv"), derive(Debug))]);
584
585 match representation.repr {
586 Repr::Rust => unreachable!(),
587 repr @ (Repr::C | Repr::CWithDiscriminant(_)) => {
588 let integer = match repr {
589 Repr::C => quote!(::core::ffi::c_int),
590 Repr::CWithDiscriminant(integer) => quote!(#integer),
591 _ => unreachable!(),
592 };
593 let input_ident = &input.ident;
594
595 let bits_repr = Representation { repr: Repr::C, ..representation };
596
597 // the enum manually re-configured as the actual tagged union it represents,
598 // thus circumventing the requirements rust imposes on the tag even when using
599 // #[repr(C)] enum layout
600 // see: https://doc.rust-lang.org/reference/type-layout.html#reprc-enums-with-fields
601 let bits_ty_ident = Ident::new(&format!("{input_ident}Bits"), input.span());
602
603 // the variants union part of the tagged union. These get put into a union which gets the
604 // AnyBitPattern derive applied to it, thus checking that the fields of the union obey the requriements of AnyBitPattern.
605 // The types that actually go in the union are one more level of indirection deep: we generate new structs for each variant
606 // (`variant_struct_definitions`) which themselves have the `CheckedBitPattern` derive applied, thus generating `{variant_struct_ident}Bits`
607 // structs, which are the ones that go into this union.
608 let variants_union_ident =
609 Ident::new(&format!("{}Variants", input.ident), input.span());
610
611 let variant_struct_idents = variants
612 .iter()
613 .map(|v| Ident::new(&format!("{input_ident}Variant{}", v.ident), v.span()));
614
615 let variant_struct_definitions =
616 variant_struct_idents.clone().zip(variants.iter()).map(|(variant_struct_ident, v)| {
617 let fields = v.fields.iter().map(|v| &v.ty);
618
619 quote! {
620 #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::CheckedBitPattern)]
621 #[repr(C)]
622 #vis struct #variant_struct_ident(#(#fields),*);
623 }
624 });
625
626 let union_fields =
627 variant_struct_idents.clone().zip(variants.iter()).map(|(variant_struct_ident, v)| {
628 let variant_struct_bits_ident =
629 Ident::new(&format!("{variant_struct_ident}Bits"), input.span());
630 let field_ident = &v.ident;
631 quote! {
632 #field_ident: #variant_struct_bits_ident
633 }
634 });
635
636 let variant_checks = variant_struct_idents
637 .clone()
638 .zip(VariantDiscriminantIterator::new(variants.iter()))
639 .zip(variants.iter())
640 .map(|((variant_struct_ident, discriminant), v)| -> Result<_> {
641 let discriminant = discriminant?;
642 let discriminant = LitInt::new(&discriminant.to_string(), v.span());
643 let ident = &v.ident;
644 Ok(quote! {
645 #discriminant => {
646 let payload = unsafe { &bits.payload.#ident };
647 <#variant_struct_ident as #crate_name::CheckedBitPattern>::is_valid_bit_pattern(payload)
648 }
649 })
650 })
651 .collect::<Result<Vec<_>>>()?;
652
653 Ok((
654 quote! {
655 #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::AnyBitPattern)]
656 #derive_dbg
657 #bits_repr
658 #vis struct #bits_ty_ident {
659 tag: #integer,
660 payload: #variants_union_ident,
661 }
662
663 #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::AnyBitPattern)]
664 #[repr(C)]
665 #[allow(non_snake_case)]
666 #vis union #variants_union_ident {
667 #(#union_fields,)*
668 }
669
670 #[cfg(not(target_arch = "spirv"))]
671 impl ::core::fmt::Debug for #variants_union_ident {
672 fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
673 let mut debug_struct = ::core::fmt::Formatter::debug_struct(f, ::core::stringify!(#variants_union_ident));
674 ::core::fmt::DebugStruct::finish_non_exhaustive(&mut debug_struct)
675 }
676 }
677
678 #(#variant_struct_definitions)*
679 },
680 quote! {
681 type Bits = #bits_ty_ident;
682
683 #[inline]
684 #[allow(clippy::double_comparisons)]
685 fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
686 match bits.tag {
687 #(#variant_checks)*
688 _ => false,
689 }
690 }
691 },
692 ))
693 }
694 Repr::Transparent => {
695 if variants.len() != 1 {
696 bail!("enums with more than one variant cannot be transparent")
697 }
698
699 let variant = &variants[0];
700
701 let bits_ty = Ident::new(&format!("{}Bits", input.ident), input.span());
702 let fields = variant.fields.iter().map(|v| &v.ty);
703
704 Ok((
705 quote! {
706 #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::CheckedBitPattern)]
707 #[repr(C)]
708 #vis struct #bits_ty(#(#fields),*);
709 },
710 quote! {
711 type Bits = <#bits_ty as #crate_name::CheckedBitPattern>::Bits;
712
713 #[inline]
714 #[allow(clippy::double_comparisons)]
715 fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
716 <#bits_ty as #crate_name::CheckedBitPattern>::is_valid_bit_pattern(bits)
717 }
718 },
719 ))
720 }
721 Repr::Integer(integer) => {
722 let bits_repr = Representation { repr: Repr::C, ..representation };
723 let input_ident = &input.ident;
724
725 // the enum manually re-configured as the union it represents. such a union is the union of variants
726 // as a repr(c) struct with the discriminator type inserted at the beginning.
727 // in our case we union the `Bits` representation of each variant rather than the variant itself, which we generate
728 // via a nested `CheckedBitPattern` derive on the `variant_struct_definitions` generated below.
729 //
730 // see: https://doc.rust-lang.org/reference/type-layout.html#primitive-representation-of-enums-with-fields
731 let bits_ty_ident = Ident::new(&format!("{input_ident}Bits"), input.span());
732
733 let variant_struct_idents = variants
734 .iter()
735 .map(|v| Ident::new(&format!("{input_ident}Variant{}", v.ident), v.span()));
736
737 let variant_struct_definitions =
738 variant_struct_idents.clone().zip(variants.iter()).map(|(variant_struct_ident, v)| {
739 let fields = v.fields.iter().map(|v| &v.ty);
740
741 // adding the discriminant repr integer as first field, as described above
742 quote! {
743 #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::CheckedBitPattern)]
744 #[repr(C)]
745 #vis struct #variant_struct_ident(#integer, #(#fields),*);
746 }
747 });
748
749 let union_fields =
750 variant_struct_idents.clone().zip(variants.iter()).map(|(variant_struct_ident, v)| {
751 let variant_struct_bits_ident =
752 Ident::new(&format!("{variant_struct_ident}Bits"), input.span());
753 let field_ident = &v.ident;
754 quote! {
755 #field_ident: #variant_struct_bits_ident
756 }
757 });
758
759 let variant_checks = variant_struct_idents
760 .clone()
761 .zip(VariantDiscriminantIterator::new(variants.iter()))
762 .zip(variants.iter())
763 .map(|((variant_struct_ident, discriminant), v)| -> Result<_> {
764 let discriminant = discriminant?;
765 let discriminant = LitInt::new(&discriminant.to_string(), v.span());
766 let ident = &v.ident;
767 Ok(quote! {
768 #discriminant => {
769 let payload = unsafe { &bits.#ident };
770 <#variant_struct_ident as #crate_name::CheckedBitPattern>::is_valid_bit_pattern(payload)
771 }
772 })
773 })
774 .collect::<Result<Vec<_>>>()?;
775
776 Ok((
777 quote! {
778 #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::AnyBitPattern)]
779 #bits_repr
780 #[allow(non_snake_case)]
781 #vis union #bits_ty_ident {
782 __tag: #integer,
783 #(#union_fields,)*
784 }
785
786 #[cfg(not(target_arch = "spirv"))]
787 impl ::core::fmt::Debug for #bits_ty_ident {
788 fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
789 let mut debug_struct = ::core::fmt::Formatter::debug_struct(f, ::core::stringify!(#bits_ty_ident));
790 ::core::fmt::DebugStruct::field(&mut debug_struct, "tag", unsafe { &self.__tag });
791 ::core::fmt::DebugStruct::finish_non_exhaustive(&mut debug_struct)
792 }
793 }
794
795 #(#variant_struct_definitions)*
796 },
797 quote! {
798 type Bits = #bits_ty_ident;
799
800 #[inline]
801 #[allow(clippy::double_comparisons)]
802 fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
803 match unsafe { bits.__tag } {
804 #(#variant_checks)*
805 _ => false,
806 }
807 }
808 },
809 ))
810 }
811 }
812}
813
814/// Check that a struct has no padding by asserting that the size of the struct
815/// is equal to the sum of the size of it's fields
816fn generate_assert_no_padding(input: &DeriveInput) -> Result<TokenStream> {
817 let struct_type: &Ident = &input.ident;
818 let span: Span = input.ident.span();
819 let fields: Fields = get_fields(input)?;
820
821 let mut field_types: impl Iterator = get_field_types(&fields);
822 let size_sum: TokenStream = if let Some(first: &Type) = field_types.next() {
823 let size_first: TokenStream = quote_spanned!(span => ::core::mem::size_of::<#first>());
824 let size_rest: TokenStream =
825 quote_spanned!(span => #( + ::core::mem::size_of::<#field_types>() )*);
826
827 quote_spanned!(span => #size_first #size_rest)
828 } else {
829 quote_spanned!(span => 0)
830 };
831
832 Ok(quote_spanned! {span => const _: fn() = || {
833 #[doc(hidden)]
834 struct TypeWithoutPadding([u8; #size_sum]);
835 let _ = ::core::mem::transmute::<#struct_type, TypeWithoutPadding>;
836 };})
837}
838
839/// Check that all fields implement a given trait
840fn generate_fields_are_trait(
841 input: &DeriveInput, trait_: syn::Path,
842) -> Result<TokenStream> {
843 let (impl_generics: ImplGenerics<'_>, _ty_generics: TypeGenerics<'_>, where_clause: Option<&WhereClause>) =
844 input.generics.split_for_impl();
845 let fields: Fields = get_fields(input)?;
846 let span: Span = input.span();
847 let field_types: impl Iterator = get_field_types(&fields);
848 Ok(quote_spanned! {span => #(const _: fn() = || {
849 #[allow(clippy::missing_const_for_fn)]
850 #[doc(hidden)]
851 fn check #impl_generics () #where_clause {
852 fn assert_impl<T: #trait_>() {}
853 assert_impl::<#field_types>();
854 }
855 };)*
856 })
857}
858
859fn get_ident_from_stream(tokens: TokenStream) -> Option<Ident> {
860 match tokens.into_iter().next() {
861 Some(TokenTree::Group(group: Group)) => get_ident_from_stream(tokens:group.stream()),
862 Some(TokenTree::Ident(ident: Ident)) => Some(ident),
863 _ => None,
864 }
865}
866
867/// get a simple #[foo(bar)] attribute, returning "bar"
868fn get_simple_attr(attributes: &[Attribute], attr_name: &str) -> Option<Ident> {
869 for attr: &Attribute in attributes {
870 if let (AttrStyle::Outer, Meta::List(list: &MetaList)) = (&attr.style, &attr.meta) {
871 if list.path.is_ident(attr_name) {
872 if let Some(ident: Ident) = get_ident_from_stream(tokens:list.tokens.clone()) {
873 return Some(ident);
874 }
875 }
876 }
877 }
878
879 None
880}
881
882fn get_repr(attributes: &[Attribute]) -> Result<Representation> {
883 attributes
884 .iter()
885 .filter_map(|attr| {
886 if attr.path().is_ident("repr") {
887 Some(attr.parse_args::<Representation>())
888 } else {
889 None
890 }
891 })
892 .try_fold(Representation::default(), |a, b| {
893 let b = b?;
894 Ok(Representation {
895 repr: match (a.repr, b.repr) {
896 (a, Repr::Rust) => a,
897 (Repr::Rust, b) => b,
898 _ => bail!("conflicting representation hints"),
899 },
900 packed: match (a.packed, b.packed) {
901 (a, None) => a,
902 (None, b) => b,
903 _ => bail!("conflicting representation hints"),
904 },
905 align: match (a.align, b.align) {
906 (Some(a), Some(b)) => Some(cmp::max(a, b)),
907 (a, None) => a,
908 (None, b) => b,
909 },
910 })
911 })
912}
913
914mk_repr! {
915 U8 => u8,
916 I8 => i8,
917 U16 => u16,
918 I16 => i16,
919 U32 => u32,
920 I32 => i32,
921 U64 => u64,
922 I64 => i64,
923 I128 => i128,
924 U128 => u128,
925 Usize => usize,
926 Isize => isize,
927}
928// where
929macro_rules! mk_repr {(
930 $(
931 $Xn:ident => $xn:ident
932 ),* $(,)?
933) => (
934 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
935 enum IntegerRepr {
936 $($Xn),*
937 }
938
939 impl<'a> TryFrom<&'a str> for IntegerRepr {
940 type Error = &'a str;
941
942 fn try_from(value: &'a str) -> std::result::Result<Self, &'a str> {
943 match value {
944 $(
945 stringify!($xn) => Ok(Self::$Xn),
946 )*
947 _ => Err(value),
948 }
949 }
950 }
951
952 impl ToTokens for IntegerRepr {
953 fn to_tokens(&self, tokens: &mut TokenStream) {
954 match self {
955 $(
956 Self::$Xn => tokens.extend(quote!($xn)),
957 )*
958 }
959 }
960 }
961)}
962use mk_repr;
963
964#[derive(Debug, Clone, Copy, PartialEq, Eq)]
965enum Repr {
966 Rust,
967 C,
968 Transparent,
969 Integer(IntegerRepr),
970 CWithDiscriminant(IntegerRepr),
971}
972
973impl Repr {
974 fn is_integer(&self) -> bool {
975 matches!(self, Self::Integer(..))
976 }
977
978 fn as_integer(&self) -> Option<IntegerRepr> {
979 if let Self::Integer(v: &IntegerRepr) = self {
980 Some(*v)
981 } else {
982 None
983 }
984 }
985}
986
987#[derive(Debug, Clone, Copy, PartialEq, Eq)]
988struct Representation {
989 packed: Option<u32>,
990 align: Option<u32>,
991 repr: Repr,
992}
993
994impl Default for Representation {
995 fn default() -> Self {
996 Self { packed: None, align: None, repr: Repr::Rust }
997 }
998}
999
1000impl Parse for Representation {
1001 fn parse(input: ParseStream<'_>) -> Result<Representation> {
1002 let mut ret = Representation::default();
1003 while !input.is_empty() {
1004 let keyword = input.parse::<Ident>()?;
1005 // preëmptively call `.to_string()` *once* (rather than on `is_ident()`)
1006 let keyword_str = keyword.to_string();
1007 let new_repr = match keyword_str.as_str() {
1008 "C" => Repr::C,
1009 "transparent" => Repr::Transparent,
1010 "packed" => {
1011 ret.packed = Some(if input.peek(token::Paren) {
1012 let contents;
1013 parenthesized!(contents in input);
1014 LitInt::base10_parse::<u32>(&contents.parse()?)?
1015 } else {
1016 1
1017 });
1018 let _: Option<Token![,]> = input.parse()?;
1019 continue;
1020 }
1021 "align" => {
1022 let contents;
1023 parenthesized!(contents in input);
1024 let new_align = LitInt::base10_parse::<u32>(&contents.parse()?)?;
1025 ret.align = Some(
1026 ret
1027 .align
1028 .map_or(new_align, |old_align| cmp::max(old_align, new_align)),
1029 );
1030 let _: Option<Token![,]> = input.parse()?;
1031 continue;
1032 }
1033 ident => {
1034 let primitive = IntegerRepr::try_from(ident)
1035 .map_err(|_| input.error("unrecognized representation hint"))?;
1036 Repr::Integer(primitive)
1037 }
1038 };
1039 ret.repr = match (ret.repr, new_repr) {
1040 (Repr::Rust, new_repr) => {
1041 // This is the first explicit repr.
1042 new_repr
1043 }
1044 (Repr::C, Repr::Integer(integer))
1045 | (Repr::Integer(integer), Repr::C) => {
1046 // Both the C repr and an integer repr have been specified
1047 // -> merge into a C wit discriminant.
1048 Repr::CWithDiscriminant(integer)
1049 }
1050 (_, _) => {
1051 return Err(input.error("duplicate representation hint"));
1052 }
1053 };
1054 let _: Option<Token![,]> = input.parse()?;
1055 }
1056 Ok(ret)
1057 }
1058}
1059
1060impl ToTokens for Representation {
1061 fn to_tokens(&self, tokens: &mut TokenStream) {
1062 let mut meta = Punctuated::<_, Token![,]>::new();
1063
1064 match self.repr {
1065 Repr::Rust => {}
1066 Repr::C => meta.push(quote!(C)),
1067 Repr::Transparent => meta.push(quote!(transparent)),
1068 Repr::Integer(primitive) => meta.push(quote!(#primitive)),
1069 Repr::CWithDiscriminant(primitive) => {
1070 meta.push(quote!(C));
1071 meta.push(quote!(#primitive));
1072 }
1073 }
1074
1075 if let Some(packed) = self.packed.as_ref() {
1076 let lit = LitInt::new(&packed.to_string(), Span::call_site());
1077 meta.push(quote!(packed(#lit)));
1078 }
1079
1080 if let Some(align) = self.align.as_ref() {
1081 let lit = LitInt::new(&align.to_string(), Span::call_site());
1082 meta.push(quote!(align(#lit)));
1083 }
1084
1085 tokens.extend(quote!(
1086 #[repr(#meta)]
1087 ));
1088 }
1089}
1090
1091fn enum_has_fields<'a>(
1092 mut variants: impl Iterator<Item = &'a Variant>,
1093) -> bool {
1094 variants.any(|v: &Variant| matches!(v.fields, Fields::Named(_) | Fields::Unnamed(_)))
1095}
1096
1097struct VariantDiscriminantIterator<'a, I: Iterator<Item = &'a Variant> + 'a> {
1098 inner: I,
1099 last_value: i64,
1100}
1101
1102impl<'a, I: Iterator<Item = &'a Variant> + 'a>
1103 VariantDiscriminantIterator<'a, I>
1104{
1105 fn new(inner: I) -> Self {
1106 VariantDiscriminantIterator { inner, last_value: -1 }
1107 }
1108}
1109
1110impl<'a, I: Iterator<Item = &'a Variant> + 'a> Iterator
1111 for VariantDiscriminantIterator<'a, I>
1112{
1113 type Item = Result<i64>;
1114
1115 fn next(&mut self) -> Option<Self::Item> {
1116 let variant: &Variant = self.inner.next()?;
1117
1118 if let Some((_, discriminant: &Expr)) = &variant.discriminant {
1119 let discriminant_value: i64 = match parse_int_expr(discriminant) {
1120 Ok(value: i64) => value,
1121 Err(e: Error) => return Some(Err(e)),
1122 };
1123 self.last_value = discriminant_value;
1124 } else {
1125 self.last_value += 1;
1126 }
1127
1128 Some(Ok(self.last_value))
1129 }
1130}
1131
1132fn parse_int_expr(expr: &Expr) -> Result<i64> {
1133 match expr {
1134 Expr::Unary(ExprUnary { op: UnOp::Neg(_), expr: &Box, .. }) => {
1135 parse_int_expr(expr).map(|int: i64| -int)
1136 }
1137 Expr::Lit(ExprLit { lit: Lit::Int(int: &LitInt), .. }) => int.base10_parse(),
1138 Expr::Lit(ExprLit { lit: Lit::Byte(byte: &LitByte), .. }) => Ok(byte.value().into()),
1139 _ => bail!("Not an integer expression"),
1140 }
1141}
1142
1143#[cfg(test)]
1144mod tests {
1145 use syn::parse_quote;
1146
1147 use super::{get_repr, IntegerRepr, Repr, Representation};
1148
1149 #[test]
1150 fn parse_basic_repr() {
1151 let attr = parse_quote!(#[repr(C)]);
1152 let repr = get_repr(&[attr]).unwrap();
1153 assert_eq!(repr, Representation { repr: Repr::C, ..Default::default() });
1154
1155 let attr = parse_quote!(#[repr(transparent)]);
1156 let repr = get_repr(&[attr]).unwrap();
1157 assert_eq!(
1158 repr,
1159 Representation { repr: Repr::Transparent, ..Default::default() }
1160 );
1161
1162 let attr = parse_quote!(#[repr(u8)]);
1163 let repr = get_repr(&[attr]).unwrap();
1164 assert_eq!(
1165 repr,
1166 Representation {
1167 repr: Repr::Integer(IntegerRepr::U8),
1168 ..Default::default()
1169 }
1170 );
1171
1172 let attr = parse_quote!(#[repr(packed)]);
1173 let repr = get_repr(&[attr]).unwrap();
1174 assert_eq!(repr, Representation { packed: Some(1), ..Default::default() });
1175
1176 let attr = parse_quote!(#[repr(packed(1))]);
1177 let repr = get_repr(&[attr]).unwrap();
1178 assert_eq!(repr, Representation { packed: Some(1), ..Default::default() });
1179
1180 let attr = parse_quote!(#[repr(packed(2))]);
1181 let repr = get_repr(&[attr]).unwrap();
1182 assert_eq!(repr, Representation { packed: Some(2), ..Default::default() });
1183
1184 let attr = parse_quote!(#[repr(align(2))]);
1185 let repr = get_repr(&[attr]).unwrap();
1186 assert_eq!(repr, Representation { align: Some(2), ..Default::default() });
1187 }
1188
1189 #[test]
1190 fn parse_advanced_repr() {
1191 let attr = parse_quote!(#[repr(align(4), align(2))]);
1192 let repr = get_repr(&[attr]).unwrap();
1193 assert_eq!(repr, Representation { align: Some(4), ..Default::default() });
1194
1195 let attr1 = parse_quote!(#[repr(align(1))]);
1196 let attr2 = parse_quote!(#[repr(align(4))]);
1197 let attr3 = parse_quote!(#[repr(align(2))]);
1198 let repr = get_repr(&[attr1, attr2, attr3]).unwrap();
1199 assert_eq!(repr, Representation { align: Some(4), ..Default::default() });
1200
1201 let attr = parse_quote!(#[repr(C, u8)]);
1202 let repr = get_repr(&[attr]).unwrap();
1203 assert_eq!(
1204 repr,
1205 Representation {
1206 repr: Repr::CWithDiscriminant(IntegerRepr::U8),
1207 ..Default::default()
1208 }
1209 );
1210
1211 let attr = parse_quote!(#[repr(u8, C)]);
1212 let repr = get_repr(&[attr]).unwrap();
1213 assert_eq!(
1214 repr,
1215 Representation {
1216 repr: Repr::CWithDiscriminant(IntegerRepr::U8),
1217 ..Default::default()
1218 }
1219 );
1220 }
1221}
1222
1223pub fn bytemuck_crate_name(input: &DeriveInput) -> TokenStream {
1224 const ATTR_NAME: &'static str = "crate";
1225
1226 let mut crate_name = quote!(::bytemuck);
1227 for attr in &input.attrs {
1228 if !attr.path().is_ident("bytemuck") {
1229 continue;
1230 }
1231
1232 attr.parse_nested_meta(|meta| {
1233 if meta.path.is_ident(ATTR_NAME) {
1234 let expr: syn::Expr = meta.value()?.parse()?;
1235 let mut value = &expr;
1236 while let syn::Expr::Group(e) = value {
1237 value = &e.expr;
1238 }
1239 if let syn::Expr::Lit(syn::ExprLit {
1240 lit: syn::Lit::Str(lit), ..
1241 }) = value
1242 {
1243 let suffix = lit.suffix();
1244 if !suffix.is_empty() {
1245 bail!(format!("Unexpected suffix `{}` on string literal", suffix))
1246 }
1247 let path: syn::Path = match lit.parse() {
1248 Ok(path) => path,
1249 Err(_) => {
1250 bail!(format!("Failed to parse path: {:?}", lit.value()))
1251 }
1252 };
1253 crate_name = path.into_token_stream();
1254 } else {
1255 bail!(
1256 "Expected bytemuck `crate` attribute to be a string: `crate = \"...\"`",
1257 )
1258 }
1259 }
1260 Ok(())
1261 }).unwrap();
1262 }
1263
1264 return crate_name;
1265}
1266