vtable_rs_proc_macros/
lib.rs1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::{quote, ToTokens};
4use syn::{
5 parse_quote, punctuated::Punctuated, Abi, BareFnArg, FnArg, Ident, ItemTrait, Lifetime, LitStr,
6 Pat, Token, TraitItem, TraitItemFn, Type, TypeBareFn, TypeParamBound, TypeReference,
7};
8
9fn check_restrictions(trait_def: &ItemTrait) {
10 if trait_def.generics.lifetimes().any(|_| true) {
12 panic!("vtable trait cannot be given a lifetime")
13 }
14 if trait_def.auto_token.is_some() {
15 panic!("vtable trait cannot be auto")
16 }
17 if trait_def.unsafety.is_some() {
18 panic!("vtable trait cannot be unsafe")
19 }
20 if trait_def.supertraits.len() > 1 {
21 panic!("vtable trait can only have a single supertrait")
22 }
23 if trait_def.items.is_empty() {
24 panic!("vtable trait must contain at least one function")
25 }
26}
27
28fn extract_base_trait(trait_def: &ItemTrait) -> Vec<proc_macro2::TokenStream> {
29 match trait_def.supertraits.first() {
30 None => None,
31 Some(TypeParamBound::Trait(t)) => Some(t.to_token_stream()),
32 Some(_) => panic!(
33 "vtable trait's bounds must be a single trait representing the base class's vtable."
34 ),
35 }
36 .into_iter()
37 .collect()
38}
39
40fn set_method_abis(trait_def: &mut ItemTrait, abi: &str) {
41 for item in trait_def.items.iter_mut() {
42 if let TraitItem::Fn(fun) = item {
43 fun.sig.abi.get_or_insert(Abi {
45 extern_token: Token),
46 name: Some(LitStr::new(abi, Span::call_site())),
47 });
48 }
49 else {
50 panic!("vtable trait can only contain functions")
51 }
52 }
53}
54
55fn trait_fn_to_bare_fn(fun: &TraitItemFn) -> TypeBareFn {
56 let lifetimes = fun
57 .sig
58 .generics
59 .lifetimes()
60 .map(|lt| syn::GenericParam::Lifetime(lt.to_owned()));
61
62 TypeBareFn {
63 lifetimes: parse_quote! { for <#(#lifetimes),*> },
64 unsafety: fun.sig.unsafety,
65 abi: fun.sig.abi.clone(),
66 fn_token: Token),
67 paren_token: fun.sig.paren_token,
68 inputs: {
69 let mut inputs = Punctuated::new();
70 let mut has_ref_receiver = false;
71 for input in fun.sig.inputs.iter() {
72 inputs.push(match input {
73 FnArg::Receiver(r) => {
74 has_ref_receiver = r.reference.is_some();
75 BareFnArg {
76 attrs: r.attrs.clone(),
77 name: Some((
78 Ident::new("this", Span::call_site()),
79 Token),
80 )),
81 ty: Type::Reference(TypeReference {
82 and_token: Token),
83 lifetime: r.lifetime().cloned(),
84 mutability: r.mutability,
85 elem: Box::new(parse_quote!(__VtableT)),
86 }),
87 }
88 }
89 FnArg::Typed(arg) => BareFnArg {
90 attrs: arg.attrs.clone(),
91 name: match arg.pat.as_ref() {
92 Pat::Ident(ident) => {
93 Some((ident.ident.clone(), Token)))
94 }
95 _ => None,
96 },
97 ty: *arg.ty.to_owned(),
98 },
99 });
100 }
101 if !has_ref_receiver {
102 panic!(
103 "vtable trait method \"{0}\" must have &self or &mut self parameter",
104 fun.sig.ident
105 )
106 }
107 inputs
108 },
109 variadic: None,
110 output: fun.sig.output.clone(),
111 }
112}
113
114#[proc_macro_attribute]
179pub fn vtable(_attr: TokenStream, item: TokenStream) -> TokenStream {
180 let mut trait_def: ItemTrait = syn::parse(item).unwrap();
181
182 check_restrictions(&trait_def);
183
184 let base_trait = extract_base_trait(&trait_def);
185
186 trait_def.supertraits.push(TypeParamBound::Lifetime(Lifetime::new(
188 "'static",
189 Span::call_site(),
190 )));
191
192 trait_def
194 .generics
195 .type_params_mut()
196 .for_each(|tp| tp.bounds.push(parse_quote!('static)));
197
198 set_method_abis(&mut trait_def, "C");
200
201 let generics = &trait_def.generics;
202 let (o_impl_generics, o_ty_generics, o_where_clause) = generics.split_for_impl();
203
204 let layout_ident = Ident::new(&(trait_def.ident.to_string() + "Layout"), Span::call_site());
205 let signatures: Vec<_> = trait_def
206 .items
207 .iter()
208 .filter_map(|item| {
209 if let TraitItem::Fn(fun) = item {
210 Some(&fun.sig)
211 }
212 else {
213 None
214 }
215 })
216 .collect();
217
218 let trait_ident = &trait_def.ident;
219 let trait_vis = &trait_def.vis;
220 let fn_idents: Vec<_> = signatures.iter().map(|sig| &sig.ident).collect();
221 let bare_fns = trait_def.items.iter().filter_map(|item| match item {
222 TraitItem::Fn(fun) => Some(trait_fn_to_bare_fn(fun)),
223 _ => None,
224 });
225
226 let mut concrete_generics = generics.clone();
228 concrete_generics.params.insert(0, parse_quote!(__VtableT: 'static));
229 let (impl_generics, ty_generics, where_clause) = concrete_generics.split_for_impl();
230
231 let mut i_generics = generics.clone();
233 i_generics.params.push(parse_quote!(__VtableT: #trait_ident #o_ty_generics));
234 let (i_impl_generics, _, _) = i_generics.split_for_impl();
235
236 let base_deref_impl = match base_trait.first() {
237 None => proc_macro2::TokenStream::new(),
238 Some(base) => quote! {
239 impl #impl_generics ::core::ops::Deref for #layout_ident #ty_generics #where_clause {
240 type Target = <dyn #base as ::vtable_rs::VmtLayout>::Layout<__VtableT>;
241
242 fn deref(&self) -> &Self::Target {
243 &self._base
244 }
245 }
246 impl #impl_generics ::core::ops::DerefMut for #layout_ident #ty_generics #where_clause {
247 fn deref_mut(&mut self) -> &mut Self::Target {
248 &mut self._base
249 }
250 }
251 },
252 };
253
254 let doc_lit = format!("Virtual method table layout for [`{}`].", trait_ident);
255 let mut tokens = trait_def.to_token_stream();
256 tokens.extend(quote! {
257 #[doc = #doc_lit]
258 #[repr(C)]
259 #trait_vis struct #layout_ident #impl_generics #where_clause {
260 #(_base: <dyn #base_trait as ::vtable_rs::VmtLayout>::Layout<__VtableT>,)*
261 #(pub #fn_idents: #bare_fns,)*
262 }
263
264 impl #impl_generics ::core::clone::Clone for #layout_ident #ty_generics #where_clause {
265 fn clone(&self) -> Self {
266 *self
267 }
268 }
269 impl #impl_generics ::core::marker::Copy for #layout_ident #ty_generics #where_clause {}
270
271 #base_deref_impl
272
273 unsafe impl #o_impl_generics ::vtable_rs::VmtLayout for dyn #trait_ident #o_ty_generics #o_where_clause {
274 type Layout<__VtableT: 'static> = #layout_ident #ty_generics;
275 }
276
277 impl #i_impl_generics ::vtable_rs::VmtInstance<__VtableT> for dyn #trait_ident #o_ty_generics #o_where_clause {
278 const VTABLE: &'static Self::Layout<__VtableT> = &#layout_ident {
279 #(_base: *<dyn #base_trait as ::vtable_rs::VmtInstance<__VtableT>>::VTABLE,)*
280 #(#fn_idents: <__VtableT as #trait_ident #o_ty_generics>::#fn_idents),*
281 };
282 }
283 });
284
285 tokens.into()
286}