1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
extern crate proc_macro; use proc_macro2::TokenStream; use quote::{quote, quote_spanned}; use syn::spanned::Spanned; use syn::{parse_macro_input, Attribute, Data, DeriveInput, Fields}; #[proc_macro_derive(SpirvLayout)] pub fn spirv_layout_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = parse_macro_input!(input as DeriveInput); let name = input.ident; let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); let repr_check = ensure_repr(&input.attrs); let body = build_function_body(&input.data); let expanded = quote! { impl #impl_generics spirv_struct_layout::CheckSpirvStruct for #name #ty_generics #where_clause { fn check_spirv_layout(name: &str, spirv: Vec<u32>) { #repr_check let spv: spirq::SpirvBinary = spirv.into(); let entries = spv.reflect().expect("Failed to parse SPIR-V"); let buffer_desc = entries[0].resolve_desc(spirq::sym::Sym::new(name)).expect(format!("Failed to find symbol with name \"{}\"", name).as_str()); let mut _rust_offset = 0; #body } } }; proc_macro::TokenStream::from(expanded) } fn ensure_repr(attrs: &Vec<Attribute>) -> TokenStream { for attr in attrs { if let Ok(meta) = attr.parse_meta() { if meta.path().is_ident("repr") { return quote! {}; } } } quote! { compile_error!("structs exposed to SPIRV must have a declared repr"); } } fn build_function_body(data: &Data) -> TokenStream { match *data { Data::Struct(ref data) => match data.fields { Fields::Named(ref fields) => { let inner = fields.named.iter().map(|f| { let name = &f.ident; let ty = &f.ty; quote_spanned! { f.span() => { { let symbol = stringify!(#name); let rust_size = std::mem::size_of::<#ty>(); if let Some(desc) = buffer_desc.desc_ty.resolve(spirq::sym::Sym::new(&symbol)) { let spirv_offset = desc.offset; let spirv_size = desc.ty.nbyte().expect(format!("Rust struct field named \"{}\" does not have a basic data type (float, vec, mat, array, struct) as a SPIR-V counterpart", &symbol).as_str()); assert_eq!( spirv_size, rust_size, "field {} should be {} bytes, but was {} bytes", symbol, spirv_size, rust_size ); assert_eq!( spirv_offset, _rust_offset, "field {} should have an offset of {} bytes, but was {} bytes", symbol, spirv_offset, _rust_offset ); } _rust_offset += rust_size; } } } }); quote! { #(#inner)* } } _ => unimplemented!(), }, _ => unimplemented!(), } }