structview_derive/
lib.rs

1//! Custom safe derive for the `View` trait of the `structview` crate.
2
3extern crate proc_macro;
4
5use proc_macro2::TokenStream;
6use quote::{quote, quote_spanned};
7use syn::spanned::Spanned;
8use syn::{
9    parse_macro_input, parse_quote, AttrStyle, Data, DeriveInput, Fields, FieldsNamed,
10    FieldsUnnamed, Index,
11};
12
13/// Derive `structview::View` on a struct or union.
14///
15/// To ensure safety, this derive:
16///   - ensures the deriving type is a struct or a union
17///   - ensures the deriving type is repr(C)
18///   - ensures that all fields implement `structview::View`
19#[proc_macro_derive(View)]
20pub fn derive_view(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
21    let item = parse_macro_input!(item as DeriveInput);
22
23    if let Some(error_msg) = check_item(&item) {
24        let error = quote! { compile_error!(#error_msg); };
25        return proc_macro::TokenStream::from(error);
26    }
27
28    let name = item.ident;
29
30    let generics = item.generics;
31    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
32
33    let view_asserts = view_asserts(&item.data);
34
35    let expanded = quote! {
36        unsafe impl #impl_generics structview::View for #name #ty_generics #where_clause {}
37
38        impl #impl_generics #name #ty_generics #where_clause {
39            fn assert_all_fields_are_view(&self) {
40                fn assert_view<T: structview::View>(_: &T) {}
41                #view_asserts
42            }
43        }
44    };
45
46    proc_macro::TokenStream::from(expanded)
47}
48
49/// Check if `item` can safely derive `structview::View`.
50///
51///   - Enums cannot derive `structview::View`.
52///   - Deriving types must be repr(C).
53fn check_item(item: &DeriveInput) -> Option<&'static str> {
54    if let Data::Enum(_) = item.data {
55        return Some("enums cannot derive `structview::View`");
56    }
57
58    if is_repr_c(item) {
59        None
60    } else {
61        Some("types that derive `structview::View` must be repr(C)")
62    }
63}
64
65/// Check if `item` has a #[repr(C)] attribute.
66fn is_repr_c(item: &DeriveInput) -> bool {
67    item.attrs
68        .iter()
69        .filter(|a| a.style == AttrStyle::Outer)
70        .any(|a| match a.parse_meta() {
71            Ok(meta) => meta == parse_quote!(repr(C)),
72            Err(_) => false,
73        })
74}
75
76/// Generate assert calls to ensure `structview::View` is implemented
77/// every field.
78fn view_asserts(data: &Data) -> TokenStream {
79    match data {
80        Data::Struct(data) => match data.fields {
81            Fields::Named(ref fields) => named_fields_asserts(fields),
82            Fields::Unnamed(ref fields) => unnamed_fields_asserts(fields),
83            Fields::Unit => TokenStream::new(),
84        },
85        Data::Union(data) => {
86            let asserts = named_fields_asserts(&data.fields);
87            quote! { unsafe { #asserts } }
88        }
89        Data::Enum(_) => unreachable!(),
90    }
91}
92
93fn named_fields_asserts(fields: &FieldsNamed) -> TokenStream {
94    let asserts = fields.named.iter().map(|f| {
95        let name = &f.ident;
96        quote_spanned! { f.span() => assert_view(&self.#name); }
97    });
98    quote! { #(#asserts)* }
99}
100
101fn unnamed_fields_asserts(fields: &FieldsUnnamed) -> TokenStream {
102    let asserts = fields.unnamed.iter().enumerate().map(|(i, f)| {
103        let index = Index {
104            index: i as u32,
105            span: f.span(),
106        };
107        quote_spanned! { f.span() => assert_view(&self.#index); }
108    });
109    quote! { #(#asserts)* }
110}