vtable_rs_proc_macros/
lib.rs1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::{quote, ToTokens};
4use syn::{
5 parse_quote, punctuated::Punctuated, Abi, BareFnArg, FnArg, Ident, ItemTrait, Lifetime, LitStr,
6 Pat, Signature, Token, TraitItem, TraitItemFn, Type, TypeBareFn, TypeParamBound, TypeReference,
7};
8
9fn check_restrictions(trait_def: &ItemTrait) {
10 if trait_def.generics.lt_token.is_some() {
12 panic!("vtable trait cannot be given a lifetime")
13 }
14 if !trait_def.generics.params.empty_or_trailing() {
15 panic!("vtable traits do not support generic parameters yet")
16 }
17 if trait_def.auto_token.is_some() {
18 panic!("vtable trait cannot be auto")
19 }
20 if trait_def.unsafety.is_some() {
21 panic!("vtable trait cannot be unsafe")
22 }
23 if trait_def.supertraits.len() > 1 {
24 panic!("vtable trait can only have a single supertrait")
25 }
26 if trait_def.items.is_empty() {
27 panic!("vtable trait must contain at least one function")
28 }
29}
30
31fn extract_base_trait(trait_def: &ItemTrait) -> Vec<proc_macro2::TokenStream> {
32 match trait_def.supertraits.first() {
33 None => None,
34 Some(TypeParamBound::Trait(t)) => Some(t.to_token_stream()),
35 Some(_) => panic!(
36 "vtable trait's bounds must be a single trait representing the base class's vtable."
37 ),
38 }
39 .into_iter()
40 .collect()
41}
42
43fn set_method_abis(trait_def: &mut ItemTrait, abi: &str) {
44 for item in trait_def.items.iter_mut() {
45 if let TraitItem::Fn(fun) = item {
46 fun.sig.abi.get_or_insert(Abi {
48 extern_token: Token),
49 name: Some(LitStr::new(abi, Span::call_site())),
50 });
51 }
52 else {
53 panic!("vtable trait can only contain functions")
54 }
55 }
56}
57
58fn trait_fn_to_bare_fn(fun: &TraitItemFn) -> TypeBareFn {
59 let lifetimes = fun
60 .sig
61 .generics
62 .lifetimes()
63 .map(|lt| syn::GenericParam::Lifetime(lt.to_owned()));
64
65 TypeBareFn {
66 lifetimes: syn::parse2(quote! { for <#(#lifetimes),*> }).unwrap(),
67 unsafety: fun.sig.unsafety,
68 abi: fun.sig.abi.clone(),
69 fn_token: Token),
70 paren_token: fun.sig.paren_token,
71 inputs: {
72 let mut inputs = Punctuated::new();
73 let mut has_ref_receiver = false;
74 for input in fun.sig.inputs.iter() {
75 inputs.push(match input {
76 FnArg::Receiver(r) => {
77 has_ref_receiver = r.reference.is_some();
78 BareFnArg {
79 attrs: r.attrs.clone(),
80 name: Some((
81 Ident::new("this", Span::call_site()),
82 Token),
83 )),
84 ty: Type::Reference(TypeReference {
85 and_token: Token),
86 lifetime: r.lifetime().cloned(),
87 mutability: r.mutability,
88 elem: Box::new(parse_quote!(T)),
89 }),
90 }
91 }
92 FnArg::Typed(arg) => BareFnArg {
93 attrs: arg.attrs.clone(),
94 name: match arg.pat.as_ref() {
95 Pat::Ident(ident) => {
96 Some((ident.ident.clone(), Token)))
97 }
98 _ => None,
99 },
100 ty: *arg.ty.to_owned(),
101 },
102 });
103 }
104 if !has_ref_receiver {
105 panic!(
106 "vtable trait method \"{0}\" must have &self or &mut self parameter",
107 fun.sig.ident.to_string()
108 )
109 }
110 inputs
111 },
112 variadic: None,
113 output: fun.sig.output.clone(),
114 }
115}
116
117fn sig_to_vtable_thunk(sig: &Signature) -> proc_macro2::TokenStream {
119 let (receiver_mut, receiver_lt) = match sig.inputs.first() {
120 Some(FnArg::Receiver(r)) => (r.mutability.clone(), r.lifetime().cloned()),
121 _ => unreachable!(),
122 };
123
124 let self_arg: FnArg = syn::parse2(quote! { &self }).unwrap();
125 let t_arg: FnArg = syn::parse2(quote! { this: & #receiver_lt #receiver_mut T }).unwrap();
126
127 let mut with_t = sig.clone();
128
129 *with_t.inputs.first_mut().unwrap() = self_arg;
130 with_t.inputs.insert(1, t_arg);
131 with_t.abi = None; let ident = &sig.ident;
134 let arg_idents = with_t.inputs.iter().skip(1).map(|arg| match arg {
135 FnArg::Typed(pt) => match pt.pat.as_ref() {
136 Pat::Ident(ident_pat) => ident_pat.ident.clone(),
137 _ => unreachable!(),
138 },
139 _ => unreachable!(),
140 });
141
142 quote! {
143 #[inline]
144 pub #with_t {
145 (self.#ident)(#(#arg_idents),*)
146 }
147 }
148}
149
150#[proc_macro_attribute]
215pub fn vtable(_attr: TokenStream, item: TokenStream) -> TokenStream {
216 let mut trait_def: ItemTrait = syn::parse(item).unwrap();
217
218 check_restrictions(&trait_def);
219
220 let base_trait = extract_base_trait(&trait_def);
221
222 trait_def.supertraits.push(TypeParamBound::Lifetime(Lifetime::new(
224 "'static",
225 Span::call_site(),
226 )));
227
228 set_method_abis(&mut trait_def, "C");
230
231 let layout_ident = Ident::new(&(trait_def.ident.to_string() + "Layout"), Span::call_site());
232 let signatures: Vec<_> = trait_def
233 .items
234 .iter()
235 .filter_map(|item| {
236 if let TraitItem::Fn(fun) = item {
237 Some(&fun.sig)
238 }
239 else {
240 None
241 }
242 })
243 .collect();
244
245 let trait_ident = &trait_def.ident;
246 let trait_vis = &trait_def.vis;
247 let fn_idents: Vec<_> = signatures.iter().map(|sig| &sig.ident).collect();
248 let bare_fns = trait_def.items.iter().filter_map(|item| match item {
249 TraitItem::Fn(fun) => Some(trait_fn_to_bare_fn(fun)),
250 _ => None,
251 });
252
253 let base_decl = if base_trait.is_empty() {
255 proc_macro2::TokenStream::new()
256 }
257 else {
258 quote! { _base: self._base, }
259 };
260
261 let base_deref_impl = match base_trait.first() {
262 None => proc_macro2::TokenStream::new(),
263 Some(base) => quote! {
264 impl<T: 'static> ::core::ops::Deref for #layout_ident<T> {
265 type Target = <dyn #base as ::vtable_rs::VmtLayout>::Layout<T>;
266
267 fn deref(&self) -> &Self::Target {
268 &self._base
269 }
270 }
271 impl<T: 'static> ::core::ops::DerefMut for #layout_ident<T> {
272 fn deref_mut(&mut self) -> &mut Self::Target {
273 &mut self._base
274 }
275 }
276 },
277 };
278
279 let mut tokens = trait_def.to_token_stream();
284 tokens.extend(quote! {
285 #[repr(C)]
286 #trait_vis struct #layout_ident<T: 'static> {
287 #(_base: <dyn #base_trait as ::vtable_rs::VmtLayout>::Layout<T>,)*
288 #(pub #fn_idents: #bare_fns,)*
289 }
290
291 impl<T> ::core::clone::Clone for #layout_ident<T> {
296 fn clone(&self) -> Self {
297 Self {
298 #base_decl
299 #(#fn_idents: self.#fn_idents),*
300 }
301 }
302 }
303 impl<T> ::core::marker::Copy for #layout_ident<T> {}
304
305 #base_deref_impl
306
307 unsafe impl ::vtable_rs::VmtLayout for dyn #trait_ident {
308 type Layout<T: 'static> = #layout_ident<T>;
309 }
310
311 impl<T: #trait_ident> ::vtable_rs::VmtInstance<T> for dyn #trait_ident {
312 const VTABLE: &'static Self::Layout<T> = &#layout_ident {
313 #(_base: *<dyn #base_trait as ::vtable_rs::VmtInstance<T>>::VTABLE,)*
314 #(#fn_idents: <T as #trait_ident>::#fn_idents),*
315 };
316 }
317 });
318
319 tokens.into()
320}