safe_bytes_derive/
lib.rs

1use {proc_macro2::TokenStream, quote::quote, syn::spanned::Spanned as _};
2
3/// Safely implements [`SafeBytes`] via [`PaddingBane`] implementation.
4///
5/// [`SafeBytes`]: https://docs.rs/safe-bytes/0.1.0/safe_bytes/trait.SafeBytes.html
6/// [`PaddingBane`]: https://docs.rs/safe-bytes/0.1.0/safe_bytes/trait.PaddingBane.html
7#[proc_macro_derive(SafeBytes)]
8pub fn safe_bytes(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
9    let ast = syn::parse(input).unwrap();
10    impl_safe_bytes(&ast).into()
11}
12
13fn impl_safe_bytes(ast: &syn::DeriveInput) -> TokenStream {
14    let type_name = &ast.ident;
15    let fields = match &ast.data {
16        syn::Data::Struct(datastruct) => &datastruct.fields,
17        _ => panic!("safe_bytes cannot be derived for enums or unions"),
18    };
19
20    let field_types = fields.iter().map(|f| f.ty.clone()).collect::<Vec<_>>();
21    let field_names = fields
22        .iter()
23        .enumerate()
24        .map(|(i, f)| {
25            f.ident
26                .clone()
27                .unwrap_or_else(|| syn::Ident::new(&format!("_{}", i), ast.span()))
28        })
29        .collect::<Vec<_>>();
30
31    let (impl_generics, type_generics, where_clause) = ast.generics.split_for_impl();
32
33    quote! {
34        #[automatically_derived]
35        unsafe impl #impl_generics ::safe_bytes::PaddingBane for #type_name #type_generics #where_clause {
36            type Fields = (#(::safe_bytes::TypedField<#field_types>,)*);
37
38            #[inline(always)]
39            fn get_fields(&self) -> Self::Fields {
40                (#(::safe_bytes::typed_field!(*self, #type_name, #field_names),)*)
41            }
42
43            #[inline]
44            unsafe fn init_padding(fields: Self::Fields, bytes: &mut [::safe_bytes::core::mem::MaybeUninit<u8>]) {
45                use {
46                    ::safe_bytes::core::{mem::size_of, ptr::write_bytes},
47                };
48
49                let (#(#field_names,)*) = fields;
50                let mut raw_fields = [#(#field_names.raw,)*];
51                raw_fields.sort_unstable_by_key(|f| f.offset);
52                let mut offset = 0;
53                for field in &raw_fields {
54                    if field.offset > offset {
55                        let count = field.offset - offset;
56                        write_bytes(&mut bytes[offset], 0xfe, count);
57                    }
58                    offset = field.offset + field.size;
59                }
60
61                if size_of::<Self>() > offset {
62                    let count = size_of::<Self>() - offset;
63                    write_bytes(&mut bytes[offset], 0xfe, count);
64                }
65
66                #(
67                    let field_bytes = &mut bytes[#field_names.raw.offset .. #field_names.raw.offset + #field_names.raw.size];
68                    <#field_types as ::safe_bytes::PaddingBane>::init_padding(#field_names.sub, field_bytes);
69                )*
70            }
71        }
72    }
73}