scotch_guest_macros/
lib.rs1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{
4 parse_macro_input, parse_quote, Expr, FnArg, ForeignItem, ItemFn, ItemForeignMod, Pat,
5 ReturnType, Signature, Stmt, Type, TypeReference,
6};
7
8fn is_atom_type(ty: &str) -> bool {
9 const ATOMS: &[&str] = &[
10 "bool", "char", "u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64",
11 ];
12
13 ATOMS.iter().any(|&a| a == ty)
14}
15
16#[derive(Clone, Copy)]
17enum WrapMode {
18 Encoded,
19 Managed,
20}
21
22impl WrapMode {
23 fn wrap(self, ty: Type) -> Type {
24 match self {
25 WrapMode::Encoded => parse_quote!(scotch_guest::EncodedPtr<#ty>),
26 WrapMode::Managed => parse_quote!(scotch_guest::MemoryType),
27 }
28 }
29}
30
31enum TypeTranslation {
32 Original,
33 Wrapped(Type),
34}
35
36fn translate_type(ty: Type, mode: WrapMode, allow_owned: bool) -> TypeTranslation {
37 match ty {
38 Type::Path(ref path)
39 if is_atom_type(&path.path.segments.last().unwrap().ident.to_string()) =>
40 {
41 TypeTranslation::Original
42 }
43 Type::Reference(TypeReference {
44 lifetime: None,
45 mutability: None,
46 elem,
47 ..
48 }) => TypeTranslation::Wrapped(mode.wrap(*elem)),
49 Type::Array(_) | Type::Tuple(_) => TypeTranslation::Wrapped(mode.wrap(ty)),
50 Type::Path(_) if allow_owned => TypeTranslation::Wrapped(mode.wrap(ty)),
51 _ => unimplemented!("Type is unsupported, consider using a reference instead."),
52 }
53}
54
55#[derive(Default)]
56struct HostInputTranslation {
57 call_args: Vec<Expr>,
58 prelude: Vec<Stmt>,
59 epilogue: Vec<Stmt>,
60}
61
62fn translate_host_inputs<'a>(it: impl Iterator<Item = &'a mut FnArg>) -> HostInputTranslation {
63 let mut out = HostInputTranslation::default();
64
65 it.map(|arg| {
66 if let FnArg::Typed(arg) = arg {
67 arg
68 } else {
69 panic!("self is not allowed in host functions")
70 }
71 })
72 .map(|arg| {
73 if let Pat::Ident(name) = arg.pat.as_mut() {
74 (name.ident.clone(), &mut arg.ty)
75 } else {
76 panic!("Invalid function argument name")
77 }
78 })
79 .for_each(|(name, ty)| {
80 if let TypeTranslation::Wrapped(new) =
81 translate_type(ty.as_ref().clone(), WrapMode::Managed, false)
82 {
83 *ty = Box::new(new);
84 out.prelude
85 .push(parse_quote!(let #name = scotch_guest::ManagedPtr::new(#name).unwrap();));
86 out.epilogue.push(parse_quote!(#name.free();));
87 out.call_args.push(parse_quote!(#name.offset()));
88 } else {
89 out.call_args.push(parse_quote!(#name));
90 }
91 });
92
93 out
94}
95
96fn translate_host_output(ret: &mut ReturnType) -> Stmt {
97 let mut out = parse_quote!(return out;);
98
99 if let ReturnType::Type(_, ty) = ret {
100 if let TypeTranslation::Wrapped(new) =
101 translate_type(ty.as_ref().clone(), WrapMode::Managed, true)
102 {
103 *ty = Box::new(new);
104 out = parse_quote! {return {
105 let ptr = scotch_guest::ManagedPtr::with_size_by_address(out);
106 let value = ptr.read().expect("Guest received invalid ptr");
107 ptr.free();
108 value
109 };};
110 }
111 }
112
113 out
114}
115
116#[proc_macro_attribute]
124pub fn host_functions(_: TokenStream, input: TokenStream) -> TokenStream {
125 let host_funcs = parse_macro_input!(input as ItemForeignMod);
126 let funcs = host_funcs
127 .items
128 .into_iter()
129 .map(|item| {
130 if let ForeignItem::Fn(func) = item {
132 func
133 } else {
134 panic!("Only functions are allowed in host_functions block")
135 }
136 })
137 .map(|mut func| {
138 let Signature {
139 ident,
140 inputs,
141 output,
142 ..
143 } = func.sig.clone();
144
145 let sig = &mut func.sig;
146 let ending = translate_host_output(&mut sig.output);
147
148 let fake_id = format_ident!("_host_{}", sig.ident);
149 sig.ident = fake_id.clone();
150
151 let HostInputTranslation {
152 prelude,
153 epilogue,
154 call_args,
155 } = translate_host_inputs(sig.inputs.iter_mut());
156
157 quote! {
158 fn #ident(#inputs) #output {
159 extern "C" {
160 #[link_name = stringify!(#ident)]
161 #sig;
162 }
163
164 unsafe {
165 #(#prelude)*
166 let out = #fake_id(#(#call_args),*);
167 #(#epilogue)*
168
169 #ending
170 }
171 }
172 }
173 });
174
175 let out = quote! {
176 #(#funcs)*
177 };
178 out.into()
179}
180
181#[derive(Default)]
182struct GuestInputTranslation {
183 prelude: Vec<Stmt>,
184}
185
186fn translate_guest_inputs<'a>(it: impl Iterator<Item = &'a mut FnArg>) -> GuestInputTranslation {
187 let mut out = GuestInputTranslation::default();
188
189 it.map(|arg| {
190 let FnArg::Typed(arg) = arg else { panic!("self is not allowed in guest functions") };
191 let Pat::Ident(id) = &*arg.pat else { panic!("Invalid function declation") };
192 (id.ident.clone(), &mut arg.ty)
193 })
194 .for_each(|(name, ty)| {
195 if let TypeTranslation::Wrapped(new) = translate_type(ty.as_ref().clone(), WrapMode::Encoded, false) {
196 out.prelude
197 .push(parse_quote!(let #name: #ty = &unsafe { #name.read().expect("Guest was given invalid pointer") };));
198 *ty = Box::new(new);
199 };
200 });
201
202 out
203}
204
205fn translate_guest_output(ret: &mut ReturnType) -> Stmt {
206 let mut out = parse_quote!(return out;);
207
208 if let ReturnType::Type(_, ty) = ret {
209 if let TypeTranslation::Wrapped(new) =
210 translate_type(ty.as_ref().clone(), WrapMode::Managed, true)
211 {
212 *ty = Box::new(new);
213 out = parse_quote!(return scotch_guest::ManagedPtr::new(&out).unwrap().offset(););
214 }
215 }
216
217 out
218}
219
220#[proc_macro_attribute]
228pub fn guest_function(_: TokenStream, input: TokenStream) -> TokenStream {
229 let mut item_fn = parse_macro_input!(input as ItemFn);
230 item_fn.attrs.push(parse_quote!(#[no_mangle]));
231 item_fn.sig.abi = Some(parse_quote!(extern "C"));
232
233 let GuestInputTranslation { prelude } = translate_guest_inputs(item_fn.sig.inputs.iter_mut());
234 let output = item_fn.sig.output.clone();
235 let epilogue = translate_guest_output(&mut item_fn.sig.output);
236 let body = item_fn.block;
237
238 item_fn.block = parse_quote!({
239 #(#prelude)*
240 let out = (move || #output #body)();
241 #epilogue
242 });
243
244 let out = quote! {
245 #item_fn
246 };
247
248 out.into()
249}