scotch_host/
plugin.rs

1use crate::{CallbackRef, GuestFunctionCreator, GuestFunctionHandle, InstanceRef, StoreRef};
2use std::{
3    any::{Any, TypeId},
4    collections::{hash_map::Entry, HashMap},
5    path::Path,
6    sync::{Arc, Weak},
7};
8use wasmer::{
9    CompileError, DeserializeError, Extern, FunctionEnv, Imports, Instance, InstantiationError,
10    Module, SerializeError, Store,
11};
12
13#[doc(hidden)]
14pub struct WasmEnv<S: Any + Send + Sized + 'static> {
15    pub instance: Weak<Instance>,
16    pub state: S,
17}
18
19/// An instantiated plugin with cached exports.
20#[allow(dead_code)]
21pub struct WasmPlugin {
22    exports: HashMap<TypeId, CallbackRef>,
23    store: StoreRef,
24    module: Module,
25    instance: InstanceRef,
26}
27
28impl WasmPlugin {
29    /// Creates a builder to create a new WasmPlugin.
30    pub fn builder<E: Any + Send + Sized + 'static>() -> WasmPluginBuilder<E> {
31        WasmPluginBuilder::new()
32    }
33
34    /// Looks up cached guest export by function handle.
35    pub fn function<H: GuestFunctionHandle + 'static>(&self) -> Option<&H::Callback> {
36        self.exports
37            .get(&TypeId::of::<H>())?
38            .downcast_ref::<H::Callback>()
39    }
40
41    /// Looks up cached guest export by function handle.
42    /// If no matches are found tries to resolve export from wasm instance and cache the result.
43    pub fn function_or_cache<H: GuestFunctionHandle + 'static>(&mut self) -> Option<&H::Callback> {
44        let type_id = TypeId::of::<H>();
45
46        if let Entry::Vacant(e) = self.exports.entry(type_id) {
47            let callback = H::new()
48                .create(self.store.clone(), self.instance.clone())?
49                .1;
50            e.insert(callback);
51        }
52
53        self.exports.get(&type_id).and_then(|f| f.downcast_ref())
54    }
55
56    /// Looks up cached guest export by function handle.
57    /// # Panics
58    /// If function was not cached with `make_exports!`.
59    pub fn function_unwrap<H: GuestFunctionHandle + 'static>(&self) -> &H::Callback {
60        self.exports
61            .get(&TypeId::of::<H>())
62            .expect("Function not found")
63            .downcast_ref::<H::Callback>()
64            .unwrap()
65    }
66
67    /// Looks up cached guest export by function handle.
68    /// If no matches are found tries to resolve export from wasm instance and cache the result.
69    /// # Panics
70    /// If failed to find function in exports and it is missing in wasm instance.
71    pub fn function_unwrap_or_cache<'this: 'cb, 'cb, H: GuestFunctionHandle + 'static>(
72        &'this mut self,
73    ) -> &'cb H::Callback {
74        let type_id = TypeId::of::<H>();
75
76        self.exports
77            .entry(type_id)
78            .or_insert_with(|| {
79                H::new()
80                    .create(self.store.clone(), self.instance.clone())
81                    .expect("Function not found")
82                    .1
83            })
84            .downcast_ref()
85            .unwrap()
86    }
87
88    /// Serializes plugin into bytes to use with headless mode.
89    pub fn serialize(&self) -> Result<Vec<u8>, SerializeError> {
90        self.module.serialize().map(|bytes| bytes.to_vec())
91    }
92
93    /// Serializes plugin into bytes to use with headless mode and writes them to file.
94    pub fn serialize_to_file(&self, path: impl AsRef<Path>) -> Result<(), SerializeError> {
95        self.module.serialize_to_file(path)
96    }
97
98    /// Serializes plugin and compresses bytes to use with headless mode.
99    #[cfg(feature = "flate2")]
100    #[cfg_attr(feature = "unstable-doc-cfg", doc(cfg(feature = "flate2")))]
101    pub fn serialize_compress(&self) -> Result<Vec<u8>, SerializeError> {
102        use flate2::Compression;
103        use std::io::Write;
104
105        let data = self.serialize()?;
106        let mut encoder = flate2::write::GzEncoder::new(vec![], Compression::best());
107        encoder.write_all(&data[..])?;
108
109        Ok(encoder.finish()?)
110    }
111
112    /// Serializes plugin to file and compresses bytes to use with headless mode.
113    #[cfg(feature = "flate2")]
114    #[cfg_attr(feature = "unstable-doc-cfg", doc(cfg(feature = "flate2")))]
115    pub fn serialize_to_file_compress(&self, path: impl AsRef<Path>) -> Result<(), SerializeError> {
116        let compressed = self.serialize_compress()?;
117        Ok(std::fs::write(path, compressed)?)
118    }
119}
120
121/// Builder for creating [`WasmPlugin`].
122pub struct WasmPluginBuilder<E: Any + Send + Sized + 'static> {
123    store: Store,
124    module: Option<Module>,
125    imports: Option<Imports>,
126    exports: Vec<Box<dyn GuestFunctionCreator>>,
127    func_env: Option<FunctionEnv<WasmEnv<E>>>,
128}
129
130impl<S: Any + Send + Sized + 'static> WasmPluginBuilder<S> {
131    /// Creates new [`WasmPluginBuilder`].
132    #[inline]
133    pub fn new() -> Self {
134        Self {
135            store: Store::default(),
136            module: None,
137            imports: None,
138            func_env: None,
139            exports: vec![],
140        }
141    }
142
143    /// Creates new [`WasmPluginBuilder`] and overrides default store with custom.
144    pub fn new_with_store(store: Store) -> Self {
145        Self {
146            store,
147            ..Self::new()
148        }
149    }
150
151    /// Compiles bytecode with selected compiler. To change the compile use feature flags.
152    /// Default compiler is `cranelift`.
153    #[cfg(feature = "compiler")]
154    #[cfg_attr(feature = "unstable-doc-cfg", doc(cfg(feature = "compiler")))]
155    pub fn from_binary(mut self, bytecode: &[u8]) -> Result<Self, CompileError> {
156        self.module = Some(Module::from_binary(&self.store, bytecode)?);
157        Ok(self)
158    }
159
160    /// Creates plugin from bytes created by [`WasmPlugin::serialize`].
161    /// # Safety
162    /// See [`Module::deserialize`].
163    pub unsafe fn from_serialized(mut self, data: &[u8]) -> Result<Self, DeserializeError> {
164        self.module = Some(Module::deserialize(&self.store, data)?);
165        Ok(self)
166    }
167
168    /// Creates plugin from compressed bytes created by [`WasmPlugin::serialize_compress`].
169    /// # Safety
170    /// See [`Module::deserialize`].
171    #[cfg(feature = "flate2")]
172    #[cfg_attr(feature = "unstable-doc-cfg", doc(cfg(feature = "flate2")))]
173    pub unsafe fn from_serialized_compressed(
174        mut self,
175        compressed: &[u8],
176    ) -> Result<Self, DeserializeError> {
177        use std::io::Read;
178
179        let mut decoder = flate2::read::GzDecoder::new(compressed);
180        let mut buf = vec![];
181        decoder.read_to_end(&mut buf)?;
182
183        self.module = Some(Module::deserialize(&self.store, buf)?);
184        Ok(self)
185    }
186
187    /// Creates plugin from bytes created by [`WasmPlugin::serialize_to_file`].
188    /// # Safety
189    /// See [`Module::deserialize_from_file`].
190    pub unsafe fn from_serialized_file(
191        mut self,
192        path: impl AsRef<Path>,
193    ) -> Result<Self, DeserializeError> {
194        self.module = Some(Module::deserialize_from_file(&self.store, path)?);
195        Ok(self)
196    }
197
198    /// Creates plugin from compressed bytes created by [`WasmPlugin::serialize_to_file_compress`].
199    /// # Safety
200    /// See [`Module::deserialize`].
201    #[cfg(feature = "flate2")]
202    #[cfg_attr(feature = "unstable-doc-cfg", doc(cfg(feature = "flate2")))]
203    pub unsafe fn from_serialized_file_compressed(
204        mut self,
205        path: impl AsRef<Path>,
206    ) -> Result<Self, DeserializeError> {
207        use std::io::Read;
208
209        let compressed = std::fs::read(path)?;
210        let mut decoder = flate2::read::GzDecoder::new(&compressed[..]);
211        let mut buf = vec![];
212        decoder.read_to_end(&mut buf)?;
213
214        self.module = Some(Module::deserialize(&self.store, buf)?);
215        Ok(self)
216    }
217
218    /// Creates a state that host function will have mutable access to.
219    /// You *HAVE* to create the state. If you do not need it simply pass `()`.
220    pub fn with_state(mut self, state: S) -> Self {
221        // This should help avoid questionable bugs in `with_imports`
222        assert!(
223            self.func_env.is_none(),
224            "You can call `with_state` only once"
225        );
226
227        self.func_env = Some(FunctionEnv::new(
228            &mut self.store,
229            WasmEnv {
230                instance: Weak::new(),
231                state,
232            },
233        ));
234        self
235    }
236
237    /// Creates imports i.e. host functions that guest imports.
238    /// use `make_imports!` to create the closure.
239    pub fn with_imports(
240        mut self,
241        imports: impl FnOnce(&mut Store, &FunctionEnv<WasmEnv<S>>) -> Imports,
242    ) -> Self {
243        self.imports = Some(imports(
244            &mut self.store,
245            self.func_env
246                .as_ref()
247                .expect("You need to call `with_state` first"),
248        ));
249        self
250    }
251
252    /// Updates exports i.e. guest functions that host imports.
253    /// use `make_exports!` to create the iterator.
254    pub fn with_exports(
255        mut self,
256        exports: impl IntoIterator<Item = Box<dyn GuestFunctionCreator>>,
257    ) -> Self {
258        self.exports.extend(exports);
259        self
260    }
261
262    /// Finishes building a `WasmPlugin`.
263    #[allow(clippy::result_large_err)]
264    pub fn finish(mut self) -> Result<WasmPlugin, InstantiationError> {
265        let module = self
266            .module
267            .expect("You need to call `from_binary` or `from_serialized` first");
268        let instance: InstanceRef =
269            Instance::new(&mut self.store, &module, &self.imports.unwrap_or_default())?.into();
270        instance
271            .exports
272            .get_memory("memory")
273            .unwrap()
274            .grow(&mut self.store, 3)
275            .unwrap();
276
277        if let Some(env) = self.func_env.as_mut() {
278            env.as_mut(&mut self.store).instance = Arc::downgrade(&instance);
279        }
280
281        let store: StoreRef = Arc::new(self.store.into());
282        let exports = self
283            .exports
284            .into_iter()
285            .flat_map(|ex| ex.create(store.clone(), instance.clone()))
286            .collect::<HashMap<_, _>>();
287
288        Ok(WasmPlugin {
289            store,
290            exports,
291            instance,
292            module,
293        })
294    }
295}
296
297impl<E: Any + Send + Sized + 'static> Default for WasmPluginBuilder<E> {
298    #[inline]
299    fn default() -> Self {
300        Self::new()
301    }
302}
303
304#[doc(hidden)]
305pub use wasmer::{Function, FunctionEnvMut};
306
307#[doc(hidden)]
308pub fn create_imports_from_functions<const N: usize>(
309    items: [(&'static str, Function); N],
310) -> Imports {
311    let mut imports = Imports::new();
312    imports.register_namespace(
313        "env",
314        items
315            .into_iter()
316            .map(|(s, f)| (s.to_string(), Extern::Function(f))),
317    );
318    imports
319}