1use proc_macro::TokenStream;
2use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
3use quote::{format_ident, quote};
4use syn::parse_macro_input;
5use wiggle_generate::Names;
6
7mod config;
8
9use config::{AsyncConf, Asyncness, ModuleConf, TargetConf};
10
11#[proc_macro]
46pub fn wasmtime_integration(args: TokenStream) -> TokenStream {
47 let config = parse_macro_input!(args as config::Config);
48 let doc = config.load_document();
49 let names = Names::new(quote!(wasmtime_wiggle));
50
51 let modules = config.modules.iter().map(|(name, module_conf)| {
52 let module = doc
53 .module(&witx::Id::new(name))
54 .unwrap_or_else(|| panic!("witx document did not contain module named '{}'", name));
55 generate_module(
56 &module,
57 &module_conf,
58 &names,
59 &config.target,
60 &config.ctx.name,
61 &config.async_,
62 )
63 });
64 quote!( #(#modules)* ).into()
65}
66
67fn generate_module(
68 module: &witx::Module,
69 module_conf: &ModuleConf,
70 names: &Names,
71 target_conf: &TargetConf,
72 ctx_type: &syn::Type,
73 async_conf: &AsyncConf,
74) -> TokenStream2 {
75 let fields = module.funcs().map(|f| {
76 let name_ident = names.func(&f.name);
77 quote! { pub #name_ident: wasmtime::Func }
78 });
79 let get_exports = module.funcs().map(|f| {
80 let func_name = f.name.as_str();
81 let name_ident = names.func(&f.name);
82 quote! { #func_name => Some(&self.#name_ident) }
83 });
84 let ctor_fields = module.funcs().map(|f| names.func(&f.name));
85
86 let module_name = module.name.as_str();
87
88 let linker_add = module.funcs().map(|f| {
89 let func_name = f.name.as_str();
90 let name_ident = names.func(&f.name);
91 quote! {
92 linker.define(#module_name, #func_name, self.#name_ident.clone())?;
93 }
94 });
95
96 let target_path = &target_conf.path;
97 let module_id = names.module(&module.name);
98 let target_module = quote! { #target_path::#module_id };
99
100 let mut fns = Vec::new();
101 let mut ctor_externs = Vec::new();
102 let mut host_funcs = Vec::new();
103
104 for f in module.funcs() {
105 let asyncness = async_conf.is_async(module.name.as_str(), f.name.as_str());
106 match asyncness {
107 Asyncness::Blocking => {}
108 Asyncness::Async => {
109 assert!(
110 cfg!(feature = "async"),
111 "generating async wasmtime Funcs requires cargo feature \"async\""
112 );
113 }
114 _ => {}
115 }
116 generate_func(
117 &module_id,
118 &f,
119 names,
120 &target_module,
121 ctx_type,
122 asyncness,
123 &mut fns,
124 &mut ctor_externs,
125 &mut host_funcs,
126 );
127 }
128
129 let type_name = module_conf.name.clone();
130 let type_docs = module_conf
131 .docs
132 .as_ref()
133 .map(|docs| quote!( #[doc = #docs] ))
134 .unwrap_or_default();
135 let constructor_docs = format!(
136 "Creates a new [`{}`] instance.
137
138External values are allocated into the `store` provided and
139configuration of the instance itself should be all
140contained in the `cx` parameter.",
141 module_conf.name.to_string()
142 );
143
144 let config_adder_definitions = host_funcs.iter().map(|(func_name, body)| {
145 let adder_func = format_ident!("add_{}_to_config", names.func(&func_name));
146 let docs = format!(
147 "Add the host function for `{}` to a config under a given module and field name.",
148 func_name.as_str()
149 );
150 quote! {
151 #[doc = #docs]
152 pub fn #adder_func(config: &mut wasmtime::Config, module: &str, field: &str) {
153 #body
154 }
155 }
156 });
157 let config_adder_invocations = host_funcs.iter().map(|(func_name, _body)| {
158 let adder_func = format_ident!("add_{}_to_config", names.func(&func_name));
159 let module = module.name.as_str();
160 let field = func_name.as_str();
161 quote! {
162 Self::#adder_func(config, #module, #field);
163 }
164 });
165
166 quote! {
167 #type_docs
168 pub struct #type_name {
169 #(#fields,)*
170 }
171
172 impl #type_name {
173 #[doc = #constructor_docs]
174 pub fn new(store: &wasmtime::Store, ctx: std::rc::Rc<std::cell::RefCell<#ctx_type>>) -> Self {
175 #(#ctor_externs)*
176
177 Self {
178 #(#ctor_fields,)*
179 }
180 }
181
182
183 pub fn get_export(&self, name: &str) -> Option<&wasmtime::Func> {
189 match name {
190 #(#get_exports,)*
191 _ => None,
192 }
193 }
194
195 pub fn add_to_linker(&self, linker: &mut wasmtime::Linker) -> anyhow::Result<()> {
197 #(#linker_add)*
198 Ok(())
199 }
200
201 pub fn add_to_config(config: &mut wasmtime::Config) {
207 #(#config_adder_invocations)*
208 }
209
210 #(#config_adder_definitions)*
211
212 pub fn set_context(store: &wasmtime::Store, ctx: #ctx_type) -> Result<(), #ctx_type> {
219 store.set(std::rc::Rc::new(std::cell::RefCell::new(ctx))).map_err(|ctx| {
220 match std::rc::Rc::try_unwrap(ctx) {
221 Ok(ctx) => ctx.into_inner(),
222 Err(_) => unreachable!(),
223 }
224 })
225 }
226
227 #(#fns)*
228 }
229 }
230}
231
232fn generate_func(
233 module_ident: &Ident,
234 func: &witx::InterfaceFunc,
235 names: &Names,
236 target_module: &TokenStream2,
237 ctx_type: &syn::Type,
238 asyncness: Asyncness,
239 fns: &mut Vec<TokenStream2>,
240 ctors: &mut Vec<TokenStream2>,
241 host_funcs: &mut Vec<(witx::Id, TokenStream2)>,
242) {
243 let rt = names.runtime_mod();
244 let name_ident = names.func(&func.name);
245
246 let (params, results) = func.wasm_signature();
247
248 let arg_names = (0..params.len())
249 .map(|i| Ident::new(&format!("arg{}", i), Span::call_site()))
250 .collect::<Vec<_>>();
251 let arg_decls = params
252 .iter()
253 .enumerate()
254 .map(|(i, ty)| {
255 let name = &arg_names[i];
256 let wasm = names.wasm_type(*ty);
257 quote! { #name: #wasm }
258 })
259 .collect::<Vec<_>>();
260
261 let ret_ty = match results.len() {
262 0 => quote!(()),
263 1 => names.wasm_type(results[0]),
264 _ => unimplemented!(),
265 };
266
267 let async_ = if asyncness.is_sync() {
268 quote!()
269 } else {
270 quote!(async)
271 };
272 let await_ = if asyncness.is_sync() {
273 quote!()
274 } else {
275 quote!(.await)
276 };
277
278 let runtime = names.runtime_mod();
279 let fn_ident = format_ident!("{}_{}", module_ident, name_ident);
280
281 fns.push(quote! {
282 #async_ fn #fn_ident(caller: &wasmtime::Caller<'_>, ctx: &mut #ctx_type #(, #arg_decls)*) -> Result<#ret_ty, wasmtime::Trap> {
283 unsafe {
284 let mem = match caller.get_export("memory") {
285 Some(wasmtime::Extern::Memory(m)) => m,
286 _ => {
287 return Err(wasmtime::Trap::new("missing required memory export"));
288 }
289 };
290 let mem = #runtime::WasmtimeGuestMemory::new(mem);
291 match #target_module::#name_ident(ctx, &mem #(, #arg_names)*) #await_ {
292 Ok(r) => Ok(r.into()),
293 Err(wasmtime_wiggle::Trap::String(err)) => Err(wasmtime::Trap::new(err)),
294 Err(wasmtime_wiggle::Trap::I32Exit(err)) => Err(wasmtime::Trap::i32_exit(err)),
295 }
296 }
297 }
298 });
299
300 match asyncness {
301 Asyncness::Async => {
302 let wrapper = format_ident!("wrap{}_async", params.len());
303 ctors.push(quote! {
304 let #name_ident = wasmtime::Func::#wrapper(
305 store,
306 ctx.clone(),
307 move |caller: wasmtime::Caller<'_>, my_ctx: &std::rc::Rc<std::cell::RefCell<_>> #(,#arg_decls)*|
308 -> Box<dyn std::future::Future<Output = Result<#ret_ty, wasmtime::Trap>>> {
309 Box::new(async move { Self::#fn_ident(&caller, &mut my_ctx.borrow_mut() #(, #arg_names)*).await })
310 }
311 );
312 });
313 }
314 Asyncness::Blocking => {
315 ctors.push(quote! {
319 let my_ctx = ctx.clone();
320 let #name_ident = wasmtime::Func::wrap(
321 store,
322 move |caller: wasmtime::Caller #(, #arg_decls)*| -> Result<#ret_ty, wasmtime::Trap> {
323 #rt::run_in_dummy_executor(Self::#fn_ident(&caller, &mut my_ctx.borrow_mut() #(, #arg_names)*))
324 }
325 );
326 });
327 }
328 Asyncness::Sync => {
329 ctors.push(quote! {
330 let my_ctx = ctx.clone();
331 let #name_ident = wasmtime::Func::wrap(
332 store,
333 move |caller: wasmtime::Caller #(, #arg_decls)*| -> Result<#ret_ty, wasmtime::Trap> {
334 Self::#fn_ident(&caller, &mut my_ctx.borrow_mut() #(, #arg_names)*)
335 }
336 );
337 });
338 }
339 }
340
341 let host_wrapper = match asyncness {
342 Asyncness::Async => {
343 let wrapper = format_ident!("wrap{}_host_func_async", params.len());
344 quote! {
345 config.#wrapper(
346 module,
347 field,
348 move |caller #(,#arg_decls)*|
349 -> Box<dyn std::future::Future<Output = Result<#ret_ty, wasmtime::Trap>>> {
350 Box::new(async move {
351 let ctx = caller.store()
352 .get::<std::rc::Rc<std::cell::RefCell<#ctx_type>>>()
353 .ok_or_else(|| wasmtime::Trap::new("context is missing in the store"))?;
354 let result = Self::#fn_ident(&caller, &mut ctx.borrow_mut() #(, #arg_names)*).await;
355 result
356 })
357 }
358 );
359 }
360 }
361
362 Asyncness::Blocking => {
363 quote! {
367 config.wrap_host_func(
368 module,
369 field,
370 move |caller: wasmtime::Caller #(, #arg_decls)*| -> Result<#ret_ty, wasmtime::Trap> {
371 let ctx = caller
372 .store()
373 .get::<std::rc::Rc<std::cell::RefCell<#ctx_type>>>()
374 .ok_or_else(|| wasmtime::Trap::new("context is missing in the store"))?;
375 #rt::run_in_dummy_executor(Self::#fn_ident(&caller, &mut ctx.borrow_mut() #(, #arg_names)*))
376 },
377 );
378 }
379 }
380 Asyncness::Sync => {
381 quote! {
382 config.wrap_host_func(
383 module,
384 field,
385 move |caller: wasmtime::Caller #(, #arg_decls)*| -> Result<#ret_ty, wasmtime::Trap> {
386 let ctx = caller
387 .store()
388 .get::<std::rc::Rc<std::cell::RefCell<#ctx_type>>>()
389 .ok_or_else(|| wasmtime::Trap::new("context is missing in the store"))?;
390 Self::#fn_ident(&caller, &mut ctx.borrow_mut() #(, #arg_names)*)
391 },
392 );
393 }
394 }
395 };
396 host_funcs.push((func.name.clone(), host_wrapper));
397}