wasmtime_wiggle_macro/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
3use quote::{format_ident, quote};
4use syn::parse_macro_input;
5use wiggle_generate::Names;
6
7mod config;
8
9use config::{AsyncConf, Asyncness, ModuleConf, TargetConf};
10
11/// Define the structs required to integrate a Wiggle implementation with Wasmtime.
12///
13/// ## Arguments
14///
15/// Arguments are provided using struct syntax e.g. `{ arg_name: value }`.
16///
17/// * `target`: The path of the module where the Wiggle implementation is defined.
18/// * `witx` or `witx_literal`: the .witx document where the interface is defined.
19///   `witx` takes a list of filesystem paths, e.g. `["/path/to/file1.witx",
20///   "./path/to_file2.witx"]`. Relative paths are relative to the root of the crate
21///   where the macro is invoked. `witx_literal` takes a string of the witx document, e.g.
22///   `"(typename $foo u8)"`.
23/// * `ctx`: The context struct used for the Wiggle implementation. This must be the same
24///   type as the `wasmtime_wiggle::from_witx` macro at `target` was invoked with. However, it
25///   must be imported to the current scope so that it is a bare identifier e.g. `CtxType`, not
26///   `path::to::CtxType`.
27/// * `modules`: Describes how any modules in the witx document will be implemented as Wasmtime
28///    instances. `modules` takes a map from the witx module name to a configuration struct, e.g.
29///    `foo => { name: Foo }, bar => { name: Bar }` will generate integrations for the modules
30///    named `foo` and `bar` in the witx document, as `pub struct Foo` and `pub struct Bar`
31///    respectively.
32///    The module configuration uses struct syntax with the following fields:
33///      * `name`: required, gives the name of the struct which encapsulates the instance for
34///         Wasmtime.
35///      * `docs`: optional, a doc string that will be used for the definition of the struct.
36///      * `function_override`: A map of witx function names to Rust function symbols for
37///         functions that should not call the Wiggle-generated functions, but instead use
38///         a separate implementation. This is typically used for functions that need to interact
39///         with Wasmtime in a manner that Wiggle does not permit, e.g. wasi's `proc_exit` function
40///         needs to return a Trap directly to the runtime.
41///    Example:
42///    `modules: { some_module => { name: SomeTypeName, docs: "Doc string for definition of
43///     SomeTypeName here", function_override: { foo => my_own_foo } }`.
44///
45#[proc_macro]
46pub fn wasmtime_integration(args: TokenStream) -> TokenStream {
47    let config = parse_macro_input!(args as config::Config);
48    let doc = config.load_document();
49    let names = Names::new(quote!(wasmtime_wiggle));
50
51    let modules = config.modules.iter().map(|(name, module_conf)| {
52        let module = doc
53            .module(&witx::Id::new(name))
54            .unwrap_or_else(|| panic!("witx document did not contain module named '{}'", name));
55        generate_module(
56            &module,
57            &module_conf,
58            &names,
59            &config.target,
60            &config.ctx.name,
61            &config.async_,
62        )
63    });
64    quote!( #(#modules)* ).into()
65}
66
67fn generate_module(
68    module: &witx::Module,
69    module_conf: &ModuleConf,
70    names: &Names,
71    target_conf: &TargetConf,
72    ctx_type: &syn::Type,
73    async_conf: &AsyncConf,
74) -> TokenStream2 {
75    let fields = module.funcs().map(|f| {
76        let name_ident = names.func(&f.name);
77        quote! { pub #name_ident: wasmtime::Func }
78    });
79    let get_exports = module.funcs().map(|f| {
80        let func_name = f.name.as_str();
81        let name_ident = names.func(&f.name);
82        quote! { #func_name => Some(&self.#name_ident) }
83    });
84    let ctor_fields = module.funcs().map(|f| names.func(&f.name));
85
86    let module_name = module.name.as_str();
87
88    let linker_add = module.funcs().map(|f| {
89        let func_name = f.name.as_str();
90        let name_ident = names.func(&f.name);
91        quote! {
92            linker.define(#module_name, #func_name, self.#name_ident.clone())?;
93        }
94    });
95
96    let target_path = &target_conf.path;
97    let module_id = names.module(&module.name);
98    let target_module = quote! { #target_path::#module_id };
99
100    let mut fns = Vec::new();
101    let mut ctor_externs = Vec::new();
102    let mut host_funcs = Vec::new();
103
104    for f in module.funcs() {
105        let asyncness = async_conf.is_async(module.name.as_str(), f.name.as_str());
106        match asyncness {
107            Asyncness::Blocking => {}
108            Asyncness::Async => {
109                assert!(
110                    cfg!(feature = "async"),
111                    "generating async wasmtime Funcs requires cargo feature \"async\""
112                );
113            }
114            _ => {}
115        }
116        generate_func(
117            &module_id,
118            &f,
119            names,
120            &target_module,
121            ctx_type,
122            asyncness,
123            &mut fns,
124            &mut ctor_externs,
125            &mut host_funcs,
126        );
127    }
128
129    let type_name = module_conf.name.clone();
130    let type_docs = module_conf
131        .docs
132        .as_ref()
133        .map(|docs| quote!( #[doc = #docs] ))
134        .unwrap_or_default();
135    let constructor_docs = format!(
136        "Creates a new [`{}`] instance.
137
138External values are allocated into the `store` provided and
139configuration of the instance itself should be all
140contained in the `cx` parameter.",
141        module_conf.name.to_string()
142    );
143
144    let config_adder_definitions = host_funcs.iter().map(|(func_name, body)| {
145        let adder_func = format_ident!("add_{}_to_config", names.func(&func_name));
146        let docs = format!(
147            "Add the host function for `{}` to a config under a given module and field name.",
148            func_name.as_str()
149        );
150        quote! {
151            #[doc = #docs]
152            pub fn #adder_func(config: &mut wasmtime::Config, module: &str, field: &str) {
153                #body
154            }
155        }
156    });
157    let config_adder_invocations = host_funcs.iter().map(|(func_name, _body)| {
158        let adder_func = format_ident!("add_{}_to_config", names.func(&func_name));
159        let module = module.name.as_str();
160        let field = func_name.as_str();
161        quote! {
162            Self::#adder_func(config, #module, #field);
163        }
164    });
165
166    quote! {
167        #type_docs
168        pub struct #type_name {
169            #(#fields,)*
170        }
171
172        impl #type_name {
173            #[doc = #constructor_docs]
174            pub fn new(store: &wasmtime::Store, ctx: std::rc::Rc<std::cell::RefCell<#ctx_type>>) -> Self {
175                #(#ctor_externs)*
176
177                Self {
178                    #(#ctor_fields,)*
179                }
180            }
181
182
183            /// Looks up a field called `name` in this structure, returning it
184            /// if found.
185            ///
186            /// This is often useful when instantiating a `wasmtime` instance
187            /// where name resolution often happens with strings.
188            pub fn get_export(&self, name: &str) -> Option<&wasmtime::Func> {
189                match name {
190                    #(#get_exports,)*
191                    _ => None,
192                }
193            }
194
195            /// Adds all instance items to the specified `Linker`.
196            pub fn add_to_linker(&self, linker: &mut wasmtime::Linker) -> anyhow::Result<()> {
197                #(#linker_add)*
198                Ok(())
199            }
200
201            /// Adds the host functions to the given [`wasmtime::Config`].
202            ///
203            /// Host functions added to the config expect [`set_context`] to be called.
204            ///
205            /// Host functions will trap if the context is not set in the calling [`wasmtime::Store`].
206            pub fn add_to_config(config: &mut wasmtime::Config) {
207                #(#config_adder_invocations)*
208            }
209
210            #(#config_adder_definitions)*
211
212            /// Sets the context in the given store.
213            ///
214            /// Context must be set in the store when using [`add_to_config`] and prior to any
215            /// host function being called.
216            ///
217            /// If the context is already set in the store, the given context is returned as an error.
218            pub fn set_context(store: &wasmtime::Store, ctx: #ctx_type) -> Result<(), #ctx_type> {
219                store.set(std::rc::Rc::new(std::cell::RefCell::new(ctx))).map_err(|ctx| {
220                    match std::rc::Rc::try_unwrap(ctx) {
221                        Ok(ctx) => ctx.into_inner(),
222                        Err(_) => unreachable!(),
223                    }
224                })
225            }
226
227            #(#fns)*
228        }
229    }
230}
231
232fn generate_func(
233    module_ident: &Ident,
234    func: &witx::InterfaceFunc,
235    names: &Names,
236    target_module: &TokenStream2,
237    ctx_type: &syn::Type,
238    asyncness: Asyncness,
239    fns: &mut Vec<TokenStream2>,
240    ctors: &mut Vec<TokenStream2>,
241    host_funcs: &mut Vec<(witx::Id, TokenStream2)>,
242) {
243    let rt = names.runtime_mod();
244    let name_ident = names.func(&func.name);
245
246    let (params, results) = func.wasm_signature();
247
248    let arg_names = (0..params.len())
249        .map(|i| Ident::new(&format!("arg{}", i), Span::call_site()))
250        .collect::<Vec<_>>();
251    let arg_decls = params
252        .iter()
253        .enumerate()
254        .map(|(i, ty)| {
255            let name = &arg_names[i];
256            let wasm = names.wasm_type(*ty);
257            quote! { #name: #wasm }
258        })
259        .collect::<Vec<_>>();
260
261    let ret_ty = match results.len() {
262        0 => quote!(()),
263        1 => names.wasm_type(results[0]),
264        _ => unimplemented!(),
265    };
266
267    let async_ = if asyncness.is_sync() {
268        quote!()
269    } else {
270        quote!(async)
271    };
272    let await_ = if asyncness.is_sync() {
273        quote!()
274    } else {
275        quote!(.await)
276    };
277
278    let runtime = names.runtime_mod();
279    let fn_ident = format_ident!("{}_{}", module_ident, name_ident);
280
281    fns.push(quote! {
282        #async_ fn #fn_ident(caller: &wasmtime::Caller<'_>, ctx: &mut #ctx_type #(, #arg_decls)*) -> Result<#ret_ty, wasmtime::Trap> {
283            unsafe {
284                let mem = match caller.get_export("memory") {
285                    Some(wasmtime::Extern::Memory(m)) => m,
286                    _ => {
287                        return Err(wasmtime::Trap::new("missing required memory export"));
288                    }
289                };
290                let mem = #runtime::WasmtimeGuestMemory::new(mem);
291                match #target_module::#name_ident(ctx, &mem #(, #arg_names)*) #await_ {
292                    Ok(r) => Ok(r.into()),
293                    Err(wasmtime_wiggle::Trap::String(err)) => Err(wasmtime::Trap::new(err)),
294                    Err(wasmtime_wiggle::Trap::I32Exit(err)) => Err(wasmtime::Trap::i32_exit(err)),
295                }
296            }
297        }
298    });
299
300    match asyncness {
301        Asyncness::Async => {
302            let wrapper = format_ident!("wrap{}_async", params.len());
303            ctors.push(quote! {
304            let #name_ident = wasmtime::Func::#wrapper(
305                store,
306                ctx.clone(),
307                move |caller: wasmtime::Caller<'_>, my_ctx: &std::rc::Rc<std::cell::RefCell<_>> #(,#arg_decls)*|
308                    -> Box<dyn std::future::Future<Output = Result<#ret_ty, wasmtime::Trap>>> {
309                    Box::new(async move { Self::#fn_ident(&caller, &mut my_ctx.borrow_mut() #(, #arg_names)*).await })
310                }
311            );
312            });
313        }
314        Asyncness::Blocking => {
315            // Emit a synchronous function. Self::#fn_ident returns a Future, so we need to
316            // use a dummy executor to let any synchronous code inside there execute correctly. If
317            // the future ends up Pending, this func will Trap.
318            ctors.push(quote! {
319                let my_ctx = ctx.clone();
320                let #name_ident = wasmtime::Func::wrap(
321                    store,
322                    move |caller: wasmtime::Caller #(, #arg_decls)*| -> Result<#ret_ty, wasmtime::Trap> {
323                        #rt::run_in_dummy_executor(Self::#fn_ident(&caller, &mut my_ctx.borrow_mut() #(, #arg_names)*))
324                    }
325                );
326            });
327        }
328        Asyncness::Sync => {
329            ctors.push(quote! {
330            let my_ctx = ctx.clone();
331            let #name_ident = wasmtime::Func::wrap(
332                store,
333                move |caller: wasmtime::Caller #(, #arg_decls)*| -> Result<#ret_ty, wasmtime::Trap> {
334                    Self::#fn_ident(&caller, &mut my_ctx.borrow_mut() #(, #arg_names)*)
335                }
336            );
337        });
338        }
339    }
340
341    let host_wrapper = match asyncness {
342        Asyncness::Async => {
343            let wrapper = format_ident!("wrap{}_host_func_async", params.len());
344            quote! {
345                config.#wrapper(
346                    module,
347                    field,
348                    move |caller #(,#arg_decls)*|
349                        -> Box<dyn std::future::Future<Output = Result<#ret_ty, wasmtime::Trap>>> {
350                        Box::new(async move {
351                            let ctx = caller.store()
352                                .get::<std::rc::Rc<std::cell::RefCell<#ctx_type>>>()
353                                .ok_or_else(|| wasmtime::Trap::new("context is missing in the store"))?;
354                            let result = Self::#fn_ident(&caller, &mut ctx.borrow_mut() #(, #arg_names)*).await;
355                            result
356                        })
357                    }
358                );
359            }
360        }
361
362        Asyncness::Blocking => {
363            // Emit a synchronous host function. Self::#fn_ident returns a Future, so we need to
364            // use a dummy executor to let any synchronous code inside there execute correctly. If
365            // the future ends up Pending, this func will Trap.
366            quote! {
367                config.wrap_host_func(
368                    module,
369                    field,
370                    move |caller: wasmtime::Caller #(, #arg_decls)*| -> Result<#ret_ty, wasmtime::Trap> {
371                        let ctx = caller
372                            .store()
373                            .get::<std::rc::Rc<std::cell::RefCell<#ctx_type>>>()
374                            .ok_or_else(|| wasmtime::Trap::new("context is missing in the store"))?;
375                        #rt::run_in_dummy_executor(Self::#fn_ident(&caller, &mut ctx.borrow_mut()  #(, #arg_names)*))
376                    },
377                );
378            }
379        }
380        Asyncness::Sync => {
381            quote! {
382                config.wrap_host_func(
383                    module,
384                    field,
385                    move |caller: wasmtime::Caller #(, #arg_decls)*| -> Result<#ret_ty, wasmtime::Trap> {
386                        let ctx = caller
387                            .store()
388                            .get::<std::rc::Rc<std::cell::RefCell<#ctx_type>>>()
389                            .ok_or_else(|| wasmtime::Trap::new("context is missing in the store"))?;
390                        Self::#fn_ident(&caller, &mut ctx.borrow_mut()  #(, #arg_names)*)
391                    },
392                );
393            }
394        }
395    };
396    host_funcs.push((func.name.clone(), host_wrapper));
397}