1use std::collections::HashMap;
27use std::path::{Path, PathBuf};
28use std::sync::Arc;
29
30use shaperail_core::ShaperailError;
31use tokio::sync::RwLock;
32use tracing::{debug, error, warn};
33use wasmtime::{AsContext, AsContextMut, Engine, Instance, Linker, Memory, Module, Store, Val};
34
35#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
37pub struct PluginContext {
38 pub input: serde_json::Map<String, serde_json::Value>,
40 pub data: Option<serde_json::Value>,
42 pub user: Option<PluginUser>,
44 pub headers: HashMap<String, String>,
46 pub tenant_id: Option<String>,
48}
49
50#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
52pub struct PluginUser {
53 pub id: String,
54 pub role: String,
55}
56
57#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
59pub struct PluginResult {
60 pub ok: bool,
62 #[serde(default)]
64 pub ctx: Option<PluginContext>,
65 #[serde(default)]
67 pub error: Option<String>,
68 #[serde(default)]
70 pub details: Option<Vec<serde_json::Value>>,
71}
72
73#[derive(Debug, Clone)]
75pub struct SandboxConfig {
76 pub max_memory_pages: u32,
78 pub max_fuel: u64,
80}
81
82impl Default for SandboxConfig {
83 fn default() -> Self {
84 Self {
85 max_memory_pages: 256,
86 max_fuel: 1_000_000,
87 }
88 }
89}
90
91struct CompiledPlugin {
93 module: Module,
94}
95
96pub struct WasmRuntime {
100 engine: Engine,
101 plugins: Arc<RwLock<HashMap<PathBuf, CompiledPlugin>>>,
102 sandbox: SandboxConfig,
103}
104
105impl WasmRuntime {
106 pub fn new() -> Result<Self, ShaperailError> {
108 Self::with_sandbox(SandboxConfig::default())
109 }
110
111 pub fn with_sandbox(sandbox: SandboxConfig) -> Result<Self, ShaperailError> {
113 let mut config = wasmtime::Config::new();
114 config.consume_fuel(true);
115
116 let engine = Engine::new(&config)
117 .map_err(|e| ShaperailError::Internal(format!("Failed to create WASM engine: {e}")))?;
118
119 Ok(Self {
120 engine,
121 plugins: Arc::new(RwLock::new(HashMap::new())),
122 sandbox,
123 })
124 }
125
126 pub async fn load_plugin(&self, path: &Path) -> Result<(), ShaperailError> {
128 let canonical = path.canonicalize().map_err(|e| {
129 ShaperailError::Internal(format!(
130 "Failed to resolve WASM plugin path '{}': {e}",
131 path.display()
132 ))
133 })?;
134
135 let mut plugins = self.plugins.write().await;
136 if plugins.contains_key(&canonical) {
137 return Ok(());
138 }
139
140 let wasm_bytes = std::fs::read(&canonical).map_err(|e| {
141 ShaperailError::Internal(format!(
142 "Failed to read WASM plugin '{}': {e}",
143 canonical.display()
144 ))
145 })?;
146
147 let module = Module::new(&self.engine, &wasm_bytes).map_err(|e| {
148 ShaperailError::Internal(format!(
149 "Failed to compile WASM plugin '{}': {e}",
150 canonical.display()
151 ))
152 })?;
153
154 debug!(path = %canonical.display(), "Loaded WASM plugin");
155 plugins.insert(canonical, CompiledPlugin { module });
156 Ok(())
157 }
158
159 pub async fn load_plugin_bytes(
161 &self,
162 name: &str,
163 wasm_bytes: &[u8],
164 ) -> Result<(), ShaperailError> {
165 let key = PathBuf::from(name);
166 let mut plugins = self.plugins.write().await;
167
168 let module = Module::new(&self.engine, wasm_bytes).map_err(|e| {
169 ShaperailError::Internal(format!("Failed to compile WASM module '{name}': {e}"))
170 })?;
171
172 plugins.insert(key, CompiledPlugin { module });
173 Ok(())
174 }
175
176 pub async fn call_hook(
181 &self,
182 plugin_path: &str,
183 hook_name: &str,
184 ctx: &PluginContext,
185 ) -> Result<PluginResult, ShaperailError> {
186 let key = if plugin_path.starts_with('/') || plugin_path.starts_with("__test:") {
187 PathBuf::from(plugin_path)
188 } else {
189 Path::new(plugin_path)
190 .canonicalize()
191 .unwrap_or_else(|_| PathBuf::from(plugin_path))
192 };
193
194 let plugins = self.plugins.read().await;
195 let compiled = plugins.get(&key).ok_or_else(|| {
196 ShaperailError::Internal(format!(
197 "WASM plugin '{}' not loaded. Call load_plugin first.",
198 key.display()
199 ))
200 })?;
201
202 let ctx_json = serde_json::to_vec(ctx).map_err(|e| {
203 ShaperailError::Internal(format!("Failed to serialize plugin context: {e}"))
204 })?;
205
206 let mut store = Store::new(&self.engine, ());
208 store
209 .set_fuel(self.sandbox.max_fuel)
210 .map_err(|e| ShaperailError::Internal(format!("Failed to set fuel: {e}")))?;
211
212 let linker = Linker::new(&self.engine);
214 let instance = linker
215 .instantiate(&mut store, &compiled.module)
216 .map_err(|e| {
217 ShaperailError::Internal(format!("Failed to instantiate WASM plugin: {e}"))
218 })?;
219
220 match self.invoke_hook(&mut store, &instance, hook_name, &ctx_json) {
222 Ok(result) => Ok(result),
223 Err(e) => {
224 warn!(
225 plugin = plugin_path,
226 hook = hook_name,
227 error = %e,
228 "WASM plugin hook trapped — returning error without crashing server"
229 );
230 Ok(PluginResult {
231 ok: false,
232 ctx: None,
233 error: Some(format!("WASM plugin trapped: {e}")),
234 details: None,
235 })
236 }
237 }
238 }
239
240 fn invoke_hook(
242 &self,
243 store: &mut Store<()>,
244 instance: &Instance,
245 hook_name: &str,
246 ctx_json: &[u8],
247 ) -> Result<PluginResult, ShaperailError> {
248 let memory = instance
250 .get_memory(store.as_context_mut(), "memory")
251 .ok_or_else(|| {
252 ShaperailError::Internal("WASM plugin does not export 'memory'".to_string())
253 })?;
254
255 let alloc_fn = instance
256 .get_func(store.as_context_mut(), "alloc")
257 .ok_or_else(|| {
258 ShaperailError::Internal("WASM plugin does not export 'alloc'".to_string())
259 })?;
260
261 let hook_fn = instance
262 .get_func(store.as_context_mut(), hook_name)
263 .ok_or_else(|| {
264 ShaperailError::Internal(format!("WASM plugin does not export '{hook_name}'"))
265 })?;
266
267 let input_len = ctx_json.len() as i32;
269 let mut alloc_result = [Val::I32(0)];
270 alloc_fn
271 .call(
272 store.as_context_mut(),
273 &[Val::I32(input_len)],
274 &mut alloc_result,
275 )
276 .map_err(|e| ShaperailError::Internal(format!("WASM alloc call failed: {e}")))?;
277 let input_ptr = alloc_result[0].unwrap_i32();
278
279 write_to_memory(&memory, store, input_ptr as usize, ctx_json)?;
281
282 let mut hook_result = [Val::I64(0)];
284 hook_fn
285 .call(
286 store.as_context_mut(),
287 &[Val::I32(input_ptr), Val::I32(input_len)],
288 &mut hook_result,
289 )
290 .map_err(|e| {
291 ShaperailError::Internal(format!("WASM hook '{hook_name}' trapped: {e}"))
292 })?;
293
294 let packed = hook_result[0].unwrap_i64();
296 let result_ptr = (packed >> 32) as usize;
297 let result_len = (packed & 0xFFFF_FFFF) as usize;
298
299 if result_len == 0 {
300 return Ok(PluginResult {
302 ok: true,
303 ctx: None,
304 error: None,
305 details: None,
306 });
307 }
308
309 let result_bytes = read_from_memory(&memory, store, result_ptr, result_len)?;
311
312 let result: PluginResult = serde_json::from_slice(&result_bytes).map_err(|e| {
313 error!(
314 raw = %String::from_utf8_lossy(&result_bytes),
315 "WASM plugin returned invalid JSON"
316 );
317 ShaperailError::Internal(format!("WASM plugin returned invalid JSON: {e}"))
318 })?;
319
320 Ok(result)
321 }
322
323 pub async fn is_loaded(&self, path: &str) -> bool {
325 let key = PathBuf::from(path);
326 self.plugins.read().await.contains_key(&key)
327 }
328}
329
330impl Default for WasmRuntime {
331 fn default() -> Self {
332 Self::new().expect("Failed to create default WasmRuntime")
333 }
334}
335
336fn write_to_memory(
338 memory: &Memory,
339 store: &mut Store<()>,
340 offset: usize,
341 data: &[u8],
342) -> Result<(), ShaperailError> {
343 let mem_data = memory.data_mut(store.as_context_mut());
344 let end = offset + data.len();
345 if end > mem_data.len() {
346 return Err(ShaperailError::Internal(format!(
347 "WASM memory write out of bounds: offset={offset}, len={}, memory_size={}",
348 data.len(),
349 mem_data.len()
350 )));
351 }
352 mem_data[offset..end].copy_from_slice(data);
353 Ok(())
354}
355
356fn read_from_memory(
358 memory: &Memory,
359 store: &mut Store<()>,
360 offset: usize,
361 len: usize,
362) -> Result<Vec<u8>, ShaperailError> {
363 let mem_data = memory.data(store.as_context());
364 let end = offset + len;
365 if end > mem_data.len() {
366 return Err(ShaperailError::Internal(format!(
367 "WASM memory read out of bounds: offset={offset}, len={len}, memory_size={}",
368 mem_data.len()
369 )));
370 }
371 Ok(mem_data[offset..end].to_vec())
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377
378 fn passthrough_wasm() -> Vec<u8> {
384 wat::parse_str(
385 r#"
386 (module
387 (memory (export "memory") 2)
388
389 ;; Simple bump allocator (start after reserved output area)
390 (global $bump (mut i32) (i32.const 4096))
391
392 (func (export "alloc") (param $size i32) (result i32)
393 (local $ptr i32)
394 (local.set $ptr (global.get $bump))
395 (global.set $bump (i32.add (global.get $bump) (local.get $size)))
396 (local.get $ptr)
397 )
398
399 (func (export "dealloc") (param $ptr i32) (param $size i32)
400 ;; no-op for bump allocator
401 )
402
403 ;; before_hook: ignore input, return {"ok":true} packed as (ptr << 32) | len
404 (func (export "before_hook") (param $ptr i32) (param $len i32) (result i64)
405 (local $out_ptr i32)
406 (local $out_len i32)
407
408 ;; Write {"ok":true} at offset 0
409 ;; {"ok":true} = 0x7B226F6B223A747275657D (11 bytes)
410 (i32.store8 (i32.const 0) (i32.const 0x7B)) ;; {
411 (i32.store8 (i32.const 1) (i32.const 0x22)) ;; "
412 (i32.store8 (i32.const 2) (i32.const 0x6F)) ;; o
413 (i32.store8 (i32.const 3) (i32.const 0x6B)) ;; k
414 (i32.store8 (i32.const 4) (i32.const 0x22)) ;; "
415 (i32.store8 (i32.const 5) (i32.const 0x3A)) ;; :
416 (i32.store8 (i32.const 6) (i32.const 0x74)) ;; t
417 (i32.store8 (i32.const 7) (i32.const 0x72)) ;; r
418 (i32.store8 (i32.const 8) (i32.const 0x75)) ;; u
419 (i32.store8 (i32.const 9) (i32.const 0x65)) ;; e
420 (i32.store8 (i32.const 10) (i32.const 0x7D)) ;; }
421
422 (local.set $out_ptr (i32.const 0))
423 (local.set $out_len (i32.const 11))
424
425 ;; Pack result: (ptr << 32) | len
426 (i64.or
427 (i64.shl
428 (i64.extend_i32_u (local.get $out_ptr))
429 (i64.const 32)
430 )
431 (i64.extend_i32_u (local.get $out_len))
432 )
433 )
434 )
435 "#,
436 )
437 .expect("WAT parse failed")
438 }
439
440 fn modifier_wasm() -> Vec<u8> {
443 let response = r#"{"ok":true,"ctx":{"input":{"name":"modified_by_wasm"},"data":null,"user":null,"headers":{},"tenant_id":null}}"#;
446 let bytes = response.as_bytes();
447 let len = bytes.len();
448
449 let mut wat = String::from(
451 r#"
452 (module
453 (memory (export "memory") 2)
454 (global $bump (mut i32) (i32.const 4096))
455
456 (func (export "alloc") (param $size i32) (result i32)
457 (local $ptr i32)
458 (local.set $ptr (global.get $bump))
459 (global.set $bump (i32.add (global.get $bump) (local.get $size)))
460 (local.get $ptr)
461 )
462
463 (func (export "dealloc") (param $ptr i32) (param $size i32))
464
465 (func (export "before_hook") (param $ptr i32) (param $len i32) (result i64)
466 "#,
467 );
468
469 for (i, b) in bytes.iter().enumerate() {
471 wat.push_str(&format!(
472 " (i32.store8 (i32.const {i}) (i32.const {}))\n",
473 *b as i32
474 ));
475 }
476
477 wat.push_str(&format!(
478 r#"
479 (i64.or
480 (i64.shl (i64.extend_i32_u (i32.const 0)) (i64.const 32))
481 (i64.extend_i32_u (i32.const {len}))
482 )
483 )
484 )
485 "#
486 ));
487
488 wat::parse_str(&wat).expect("WAT parse failed")
489 }
490
491 fn crashing_wasm() -> Vec<u8> {
493 wat::parse_str(
494 r#"
495 (module
496 (memory (export "memory") 2)
497 (global $bump (mut i32) (i32.const 4096))
498
499 (func (export "alloc") (param $size i32) (result i32)
500 (local $ptr i32)
501 (local.set $ptr (global.get $bump))
502 (global.set $bump (i32.add (global.get $bump) (local.get $size)))
503 (local.get $ptr)
504 )
505
506 (func (export "dealloc") (param $ptr i32) (param $size i32))
507
508 (func (export "before_hook") (param $ptr i32) (param $len i32) (result i64)
509 unreachable
510 )
511 )
512 "#,
513 )
514 .expect("WAT parse failed")
515 }
516
517 fn error_wasm() -> Vec<u8> {
519 let response = r#"{"ok":false,"error":"validation failed: email is required"}"#;
520 let bytes = response.as_bytes();
521 let len = bytes.len();
522
523 let mut wat = String::from(
524 r#"
525 (module
526 (memory (export "memory") 2)
527 (global $bump (mut i32) (i32.const 4096))
528
529 (func (export "alloc") (param $size i32) (result i32)
530 (local $ptr i32)
531 (local.set $ptr (global.get $bump))
532 (global.set $bump (i32.add (global.get $bump) (local.get $size)))
533 (local.get $ptr)
534 )
535
536 (func (export "dealloc") (param $ptr i32) (param $size i32))
537
538 (func (export "before_hook") (param $ptr i32) (param $len i32) (result i64)
539 "#,
540 );
541
542 for (i, b) in bytes.iter().enumerate() {
543 wat.push_str(&format!(
544 " (i32.store8 (i32.const {i}) (i32.const {}))\n",
545 *b as i32
546 ));
547 }
548
549 wat.push_str(&format!(
550 r#"
551 (i64.or
552 (i64.shl (i64.extend_i32_u (i32.const 0)) (i64.const 32))
553 (i64.extend_i32_u (i32.const {len}))
554 )
555 )
556 )
557 "#
558 ));
559
560 wat::parse_str(&wat).expect("WAT parse failed")
561 }
562
563 fn test_context() -> PluginContext {
564 let mut input = serde_json::Map::new();
565 input.insert("name".to_string(), serde_json::json!("Alice"));
566 input.insert("email".to_string(), serde_json::json!("alice@example.com"));
567
568 PluginContext {
569 input,
570 data: None,
571 user: Some(PluginUser {
572 id: "user-123".to_string(),
573 role: "admin".to_string(),
574 }),
575 headers: HashMap::new(),
576 tenant_id: None,
577 }
578 }
579
580 #[tokio::test]
581 async fn passthrough_hook_runs_and_returns_ok() {
582 let runtime = WasmRuntime::new().unwrap();
583 let wasm = passthrough_wasm();
584 runtime
585 .load_plugin_bytes("__test:passthrough", &wasm)
586 .await
587 .unwrap();
588
589 let ctx = test_context();
590 let result = runtime
591 .call_hook("__test:passthrough", "before_hook", &ctx)
592 .await
593 .unwrap();
594
595 assert!(result.ok);
596 }
597
598 #[tokio::test]
599 async fn modifier_hook_modifies_context() {
600 let runtime = WasmRuntime::new().unwrap();
601 let wasm = modifier_wasm();
602 runtime
603 .load_plugin_bytes("__test:modifier", &wasm)
604 .await
605 .unwrap();
606
607 let ctx = test_context();
608 let result = runtime
609 .call_hook("__test:modifier", "before_hook", &ctx)
610 .await
611 .unwrap();
612
613 assert!(result.ok);
614 let modified_ctx = result.ctx.unwrap();
615 assert_eq!(
616 modified_ctx.input.get("name").and_then(|v| v.as_str()),
617 Some("modified_by_wasm")
618 );
619 }
620
621 #[tokio::test]
622 async fn crashing_plugin_does_not_crash_server() {
623 let runtime = WasmRuntime::new().unwrap();
624 let wasm = crashing_wasm();
625 runtime
626 .load_plugin_bytes("__test:crash", &wasm)
627 .await
628 .unwrap();
629
630 let ctx = test_context();
631 let result = runtime
632 .call_hook("__test:crash", "before_hook", &ctx)
633 .await
634 .unwrap();
635
636 assert!(!result.ok);
638 assert!(result.error.as_ref().unwrap().contains("trapped"));
639 }
640
641 #[tokio::test]
642 async fn error_hook_returns_plugin_error() {
643 let runtime = WasmRuntime::new().unwrap();
644 let wasm = error_wasm();
645 runtime
646 .load_plugin_bytes("__test:error", &wasm)
647 .await
648 .unwrap();
649
650 let ctx = test_context();
651 let result = runtime
652 .call_hook("__test:error", "before_hook", &ctx)
653 .await
654 .unwrap();
655
656 assert!(!result.ok);
657 assert_eq!(
658 result.error.as_deref(),
659 Some("validation failed: email is required")
660 );
661 }
662
663 #[tokio::test]
664 async fn unloaded_plugin_returns_error() {
665 let runtime = WasmRuntime::new().unwrap();
666 let ctx = test_context();
667 let result = runtime
668 .call_hook("__test:nonexistent", "before_hook", &ctx)
669 .await;
670
671 assert!(result.is_err());
672 }
673
674 #[tokio::test]
675 async fn fuel_exhaustion_does_not_crash_server() {
676 let sandbox = SandboxConfig {
678 max_memory_pages: 256,
679 max_fuel: 1, };
681 let runtime = WasmRuntime::with_sandbox(sandbox).unwrap();
682
683 let wasm = modifier_wasm();
685 runtime
686 .load_plugin_bytes("__test:fuel", &wasm)
687 .await
688 .unwrap();
689
690 let ctx = test_context();
691 let result = runtime
692 .call_hook("__test:fuel", "before_hook", &ctx)
693 .await
694 .unwrap();
695
696 assert!(!result.ok);
698 assert!(result.error.as_ref().unwrap().contains("trapped"));
699 }
700
701 #[tokio::test]
702 async fn sandbox_no_wasi_by_default() {
703 let runtime = WasmRuntime::new().unwrap();
708 let wasm = passthrough_wasm();
709 runtime
710 .load_plugin_bytes("__test:sandbox", &wasm)
711 .await
712 .unwrap();
713
714 assert!(runtime.is_loaded("__test:sandbox").await);
715
716 let ctx = test_context();
717 let result = runtime
718 .call_hook("__test:sandbox", "before_hook", &ctx)
719 .await
720 .unwrap();
721 assert!(result.ok);
722 }
723}