Skip to main content

spirv_struct_layout_derive/
lib.rs

1extern crate proc_macro;
2
3use proc_macro2::TokenStream;
4use quote::{quote, quote_spanned};
5use syn::spanned::Spanned;
6use syn::{parse_macro_input, Attribute, Data, DeriveInput, Fields};
7
8#[proc_macro_derive(SpirvLayout)]
9pub fn spirv_layout_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
10    let input = parse_macro_input!(input as DeriveInput);
11
12    let name = input.ident;
13    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
14
15    let repr_check = ensure_repr(&input.attrs);
16    let body = build_function_body(&input.data);
17
18    let expanded = quote! {
19        impl #impl_generics spirv_struct_layout::CheckSpirvStruct for #name
20            #ty_generics #where_clause {
21
22            fn check_spirv_layout(name: &str, spirv: Vec<u32>) {
23                #repr_check
24
25                let spv: spirq::SpirvBinary = spirv.into();
26                let entries = spv.reflect().expect("Failed to parse SPIR-V");
27
28                let buffer_desc = entries[0].resolve_desc(spirq::sym::Sym::new(name)).expect(format!("Failed to find symbol with name \"{}\"", name).as_str());
29
30                let mut _rust_offset = 0;
31
32                #body
33            }
34        }
35    };
36
37    proc_macro::TokenStream::from(expanded)
38}
39
40fn ensure_repr(attrs: &Vec<Attribute>) -> TokenStream {
41    for attr in attrs {
42        if let Ok(meta) = attr.parse_meta() {
43            if meta.path().is_ident("repr") {
44                return quote! {};
45            }
46        }
47    }
48
49    quote! { compile_error!("structs exposed to SPIRV must have a declared repr"); }
50}
51
52fn build_function_body(data: &Data) -> TokenStream {
53    match *data {
54        Data::Struct(ref data) => match data.fields {
55            Fields::Named(ref fields) => {
56                let inner = fields.named.iter().map(|f| {
57                    let name = &f.ident;
58                    let ty = &f.ty;
59                    quote_spanned! {
60                        f.span() => {
61                            {
62                                let symbol = stringify!(#name);
63                                let rust_size = std::mem::size_of::<#ty>();
64
65                                if let Some(desc) = buffer_desc.desc_ty.resolve(spirq::sym::Sym::new(&symbol)) {
66                                    let spirv_offset = desc.offset;
67                                    let spirv_size = desc.ty.nbyte().expect(format!("Rust struct field named \"{}\" does not have a basic data type (float, vec, mat, array, struct) as a SPIR-V counterpart", &symbol).as_str());
68
69                                    assert_eq!(
70                                        spirv_size, rust_size,
71                                        "field {} should be {} bytes, but was {} bytes",
72                                        symbol, spirv_size, rust_size
73                                    );
74                                    assert_eq!(
75                                        spirv_offset, _rust_offset,
76                                        "field {} should have an offset of {} bytes, but was {} bytes",
77                                        symbol, spirv_offset, _rust_offset
78                                    );
79                                }
80
81                                _rust_offset += rust_size;
82                            }
83                        }
84                    }
85                });
86                quote! {
87                    #(#inner)*
88                }
89            }
90            _ => unimplemented!(),
91        },
92        _ => unimplemented!(),
93    }
94}