1use std::collections::HashSet;
8use std::path::Path;
9use std::sync::atomic::{AtomicBool, Ordering};
10use std::sync::Arc;
11
12use wasmtime::{Caller, Engine, Func, Instance, Linker, Module, Store, Trap, Val, ValType};
13
14use crate::plugin::PluginError;
15
16pub type CapabilityCheck = Arc<dyn Fn(&CapabilityArgs) -> bool + Send + Sync + 'static>;
17
18#[derive(Clone, Debug)]
19pub struct CapabilityArgs {
20 pub plugin_actor: String,
21 pub capability: String,
22 pub caller: String,
23}
24
25pub struct WasmPluginRuntimeOptions {
26 pub plugin_actor: String,
27 pub allowed_imports: HashSet<String>,
29 pub capability_check: Option<CapabilityCheck>,
32}
33
34pub struct WasmPlugin {
35 pub plugin_actor: String,
36 instance: Instance,
37 store: Store<HostState>,
38}
39
40struct HostState {
41 revoked: Arc<AtomicBool>,
42}
43
44impl WasmPlugin {
45 pub fn from_file<P: AsRef<Path>>(
47 path: P,
48 opts: WasmPluginRuntimeOptions,
49 ) -> Result<Self, PluginError> {
50 let bytes = std::fs::read(path.as_ref()).map_err(|e| PluginError::Io(e.to_string()))?;
51 Self::from_bytes(&bytes, opts)
52 }
53
54 pub fn from_bytes(wasm: &[u8], opts: WasmPluginRuntimeOptions) -> Result<Self, PluginError> {
55 let engine = Engine::default();
56 let module = Module::from_binary(&engine, wasm)
57 .map_err(|e| PluginError::Parse(format!("wasm compile: {e}")))?;
58 let revoked = Arc::new(AtomicBool::new(false));
59 let mut store = Store::new(
60 &engine,
61 HostState {
62 revoked: revoked.clone(),
63 },
64 );
65 let mut linker: Linker<HostState> = Linker::new(&engine);
66
67 for import in module.imports() {
76 let module_name = import.module().to_string();
77 let field_name = import.name().to_string();
78 let combined = format!("{module_name}.{field_name}");
79 let allowed = opts.allowed_imports.contains(&combined);
80 let plugin_actor = opts.plugin_actor.clone();
81 let capability_check = opts.capability_check.clone();
82
83 let func_ty = match import.ty() {
88 wasmtime::ExternType::Func(ft) => ft,
89 _ => continue,
90 };
91
92 let plugin_actor_for_func = plugin_actor.clone();
93 let cap_for_func = combined.clone();
94 let func = Func::new(
95 &mut store,
96 func_ty.clone(),
97 move |_caller: Caller<'_, HostState>,
98 params: &[Val],
99 results: &mut [Val]|
100 -> Result<(), wasmtime::Error> {
101 if !allowed {
102 return Err(Trap::UnreachableCodeReached.into());
103 }
104 if let Some(cb) = &capability_check {
105 let ok = cb(&CapabilityArgs {
106 plugin_actor: plugin_actor_for_func.clone(),
107 capability: format!("wasm.import.{}", cap_for_func),
108 caller: plugin_actor_for_func.clone(),
109 });
110 if !ok {
111 return Err(Trap::UnreachableCodeReached.into());
112 }
113 }
114 let _ = params;
122 for (i, result_ty) in func_ty.results().enumerate() {
123 results[i] = match result_ty {
124 ValType::I32 => Val::I32(0),
125 ValType::I64 => Val::I64(0),
126 ValType::F32 => Val::F32(0),
127 ValType::F64 => Val::F64(0),
128 _ => Val::I32(0),
129 };
130 }
131 Ok(())
132 },
133 );
134 linker
135 .define(&store, &module_name, &field_name, func)
136 .map_err(|e| PluginError::Parse(format!("wasm linker.define: {e}")))?;
137 }
138
139 let instance = linker
140 .instantiate(&mut store, &module)
141 .map_err(|e| PluginError::Parse(format!("wasm instantiate: {e}")))?;
142 Ok(WasmPlugin {
143 plugin_actor: opts.plugin_actor,
144 instance,
145 store,
146 })
147 }
148
149 pub fn revoke(&self) {
152 self.revoked().store(true, Ordering::SeqCst);
153 }
154
155 fn revoked(&self) -> Arc<AtomicBool> {
156 self.store.data().revoked.clone()
157 }
158
159 pub fn call_i32(&mut self, name: &str, arg: i32) -> Result<i32, PluginError> {
163 if self.revoked().load(Ordering::SeqCst) {
164 return Err(PluginError::BadSignature(format!(
165 "plugin {} actor was revoked",
166 self.plugin_actor
167 )));
168 }
169 let f = self
170 .instance
171 .get_typed_func::<i32, i32>(&mut self.store, name)
172 .map_err(|e| PluginError::Parse(format!("get_typed_func {name}: {e}")))?;
173 f.call(&mut self.store, arg)
174 .map_err(|e| PluginError::Parse(format!("wasm call {name}: {e}")))
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181
182 const TINY_WASM: &[u8] = &[
186 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, 0x01, 0x05, 0x01, 0x60, 0x00, 0x01, 0x7f,
190 0x03, 0x02, 0x01, 0x00, 0x07, 0x05, 0x01, 0x01, 0x66, 0x00, 0x00,
193 0x0a, 0x06, 0x01, 0x04, 0x00, 0x41, 0x00, 0x0b,
195 ];
196
197 #[test]
198 fn instantiate_minimal_wasm_no_imports() {
199 let plugin = WasmPlugin::from_bytes(
200 TINY_WASM,
201 WasmPluginRuntimeOptions {
202 plugin_actor: "tf:actor:plugin:example.com/test".to_string(),
203 allowed_imports: HashSet::new(),
204 capability_check: None,
205 },
206 )
207 .expect("instantiate");
208 assert_eq!(plugin.plugin_actor, "tf:actor:plugin:example.com/test");
211 }
212
213 #[test]
214 fn revocation_blocks_subsequent_calls() {
215 let mut plugin = WasmPlugin::from_bytes(
216 TINY_WASM,
217 WasmPluginRuntimeOptions {
218 plugin_actor: "tf:actor:plugin:example.com/test".to_string(),
219 allowed_imports: HashSet::new(),
220 capability_check: None,
221 },
222 )
223 .expect("instantiate");
224 plugin.revoke();
225 let err = plugin.call_i32("nope", 0).unwrap_err();
226 assert!(matches!(err, PluginError::BadSignature(_)));
227 }
228}