viable_impl/
lib.rs

1use proc_macro::TokenStream;
2
3use quote::{quote, format_ident, ToTokens};
4use syn::{parse_macro_input, parse_quote, ItemStruct, ItemImpl, ImplItem, ImplItemMethod, Fields, LitInt};
5
6#[proc_macro_attribute]
7/// Defines a struct that will act as a VTable to a C++ class.
8/// It can also take data as to make sure the class functions as expected.
9/// # Example
10/// ```cpp
11/// #define interface class __declspec(novtable)
12/// interface MathEngine {
13/// public:
14///         virtual int add(int x, int y) = 0;
15///         virtual int add2(int x, int y) = 0;
16/// };
17///
18/// class MyEngine: public MathEngine {
19/// public:
20///     int mynum;
21///
22///     MyEngine(int b) {
23///         mynum = b;
24///     }
25///
26///     virtual int add(int x, int y) {
27///         return x + y;
28///     }
29///
30///     virtual int add2(int x, int y) {
31///         return mynum + x + y;
32///     }
33/// };
34///
35/// extern "C" {
36///     MyEngine* getMath(int b) {
37///     return new MyEngine(b);
38/// }
39/// };
40/// ```
41/// ```rust, no_run
42/// use std::os::raw::c_int;
43/// use viable_impl::*;
44/// extern "C" {
45///     fn getMath(b: c_int) -> *mut MathIFace;
46/// }
47///
48/// use viable_impl::vtable;
49/// #[vtable]
50/// struct MathIFace {
51///     internal: i32,
52///
53///     #[offset(0)]
54///     add: fn(a: c_int, b: c_int) -> c_int,
55///     add2: fn(a: c_int, b: c_int) -> c_int,
56/// }
57///
58/// pub fn main() {
59///     let iface = unsafe { getMath(10) };
60///     let iface = unsafe { iface.as_mut().unwrap() };
61///
62///     let value = iface.add2(5, 5);
63///
64///    assert_eq!(value, 10 + 5 + 5);
65/// }
66/// ```
67pub fn vtable(_attr: TokenStream, item: TokenStream) -> TokenStream {
68	let mut ast = parse_macro_input!(item as ItemStruct);
69
70	let ident = &ast.ident;
71
72	let mut interface: ItemImpl = parse_quote! {
73		impl #ident {}
74	};
75
76	let mut fields = vec![];
77	let mut count = 0usize;
78	for f in &ast.fields {
79		if let Some(id) = &f.ident {
80			match f.ty {
81				syn::Type::BareFn(ref x) => {
82					// Look for custom attributes
83					let mut covered = vec![];
84					for attr in &f.attrs {
85						let s = attr.path.to_token_stream().to_string();
86
87						let str = s.as_str();
88						if covered.contains(&s) {
89							panic!("Repeated attribute: {}", s);
90						}
91
92						match str {
93							"offset" | "check" => {
94								let offset: LitInt = attr
95									.parse_args()
96									.expect("Expected integer for offset");
97								let num = offset.base10_parse::<usize>()
98									.expect("Offset must be usize");
99
100								match str {
101									"offset" => {
102										count = num;
103									},
104									"check" => {
105										if count != num {
106											panic!("Check failed, expected offset to be {}, but was {}", num, count);
107										}
108									},
109									_ => unreachable!()
110								}
111								covered.push(s);
112							},
113							"skip" => {
114								let by: LitInt = attr
115									.parse_args()
116									.expect("Expected integer for skip");
117								let num = by.base10_parse::<isize>()
118									.expect("Skip must be isize");
119
120								// There's surely a more elegant way to do this.
121								if num < 0 {
122									let num = num.abs() as usize;
123									if num > count {
124										panic!("Skip would move offset below 0");
125									}
126									count -= num;
127								} else {
128									count += num as usize;
129								}
130
131								covered.push(s);
132							}
133							_ => ()
134						}
135					}
136
137					let (name, ty) = (id.to_string(), x.clone());
138					let ret = &ty.output;
139
140					// Full signature with the pointer to the original class.
141					let mut ty_full = ty.clone();
142					ty_full.inputs.insert(0,  parse_quote! {_self: *mut Self} );
143
144					let inputs = &ty.inputs;
145
146					let name = format_ident!("{name}");
147
148					let mut call: syn::ExprCall = parse_quote! {
149						func(self)
150					};
151
152					for (pnum, i) in inputs.iter().enumerate() {
153						let b = i.name.as_ref().map(|(i, _)| i.to_string()).unwrap_or(format!("argn{pnum}"));
154						let b = format_ident!("{}", b);
155
156						call.args.push_punct( parse_quote! { , } );
157						call.args.push_value( parse_quote! { #b } );
158					}
159
160					let mut item: ImplItemMethod = parse_quote! {
161						// #[offset(#count)]
162						fn #name(&mut self, #inputs) #ret {
163							let vtable = self.vtable as *const #ty_full;
164							let func = unsafe { vtable.add(#count).read() };
165						}
166					};
167
168					item.vis = f.vis.clone();
169
170					item.block.stmts.push( syn::Stmt::Expr( syn::Expr::Call(call) ) );
171					interface.items.push( ImplItem::Method(item) );
172					count += 1;
173				},
174				_ => {
175					fields.push(f);
176					//panic!("VTable fields must be bare functions!")
177				},
178			}
179		}
180	}
181
182
183	let mut struc: ItemStruct = parse_quote! {
184		#[repr(C)]
185		struct #ident {
186			pub vtable: *mut *mut usize,
187		}
188	};
189
190	struc.vis = ast.vis;
191	struc.attrs.append(&mut ast.attrs);
192
193	// Add data fields (non bare functions)
194	for f in fields {
195		if let Fields::Named(ref mut x) = struc.fields {
196			x.named.push( f.to_owned() );
197		}
198	}
199
200	quote! {
201		#struc
202		#interface
203	}.into()
204}