#![recursion_limit = "128"]
mod ext;
mod repr;
use proc_macro2::Span;
use syn::visit::{self, Visit};
use syn::{
parse_quote, punctuated::Punctuated, token::Comma, Data, DataEnum, DataStruct, DeriveInput,
Error, GenericParam, Ident, Lifetime, Type, TypePath,
};
use synstructure::{decl_derive, quote, Structure};
use ext::*;
use repr::*;
decl_derive!([FromBytes] => derive_from_bytes);
decl_derive!([AsBytes] => derive_as_bytes);
decl_derive!([Unaligned] => derive_unaligned);
fn derive_from_bytes(s: Structure<'_>) -> proc_macro2::TokenStream {
match &s.ast().data {
Data::Struct(strct) => derive_from_bytes_struct(&s, strct),
Data::Enum(enm) => derive_from_bytes_enum(&s, enm),
Data::Union(_) => Error::new(Span::call_site(), "unsupported on unions").to_compile_error(),
}
}
fn derive_as_bytes(s: Structure<'_>) -> proc_macro2::TokenStream {
match &s.ast().data {
Data::Struct(strct) => derive_as_bytes_struct(&s, strct),
Data::Enum(enm) => derive_as_bytes_enum(&s, enm),
Data::Union(_) => Error::new(Span::call_site(), "unsupported on unions").to_compile_error(),
}
}
fn derive_unaligned(s: Structure<'_>) -> proc_macro2::TokenStream {
match &s.ast().data {
Data::Struct(strct) => derive_unaligned_struct(&s, strct),
Data::Enum(enm) => derive_unaligned_enum(&s, enm),
Data::Union(_) => Error::new(Span::call_site(), "unsupported on unions").to_compile_error(),
}
}
macro_rules! try_or_print {
($e:expr) => {
match $e {
Ok(x) => x,
Err(errors) => return print_all_errors(errors),
}
};
}
fn derive_from_bytes_struct(s: &Structure<'_>, strct: &DataStruct) -> proc_macro2::TokenStream {
impl_block(s.ast(), strct, "FromBytes", true, false)
}
fn derive_from_bytes_enum(s: &Structure<'_>, enm: &DataEnum) -> proc_macro2::TokenStream {
if !enm.is_c_like() {
return Error::new_spanned(s.ast(), "only C-like enums can implement FromBytes")
.to_compile_error();
}
let reprs = try_or_print!(ENUM_FROM_BYTES_CFG.validate_reprs(s.ast()));
let variants_required = match reprs.as_slice() {
[EnumRepr::U8] | [EnumRepr::I8] => 1usize << 8,
[EnumRepr::U16] | [EnumRepr::I16] => 1usize << 16,
_ => unreachable!(),
};
if enm.variants.len() != variants_required {
return Error::new_spanned(
s.ast(),
format!(
"FromBytes only supported on {} enum with {} variants",
reprs[0], variants_required
),
)
.to_compile_error();
}
impl_block(s.ast(), enm, "FromBytes", true, false)
}
#[rustfmt::skip]
const ENUM_FROM_BYTES_CFG: Config<EnumRepr> = {
use EnumRepr::*;
Config {
allowed_combinations_message: r#"FromBytes requires repr of "u8", "u16", "i8", or "i16""#,
derive_unaligned: false,
allowed_combinations: &[
&[U8],
&[U16],
&[I8],
&[I16],
],
disallowed_but_legal_combinations: &[
&[C],
&[U32],
&[I32],
&[U64],
&[I64],
&[Usize],
&[Isize],
],
}
};
fn derive_as_bytes_struct(s: &Structure<'_>, strct: &DataStruct) -> proc_macro2::TokenStream {
if !s.ast().generics.params.is_empty() {
return Error::new(Span::call_site(), "unsupported on types with type parameters")
.to_compile_error();
}
let reprs = try_or_print!(STRUCT_AS_BYTES_CFG.validate_reprs(s.ast()));
let require_size_check = match reprs.as_slice() {
[StructRepr::C] | [StructRepr::Transparent] => true,
[StructRepr::Packed] | [StructRepr::C, StructRepr::Packed] => false,
_ => unreachable!(),
};
impl_block(s.ast(), strct, "AsBytes", true, require_size_check)
}
#[rustfmt::skip]
const STRUCT_AS_BYTES_CFG: Config<StructRepr> = {
use StructRepr::*;
Config {
allowed_combinations_message: r#"AsBytes requires repr of "C", "transparent", or "packed""#,
derive_unaligned: false,
allowed_combinations: &[
&[C],
&[Transparent],
&[C, Packed],
&[Packed],
],
disallowed_but_legal_combinations: &[],
}
};
fn derive_as_bytes_enum(s: &Structure<'_>, enm: &DataEnum) -> proc_macro2::TokenStream {
if !enm.is_c_like() {
return Error::new_spanned(s.ast(), "only C-like enums can implement AsBytes")
.to_compile_error();
}
try_or_print!(ENUM_AS_BYTES_CFG.validate_reprs(s.ast()));
impl_block(s.ast(), enm, "AsBytes", false, false)
}
#[rustfmt::skip]
const ENUM_AS_BYTES_CFG: Config<EnumRepr> = {
use EnumRepr::*;
Config {
allowed_combinations_message: r#"AsBytes requires repr of "C", "u8", "u16", "u32", "u64", "usize", "i8", "i16", "i32", "i64", or "isize""#,
derive_unaligned: false,
allowed_combinations: &[
&[C],
&[U8],
&[U16],
&[I8],
&[I16],
&[U32],
&[I32],
&[U64],
&[I64],
&[Usize],
&[Isize],
],
disallowed_but_legal_combinations: &[],
}
};
fn derive_unaligned_struct(s: &Structure<'_>, strct: &DataStruct) -> proc_macro2::TokenStream {
let reprs = try_or_print!(STRUCT_UNALIGNED_CFG.validate_reprs(s.ast()));
let require_trait_bound = match reprs.as_slice() {
[StructRepr::C] | [StructRepr::Transparent] => true,
[StructRepr::Packed] | [StructRepr::C, StructRepr::Packed] => false,
_ => unreachable!(),
};
impl_block(s.ast(), strct, "Unaligned", require_trait_bound, false)
}
#[rustfmt::skip]
const STRUCT_UNALIGNED_CFG: Config<StructRepr> = {
use StructRepr::*;
Config {
allowed_combinations_message:
r#"Unaligned requires either a) repr "C" or "transparent" with all fields implementing Unaligned or, b) repr "packed""#,
derive_unaligned: true,
allowed_combinations: &[
&[C],
&[Transparent],
&[Packed],
&[C, Packed],
],
disallowed_but_legal_combinations: &[],
}
};
fn derive_unaligned_enum(s: &Structure<'_>, enm: &DataEnum) -> proc_macro2::TokenStream {
if !enm.is_c_like() {
return Error::new_spanned(s.ast(), "only C-like enums can implement Unaligned")
.to_compile_error();
}
try_or_print!(ENUM_UNALIGNED_CFG.validate_reprs(s.ast()));
impl_block(s.ast(), enm, "Unaligned", true, false)
}
#[rustfmt::skip]
const ENUM_UNALIGNED_CFG: Config<EnumRepr> = {
use EnumRepr::*;
Config {
allowed_combinations_message:
r#"Unaligned requires repr of "u8" or "i8", and no alignment (i.e., repr(align(N > 1)))"#,
derive_unaligned: true,
allowed_combinations: &[
&[U8],
&[I8],
],
disallowed_but_legal_combinations: &[
&[C],
&[U16],
&[U32],
&[U64],
&[Usize],
&[I16],
&[I32],
&[I64],
&[Isize],
],
}
};
fn impl_block<D: DataExt>(
input: &DeriveInput,
data: &D,
trait_name: &str,
require_trait_bound: bool,
require_size_check: bool,
) -> proc_macro2::TokenStream {
struct FromTypeParamVisit<'a, 'b>(&'a Punctuated<GenericParam, Comma>, &'b mut bool);
impl<'a, 'b> Visit<'a> for FromTypeParamVisit<'a, 'b> {
fn visit_type_path(&mut self, i: &'a TypePath) {
visit::visit_type_path(self, i);
if self.0.iter().any(|param| {
if let GenericParam::Type(param) = param {
i.path.segments.first().unwrap().ident == param.ident
} else {
false
}
}) {
*self.1 = true;
}
}
fn visit_lifetime(&mut self, i: &'a Lifetime) {
visit::visit_lifetime(self, i);
if self.0.iter().any(|param| {
if let GenericParam::Lifetime(param) = param {
param.lifetime.ident == i.ident
} else {
false
}
}) {
*self.1 = true;
}
}
}
let is_from_type_param = |ty: &Type| {
let mut ret = false;
FromTypeParamVisit(&input.generics.params, &mut ret).visit_type(ty);
ret
};
let trait_ident = Ident::new(trait_name, Span::call_site());
let field_types = data.nested_types();
let type_param_field_types = field_types.iter().filter(|ty| is_from_type_param(ty));
let non_type_param_field_types = field_types.iter().filter(|ty| !is_from_type_param(ty));
let mut generics = input.generics.clone();
let where_clause = generics.make_where_clause();
if require_trait_bound {
for ty in type_param_field_types {
let bound = parse_quote!(#ty: zerocopy::#trait_ident);
where_clause.predicates.push(bound);
}
}
let type_ident = &input.ident;
let params = input.generics.params.clone().into_iter().map(|mut param| {
match &mut param {
GenericParam::Type(ty) => ty.default = None,
GenericParam::Const(cnst) => cnst.default = None,
GenericParam::Lifetime(_) => {}
}
quote!(#param)
});
let param_idents = input.generics.params.iter().map(|param| match param {
GenericParam::Type(ty) => {
let ident = &ty.ident;
quote!(#ident)
}
GenericParam::Lifetime(l) => quote!(#l),
GenericParam::Const(cnst) => quote!(#cnst),
});
let trait_bound_body = if require_trait_bound {
let implements_type_ident =
Ident::new(format!("Implements{}", trait_ident).as_str(), Span::call_site());
let implements_type_tokens = quote!(#implements_type_ident);
let types = non_type_param_field_types.map(|ty| quote!(#implements_type_tokens<#ty>));
quote!(
struct #implements_type_ident<F: ?Sized + zerocopy::#trait_ident>(::core::marker::PhantomData<F>);
#(let _: #types;)*
)
} else {
quote!()
};
let size_check_body = if require_size_check && !field_types.is_empty() {
quote!(
const _: () = {
trait HasPadding<const HAS_PADDING: bool> {}
fn assert_no_padding<T: HasPadding<false>>() {}
const COMPOSITE_TYPE_SIZE: usize = ::core::mem::size_of::<#type_ident>();
const SUM_FIELD_SIZES: usize = 0 #(+ ::core::mem::size_of::<#field_types>())*;
const HAS_PADDING: bool = COMPOSITE_TYPE_SIZE > SUM_FIELD_SIZES;
impl HasPadding<HAS_PADDING> for #type_ident {}
let _ = assert_no_padding::<#type_ident>;
};
)
} else {
quote!()
};
quote! {
unsafe impl < #(#params),* > zerocopy::#trait_ident for #type_ident < #(#param_idents),* > #where_clause {
fn only_derive_is_allowed_to_implement_this_trait() where Self: Sized {
#trait_bound_body
#size_check_body
}
}
}
}
fn print_all_errors(errors: Vec<Error>) -> proc_macro2::TokenStream {
errors.iter().map(Error::to_compile_error).collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_repr_orderings() {
fn is_sorted_and_deduped<T: Clone + Ord>(ts: &[T]) -> bool {
let mut sorted = ts.to_vec();
sorted.sort();
sorted.dedup();
ts == sorted.as_slice()
}
fn elements_are_sorted_and_deduped<T: Clone + Ord>(lists: &[&[T]]) -> bool {
lists.iter().all(|list| is_sorted_and_deduped(*list))
}
fn config_is_sorted<T: KindRepr + Clone>(config: &Config<T>) -> bool {
elements_are_sorted_and_deduped(&config.allowed_combinations)
&& elements_are_sorted_and_deduped(&config.disallowed_but_legal_combinations)
}
assert!(config_is_sorted(&STRUCT_UNALIGNED_CFG));
assert!(config_is_sorted(&ENUM_FROM_BYTES_CFG));
assert!(config_is_sorted(&ENUM_UNALIGNED_CFG));
}
#[test]
fn test_config_repr_no_overlap() {
fn overlap<T: Eq>(a: &[T], b: &[T]) -> bool {
a.iter().any(|elem| b.contains(elem))
}
fn config_overlaps<T: KindRepr + Eq>(config: &Config<T>) -> bool {
overlap(config.allowed_combinations, config.disallowed_but_legal_combinations)
}
assert!(!config_overlaps(&STRUCT_UNALIGNED_CFG));
assert!(!config_overlaps(&ENUM_FROM_BYTES_CFG));
assert!(!config_overlaps(&ENUM_UNALIGNED_CFG));
}
}