#![recursion_limit = "128"]
extern crate proc_macro;
use proc_macro::TokenStream;
use quote::quote;
use syn::parse::Parser;
use syn::{self, spanned::Spanned, AttributeArgs, ItemStruct};
#[proc_macro_attribute]
pub fn vptr(attr: TokenStream, item: TokenStream) -> TokenStream {
let attr = syn::parse_macro_input!(attr as AttributeArgs);
let item = syn::parse_macro_input!(item as ItemStruct);
match vptr_impl(attr, item) {
Ok(x) => x,
Err(e) => e.to_compile_error().into(),
}
}
fn vptr_impl(attr: AttributeArgs, item: ItemStruct) -> Result<TokenStream, syn::Error> {
let ItemStruct {
attrs,
vis,
struct_token,
ident,
generics,
fields,
semi_token,
} = item;
let attr = attr
.iter()
.map(|a| {
if let syn::NestedMeta::Meta(syn::Meta::Path(i)) = a {
Ok(i.clone())
} else if let syn::NestedMeta::Lit(syn::Lit::Str(lit_str)) = a {
lit_str.parse::<syn::Path>()
} else {
Err(syn::Error::new(
a.span(),
"attribute of vptr must be a trait",
))
}
})
.collect::<Result<Vec<_>, _>>()?;
if let Some(tp) = generics.type_params().next() {
return Err(syn::Error::new(tp.span(), "vptr does not support generics"));
}
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let (fields, attr_with_names) = if let syn::Fields::Named(mut n) = fields {
let attr_with_names: Vec<_> = attr
.iter()
.map(|t| {
let field_name = quote::format_ident!("vptr_{}", t.segments.last().unwrap().ident);
(t, quote! { #field_name })
})
.collect();
let parser = syn::Field::parse_named;
for (trait_, field_name) in &attr_with_names {
n.named
.push(parser.parse(
quote!(#field_name : vptr::VPtr<#ident #ty_generics, dyn #trait_>).into(),
)?);
}
(syn::Fields::Named(n), attr_with_names)
} else {
let mut n = if let syn::Fields::Unnamed(n) = fields {
n
} else {
syn::FieldsUnnamed {
paren_token: Default::default(),
unnamed: Default::default(),
}
};
let count = n.unnamed.len();
let parser = syn::Field::parse_unnamed;
for trait_ in &attr {
n.unnamed
.push(parser.parse(quote!(vptr::VPtr<#ident #ty_generics, dyn #trait_>).into())?);
}
let attr_with_names: Vec<_> = attr
.iter()
.enumerate()
.map(|(i, t)| {
let field_name = syn::Index::from(i + count);
(t, quote! { #field_name })
})
.collect();
(syn::Fields::Unnamed(n), attr_with_names)
};
let mut result = quote!(
#(#attrs)* #[allow(non_snake_case)] #vis #struct_token #ident #generics #fields #semi_token
);
for (trait_, field_name) in attr_with_names {
result = quote!(#result
unsafe impl #impl_generics vptr::HasVPtr<dyn #trait_> for #ident #ty_generics #where_clause {
fn init() -> &'static vptr::VTableData {
use vptr::internal::{TransmuterTO, TransmuterPtr};
static VTABLE : vptr::VTableData = vptr::VTableData{
offset: ::core::mem::offset_of!(#ident, #field_name) as isize,
vtable: unsafe {
let x: &'static #ident = TransmuterPtr::<#ident> { int: 0 }.ptr;
TransmuterTO::<dyn #trait_>{ ptr: x }.to.vtable
}
};
&VTABLE
}
fn get_vptr(&self) -> &vptr::VPtr<Self, dyn #trait_> { &self.#field_name }
fn get_vptr_mut(&mut self) -> &mut vptr::VPtr<Self, dyn #trait_> { &mut self.#field_name }
}
);
}
Ok(result.into())
}