Skip to main content

vtable_rs_proc_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::{quote, ToTokens};
4use syn::{
5    parse_quote, punctuated::Punctuated, Abi, BareFnArg, FnArg, Ident, ItemTrait, Lifetime, LitStr,
6    Pat, Token, TraitItem, TraitItemFn, Type, TypeBareFn, TypeParamBound, TypeReference,
7};
8
9fn check_restrictions(trait_def: &ItemTrait) {
10    // First, make sure we support the trait
11    if trait_def.generics.lifetimes().any(|_| true) {
12        panic!("vtable trait cannot be given a lifetime")
13    }
14    if trait_def.auto_token.is_some() {
15        panic!("vtable trait cannot be auto")
16    }
17    if trait_def.unsafety.is_some() {
18        panic!("vtable trait cannot be unsafe")
19    }
20    if trait_def.supertraits.len() > 1 {
21        panic!("vtable trait can only have a single supertrait")
22    }
23    if trait_def.items.is_empty() {
24        panic!("vtable trait must contain at least one function")
25    }
26}
27
28fn extract_base_trait(trait_def: &ItemTrait) -> Vec<proc_macro2::TokenStream> {
29    match trait_def.supertraits.first() {
30        None => None,
31        Some(TypeParamBound::Trait(t)) => Some(t.to_token_stream()),
32        Some(_) => panic!(
33            "vtable trait's bounds must be a single trait representing the base class's vtable."
34        ),
35    }
36    .into_iter()
37    .collect()
38}
39
40fn set_method_abis(trait_def: &mut ItemTrait, abi: &str) {
41    for item in trait_def.items.iter_mut() {
42        if let TraitItem::Fn(fun) = item {
43            // Add "extern C" ABI to the function if not present
44            fun.sig.abi.get_or_insert(Abi {
45                extern_token: Token![extern](Span::call_site()),
46                name: Some(LitStr::new(abi, Span::call_site())),
47            });
48        }
49        else {
50            panic!("vtable trait can only contain functions")
51        }
52    }
53}
54
55fn trait_fn_to_bare_fn(fun: &TraitItemFn) -> TypeBareFn {
56    let lifetimes = fun
57        .sig
58        .generics
59        .lifetimes()
60        .map(|lt| syn::GenericParam::Lifetime(lt.to_owned()));
61
62    TypeBareFn {
63        lifetimes: parse_quote! { for <#(#lifetimes),*> },
64        unsafety: fun.sig.unsafety,
65        abi: fun.sig.abi.clone(),
66        fn_token: Token![fn](Span::call_site()),
67        paren_token: fun.sig.paren_token,
68        inputs: {
69            let mut inputs = Punctuated::new();
70            let mut has_ref_receiver = false;
71            for input in fun.sig.inputs.iter() {
72                inputs.push(match input {
73                    FnArg::Receiver(r) => {
74                        has_ref_receiver = r.reference.is_some();
75                        BareFnArg {
76                            attrs: r.attrs.clone(),
77                            name: Some((
78                                Ident::new("this", Span::call_site()),
79                                Token![:](Span::call_site()),
80                            )),
81                            ty: Type::Reference(TypeReference {
82                                and_token: Token![&](Span::call_site()),
83                                lifetime: r.lifetime().cloned(),
84                                mutability: r.mutability,
85                                elem: Box::new(parse_quote!(__VtableT)),
86                            }),
87                        }
88                    }
89                    FnArg::Typed(arg) => BareFnArg {
90                        attrs: arg.attrs.clone(),
91                        name: match arg.pat.as_ref() {
92                            Pat::Ident(ident) => {
93                                Some((ident.ident.clone(), Token![:](Span::call_site())))
94                            }
95                            _ => None,
96                        },
97                        ty: *arg.ty.to_owned(),
98                    },
99                });
100            }
101            if !has_ref_receiver {
102                panic!(
103                    "vtable trait method \"{0}\" must have &self or &mut self parameter",
104                    fun.sig.ident
105                )
106            }
107            inputs
108        },
109        variadic: None,
110        output: fun.sig.output.clone(),
111    }
112}
113
114/// Attribute proc macro that can be used to turn a dyn-compatible trait definition
115/// into a C++ compatible vtable definition.
116///
117/// For example, say we have a C++ abstract class of the form
118/// ```cpp
119/// struct Obj {
120///     uint32_t field;
121///
122///     virtual ~Obj() = default;
123///     virtual uint32_t method(uint32_t arg) const noexcept = 0;
124/// };
125/// ```
126///
127/// This macro then allows us to represent `Obj`'s virtual function table in Rust
128/// and provide our own implementations:
129///
130/// ```rs
131/// use vtable_rs::{vtable, VPtr};
132///
133/// #[vtable]
134/// pub trait ObjVmt {
135///     fn destructor(&mut self) {
136///         // We can provide a default implementation too!
137///     }
138///     fn method(&self, arg: u32) -> u32;
139/// }
140///
141/// // VPtr implements Default for types that implement the trait, and provides
142/// // a compile-time generated vtable!
143/// #[derive(Default)]
144/// #[repr(C)]
145/// struct RustObj {
146///     vftable: VPtr<dyn ObjVmt, Self>,
147///     field: u32
148/// }
149///
150/// impl ObjVmt for RustObj {
151///     extern "C" fn method(&self, arg: u32) -> u32 {
152///         self.field + arg
153///     }
154/// }
155///
156/// ```
157///
158/// `RustObj` could then be passed to a C++ function that takes in a pointer to `Obj`.
159///
160/// The macro supports single inheritance through a single trait bound, e.g.
161///
162/// ```rs
163/// #[vtable]
164/// pub trait DerivedObjVmt: ObjVmt {
165///     unsafe fn additional_method(&mut self, s: *const c_char);
166/// }
167/// ```
168///
169/// The vtable layout is fully typed and can be accessed as `<dyn TraitName as VmtLayout>::Layout<T>`.
170/// A `VPtr` can be `Deref`'d into it to obtain the bare function pointers and thus call through
171/// the vtable directly:
172///
173/// ```rs
174/// let obj = RustObj::default();
175/// let method_impl = obj.vftable.method; // extern "C" fn(&RustObj, u32) -> u32
176/// let call_result = method_impl(obj, 42);
177/// ```
178#[proc_macro_attribute]
179pub fn vtable(_attr: TokenStream, item: TokenStream) -> TokenStream {
180    let mut trait_def: ItemTrait = syn::parse(item).unwrap();
181
182    check_restrictions(&trait_def);
183
184    let base_trait = extract_base_trait(&trait_def);
185
186    // Add 'static lifetime bound to the trait
187    trait_def.supertraits.push(TypeParamBound::Lifetime(Lifetime::new(
188        "'static",
189        Span::call_site(),
190    )));
191
192    // Add 'static lifetime bounds to type parameters
193    trait_def
194        .generics
195        .type_params_mut()
196        .for_each(|tp| tp.bounds.push(parse_quote!('static)));
197
198    // TODO: generate a #[cfg] to switch to fastcall for x86 windows support
199    set_method_abis(&mut trait_def, "C");
200
201    let generics = &trait_def.generics;
202    let (o_impl_generics, o_ty_generics, o_where_clause) = generics.split_for_impl();
203
204    let layout_ident = Ident::new(&(trait_def.ident.to_string() + "Layout"), Span::call_site());
205    let signatures: Vec<_> = trait_def
206        .items
207        .iter()
208        .filter_map(|item| {
209            if let TraitItem::Fn(fun) = item {
210                Some(&fun.sig)
211            }
212            else {
213                None
214            }
215        })
216        .collect();
217
218    let trait_ident = &trait_def.ident;
219    let trait_vis = &trait_def.vis;
220    let fn_idents: Vec<_> = signatures.iter().map(|sig| &sig.ident).collect();
221    let bare_fns = trait_def.items.iter().filter_map(|item| match item {
222        TraitItem::Fn(fun) => Some(trait_fn_to_bare_fn(fun)),
223        _ => None,
224    });
225
226    // Generics for layout type
227    let mut concrete_generics = generics.clone();
228    concrete_generics.params.insert(0, parse_quote!(__VtableT: 'static));
229    let (impl_generics, ty_generics, where_clause) = concrete_generics.split_for_impl();
230
231    // Generics for VmtInstance impl
232    let mut i_generics = generics.clone();
233    i_generics.params.push(parse_quote!(__VtableT: #trait_ident #o_ty_generics));
234    let (i_impl_generics, _, _) = i_generics.split_for_impl();
235
236    let base_deref_impl = match base_trait.first() {
237        None => proc_macro2::TokenStream::new(),
238        Some(base) => quote! {
239            impl #impl_generics ::core::ops::Deref for #layout_ident #ty_generics #where_clause {
240                type Target = <dyn #base as ::vtable_rs::VmtLayout>::Layout<__VtableT>;
241
242                fn deref(&self) -> &Self::Target {
243                    &self._base
244                }
245            }
246            impl #impl_generics ::core::ops::DerefMut for #layout_ident #ty_generics #where_clause {
247                fn deref_mut(&mut self) -> &mut Self::Target {
248                    &mut self._base
249                }
250            }
251        },
252    };
253
254    let doc_lit = format!("Virtual method table layout for [`{}`].", trait_ident);
255    let mut tokens = trait_def.to_token_stream();
256    tokens.extend(quote! {
257        #[doc = #doc_lit]
258        #[repr(C)]
259        #trait_vis struct #layout_ident #impl_generics #where_clause {
260            #(_base: <dyn #base_trait as ::vtable_rs::VmtLayout>::Layout<__VtableT>,)*
261            #(pub #fn_idents: #bare_fns,)*
262        }
263
264        impl #impl_generics ::core::clone::Clone for #layout_ident #ty_generics #where_clause {
265            fn clone(&self) -> Self {
266                *self
267            }
268        }
269        impl #impl_generics ::core::marker::Copy for #layout_ident #ty_generics #where_clause {}
270
271        #base_deref_impl
272
273        unsafe impl #o_impl_generics ::vtable_rs::VmtLayout for dyn #trait_ident #o_ty_generics #o_where_clause {
274            type Layout<__VtableT: 'static> = #layout_ident #ty_generics;
275        }
276
277        impl #i_impl_generics ::vtable_rs::VmtInstance<__VtableT> for dyn #trait_ident #o_ty_generics #o_where_clause {
278            const VTABLE: &'static Self::Layout<__VtableT> = &#layout_ident {
279                #(_base: *<dyn #base_trait as ::vtable_rs::VmtInstance<__VtableT>>::VTABLE,)*
280                #(#fn_idents: <__VtableT as #trait_ident #o_ty_generics>::#fn_idents),*
281            };
282        }
283    });
284
285    tokens.into()
286}