vtable_rs_proc_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::{quote, ToTokens};
4use syn::{
5    parse_quote, punctuated::Punctuated, Abi, BareFnArg, FnArg, Ident, ItemTrait, Lifetime, LitStr,
6    Pat, Signature, Token, TraitItem, TraitItemFn, Type, TypeBareFn, TypeParamBound, TypeReference,
7};
8
9fn check_restrictions(trait_def: &ItemTrait) {
10    // First, make sure we support the trait
11    if trait_def.generics.lt_token.is_some() {
12        panic!("vtable trait cannot be given a lifetime")
13    }
14    if !trait_def.generics.params.empty_or_trailing() {
15        panic!("vtable traits do not support generic parameters yet")
16    }
17    if trait_def.auto_token.is_some() {
18        panic!("vtable trait cannot be auto")
19    }
20    if trait_def.unsafety.is_some() {
21        panic!("vtable trait cannot be unsafe")
22    }
23    if trait_def.supertraits.len() > 1 {
24        panic!("vtable trait can only have a single supertrait")
25    }
26    if trait_def.items.is_empty() {
27        panic!("vtable trait must contain at least one function")
28    }
29}
30
31fn extract_base_trait(trait_def: &ItemTrait) -> Vec<proc_macro2::TokenStream> {
32    match trait_def.supertraits.first() {
33        None => None,
34        Some(TypeParamBound::Trait(t)) => Some(t.to_token_stream()),
35        Some(_) => panic!(
36            "vtable trait's bounds must be a single trait representing the base class's vtable."
37        ),
38    }
39    .into_iter()
40    .collect()
41}
42
43fn set_method_abis(trait_def: &mut ItemTrait, abi: &str) {
44    for item in trait_def.items.iter_mut() {
45        if let TraitItem::Fn(fun) = item {
46            // Add "extern C" ABI to the function if not present
47            fun.sig.abi.get_or_insert(Abi {
48                extern_token: Token![extern](Span::call_site()),
49                name: Some(LitStr::new(abi, Span::call_site())),
50            });
51        }
52        else {
53            panic!("vtable trait can only contain functions")
54        }
55    }
56}
57
58fn trait_fn_to_bare_fn(fun: &TraitItemFn) -> TypeBareFn {
59    let lifetimes = fun
60        .sig
61        .generics
62        .lifetimes()
63        .map(|lt| syn::GenericParam::Lifetime(lt.to_owned()));
64
65    TypeBareFn {
66        lifetimes: syn::parse2(quote! { for <#(#lifetimes),*> }).unwrap(),
67        unsafety: fun.sig.unsafety,
68        abi: fun.sig.abi.clone(),
69        fn_token: Token![fn](Span::call_site()),
70        paren_token: fun.sig.paren_token,
71        inputs: {
72            let mut inputs = Punctuated::new();
73            let mut has_ref_receiver = false;
74            for input in fun.sig.inputs.iter() {
75                inputs.push(match input {
76                    FnArg::Receiver(r) => {
77                        has_ref_receiver = r.reference.is_some();
78                        BareFnArg {
79                            attrs: r.attrs.clone(),
80                            name: Some((
81                                Ident::new("this", Span::call_site()),
82                                Token![:](Span::call_site()),
83                            )),
84                            ty: Type::Reference(TypeReference {
85                                and_token: Token![&](Span::call_site()),
86                                lifetime: r.lifetime().cloned(),
87                                mutability: r.mutability,
88                                elem: Box::new(parse_quote!(T)),
89                            }),
90                        }
91                    }
92                    FnArg::Typed(arg) => BareFnArg {
93                        attrs: arg.attrs.clone(),
94                        name: match arg.pat.as_ref() {
95                            Pat::Ident(ident) => {
96                                Some((ident.ident.clone(), Token![:](Span::call_site())))
97                            }
98                            _ => None,
99                        },
100                        ty: *arg.ty.to_owned(),
101                    },
102                });
103            }
104            if !has_ref_receiver {
105                panic!(
106                    "vtable trait method \"{0}\" must have &self or &mut self parameter",
107                    fun.sig.ident.to_string()
108                )
109            }
110            inputs
111        },
112        variadic: None,
113        output: fun.sig.output.clone(),
114    }
115}
116
117// TODO (WIP): Handle all lifetime edge cases before implementing
118fn sig_to_vtable_thunk(sig: &Signature) -> proc_macro2::TokenStream {
119    let (receiver_mut, receiver_lt) = match sig.inputs.first() {
120        Some(FnArg::Receiver(r)) => (r.mutability.clone(), r.lifetime().cloned()),
121        _ => unreachable!(),
122    };
123
124    let self_arg: FnArg = syn::parse2(quote! { &self }).unwrap();
125    let t_arg: FnArg = syn::parse2(quote! { this: & #receiver_lt #receiver_mut T }).unwrap();
126
127    let mut with_t = sig.clone();
128
129    *with_t.inputs.first_mut().unwrap() = self_arg;
130    with_t.inputs.insert(1, t_arg);
131    with_t.abi = None; // No need for an ABI on the thunk method
132
133    let ident = &sig.ident;
134    let arg_idents = with_t.inputs.iter().skip(1).map(|arg| match arg {
135        FnArg::Typed(pt) => match pt.pat.as_ref() {
136            Pat::Ident(ident_pat) => ident_pat.ident.clone(),
137            _ => unreachable!(),
138        },
139        _ => unreachable!(),
140    });
141
142    quote! {
143        #[inline]
144        pub #with_t {
145            (self.#ident)(#(#arg_idents),*)
146        }
147    }
148}
149
150/// Attribute proc macro that can be used to turn a dyn-compatible trait definition
151/// into a C++ compatible vtable definition.
152///
153/// For example, say we have a C++ abstract class of the form
154/// ```cpp
155/// struct Obj {
156///     uint32_t field;
157///
158///     virtual ~Obj() = default;
159///     virtual uint32_t method(uint32_t arg) const noexcept = 0;
160/// };
161/// ```
162///
163/// This macro then allows us to represent `Obj`'s virtual function table in Rust
164/// and provide our own implementations:
165///
166/// ```rs
167/// use vtable_rs::{vtable, VPtr};
168///
169/// #[vtable]
170/// pub trait ObjVmt {
171///     fn destructor(&mut self) {
172///         // We can provide a default implementation too!
173///     }
174///     fn method(&self, arg: u32) -> u32;
175/// }
176///
177/// // VPtr implements Default for types that implement the trait, and provides
178/// // a compile-time generated vtable!
179/// #[derive(Default)]
180/// #[repr(C)]
181/// struct RustObj {
182///     vftable: VPtr<dyn ObjVmt, Self>,
183///     field: u32
184/// }
185///
186/// impl ObjVmt for RustObj {
187///     extern "C" fn method(&self, arg: u32) -> u32 {
188///         self.field + arg
189///     }
190/// }
191///
192/// ```
193///
194/// `RustObj` could then be passed to a C++ function that takes in a pointer to `Obj`.
195///
196/// The macro supports single inhertiance through a single trait bound, e.g.
197///
198/// ```rs
199/// #[vtable]
200/// pub trait DerivedObjVmt: ObjVmt {
201///     unsafe fn additional_method(&mut self, s: *const c_char);
202/// }
203/// ```
204///
205/// The vtable layout is fully typed and can be accessed as `<dyn TraitName as VmtLayout>::Layout<T>`.
206/// A `VPtr` can be `Deref`'d into it to obtain the bare function pointers and thus call through
207/// the vtable directly:
208///
209/// ```rs
210/// let obj = RustObj::default();
211/// let method_impl = obj.vftable.method; // extern "C" fn(&RustObj, u32) -> u32
212/// let call_result = method_impl(obj, 42);
213/// ```
214#[proc_macro_attribute]
215pub fn vtable(_attr: TokenStream, item: TokenStream) -> TokenStream {
216    let mut trait_def: ItemTrait = syn::parse(item).unwrap();
217
218    check_restrictions(&trait_def);
219
220    let base_trait = extract_base_trait(&trait_def);
221
222    // Add 'static lifetime bound to the trait
223    trait_def.supertraits.push(TypeParamBound::Lifetime(Lifetime::new(
224        "'static",
225        Span::call_site(),
226    )));
227
228    // TODO: generate a #[cfg] to switch to fastcall for x86 windows support
229    set_method_abis(&mut trait_def, "C");
230
231    let layout_ident = Ident::new(&(trait_def.ident.to_string() + "Layout"), Span::call_site());
232    let signatures: Vec<_> = trait_def
233        .items
234        .iter()
235        .filter_map(|item| {
236            if let TraitItem::Fn(fun) = item {
237                Some(&fun.sig)
238            }
239            else {
240                None
241            }
242        })
243        .collect();
244
245    let trait_ident = &trait_def.ident;
246    let trait_vis = &trait_def.vis;
247    let fn_idents: Vec<_> = signatures.iter().map(|sig| &sig.ident).collect();
248    let bare_fns = trait_def.items.iter().filter_map(|item| match item {
249        TraitItem::Fn(fun) => Some(trait_fn_to_bare_fn(fun)),
250        _ => None,
251    });
252
253    // Create token stream with base layout declaration if a base trait is present
254    let base_decl = if base_trait.is_empty() {
255        proc_macro2::TokenStream::new()
256    }
257    else {
258        quote! { _base: self._base, }
259    };
260
261    let base_deref_impl = match base_trait.first() {
262        None => proc_macro2::TokenStream::new(),
263        Some(base) => quote! {
264            impl<T: 'static> ::core::ops::Deref for #layout_ident<T> {
265                type Target = <dyn #base as ::vtable_rs::VmtLayout>::Layout<T>;
266
267                fn deref(&self) -> &Self::Target {
268                    &self._base
269                }
270            }
271            impl<T: 'static> ::core::ops::DerefMut for #layout_ident<T> {
272                fn deref_mut(&mut self) -> &mut Self::Target {
273                    &mut self._base
274                }
275            }
276        },
277    };
278
279    // TODO: Figure out 100% reliable strategy to adjust lifetimes
280    // so that lifetime inference works as expected in the trait definition
281    //let thunk_impls = signatures.iter().map(|&s| sig_to_vtable_thunk(s));
282
283    let mut tokens = trait_def.to_token_stream();
284    tokens.extend(quote! {
285        #[repr(C)]
286        #trait_vis struct #layout_ident<T: 'static> {
287            #(_base: <dyn #base_trait as ::vtable_rs::VmtLayout>::Layout<T>,)*
288            #(pub #fn_idents: #bare_fns,)*
289        }
290
291        // impl<T: 'static> #layout_ident<T> {
292        //     #(#thunk_impls)*
293        // }
294
295        impl<T> ::core::clone::Clone for #layout_ident<T> {
296            fn clone(&self) -> Self {
297                Self {
298                    #base_decl
299                    #(#fn_idents: self.#fn_idents),*
300                }
301            }
302        }
303        impl<T> ::core::marker::Copy for #layout_ident<T> {}
304
305        #base_deref_impl
306
307        unsafe impl ::vtable_rs::VmtLayout for dyn #trait_ident {
308            type Layout<T: 'static> = #layout_ident<T>;
309        }
310
311        impl<T: #trait_ident> ::vtable_rs::VmtInstance<T> for dyn #trait_ident {
312            const VTABLE: &'static Self::Layout<T> = &#layout_ident {
313                #(_base: *<dyn #base_trait as ::vtable_rs::VmtInstance<T>>::VTABLE,)*
314                #(#fn_idents: <T as #trait_ident>::#fn_idents),*
315            };
316        }
317    });
318
319    tokens.into()
320}