Skip to main content

shaperail_runtime/plugins/
mod.rs

1//! WASM plugin runtime for Shaperail (M19).
2//!
3//! Provides sandboxed execution of WebAssembly plugins as controller hooks.
4//! Plugins receive a JSON context and return a modified JSON context.
5//!
6//! # Plugin Interface
7//!
8//! WASM modules must export:
9//! - `alloc(size: i32) -> i32` — allocate `size` bytes, return pointer
10//! - `dealloc(ptr: i32, size: i32)` — free previously allocated memory
11//! - `before_hook(ptr: i32, len: i32) -> i64` — process context, return `(ptr << 32) | len`
12//!
13//! Optionally:
14//! - `after_hook(ptr: i32, len: i32) -> i64` — same interface, called after DB operation
15//!
16//! # Sandboxing
17//!
18//! By default, plugins have NO access to:
19//! - Host filesystem
20//! - Network
21//! - Environment variables
22//! - System clock
23//!
24//! This is enforced by creating WASM instances without WASI capabilities.
25
26use std::collections::HashMap;
27use std::path::{Path, PathBuf};
28use std::sync::Arc;
29
30use shaperail_core::ShaperailError;
31use tokio::sync::RwLock;
32use tracing::{debug, error, warn};
33use wasmtime::{AsContext, AsContextMut, Engine, Instance, Linker, Memory, Module, Store, Val};
34
35/// JSON context passed to WASM plugins, matching the controller `Context` shape.
36#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
37pub struct PluginContext {
38    /// Mutable input data (before-hooks can modify this).
39    pub input: serde_json::Map<String, serde_json::Value>,
40    /// DB result data. `null` in before-hooks, populated in after-hooks.
41    pub data: Option<serde_json::Value>,
42    /// Authenticated user info, if present.
43    pub user: Option<PluginUser>,
44    /// Request headers (read-only from plugin perspective).
45    pub headers: HashMap<String, String>,
46    /// Tenant ID, if multi-tenancy is active.
47    pub tenant_id: Option<String>,
48}
49
50/// Minimal user info passed to WASM plugins.
51#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
52pub struct PluginUser {
53    pub id: String,
54    pub role: String,
55}
56
57/// Result returned from a WASM plugin hook.
58#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
59pub struct PluginResult {
60    /// Whether the hook succeeded.
61    pub ok: bool,
62    /// Modified context (only `input` and `data` changes are applied back).
63    #[serde(default)]
64    pub ctx: Option<PluginContext>,
65    /// Error message if `ok` is false.
66    #[serde(default)]
67    pub error: Option<String>,
68    /// Error details for validation errors.
69    #[serde(default)]
70    pub details: Option<Vec<serde_json::Value>>,
71}
72
73/// Configuration for plugin sandboxing.
74#[derive(Debug, Clone)]
75pub struct SandboxConfig {
76    /// Maximum memory pages (64KB each). Default: 256 (16MB).
77    pub max_memory_pages: u32,
78    /// Maximum execution fuel (instruction count limit). Default: 1_000_000.
79    pub max_fuel: u64,
80}
81
82impl Default for SandboxConfig {
83    fn default() -> Self {
84        Self {
85            max_memory_pages: 256,
86            max_fuel: 1_000_000,
87        }
88    }
89}
90
91/// A compiled WASM module ready for instantiation.
92struct CompiledPlugin {
93    module: Module,
94}
95
96/// Runtime for executing WASM plugins with sandboxing.
97///
98/// Caches compiled modules to avoid recompilation on every request.
99pub struct WasmRuntime {
100    engine: Engine,
101    plugins: Arc<RwLock<HashMap<PathBuf, CompiledPlugin>>>,
102    sandbox: SandboxConfig,
103}
104
105impl WasmRuntime {
106    /// Creates a new WASM runtime with default sandbox configuration.
107    pub fn new() -> Result<Self, ShaperailError> {
108        Self::with_sandbox(SandboxConfig::default())
109    }
110
111    /// Creates a new WASM runtime with custom sandbox configuration.
112    pub fn with_sandbox(sandbox: SandboxConfig) -> Result<Self, ShaperailError> {
113        let mut config = wasmtime::Config::new();
114        config.consume_fuel(true);
115
116        let engine = Engine::new(&config)
117            .map_err(|e| ShaperailError::Internal(format!("Failed to create WASM engine: {e}")))?;
118
119        Ok(Self {
120            engine,
121            plugins: Arc::new(RwLock::new(HashMap::new())),
122            sandbox,
123        })
124    }
125
126    /// Loads a WASM plugin from a file path. Caches the compiled module.
127    pub async fn load_plugin(&self, path: &Path) -> Result<(), ShaperailError> {
128        let canonical = path.canonicalize().map_err(|e| {
129            ShaperailError::Internal(format!(
130                "Failed to resolve WASM plugin path '{}': {e}",
131                path.display()
132            ))
133        })?;
134
135        let mut plugins = self.plugins.write().await;
136        if plugins.contains_key(&canonical) {
137            return Ok(());
138        }
139
140        let wasm_bytes = std::fs::read(&canonical).map_err(|e| {
141            ShaperailError::Internal(format!(
142                "Failed to read WASM plugin '{}': {e}",
143                canonical.display()
144            ))
145        })?;
146
147        let module = Module::new(&self.engine, &wasm_bytes).map_err(|e| {
148            ShaperailError::Internal(format!(
149                "Failed to compile WASM plugin '{}': {e}",
150                canonical.display()
151            ))
152        })?;
153
154        debug!(path = %canonical.display(), "Loaded WASM plugin");
155        plugins.insert(canonical, CompiledPlugin { module });
156        Ok(())
157    }
158
159    /// Loads a WASM plugin from raw bytes (for testing).
160    pub async fn load_plugin_bytes(
161        &self,
162        name: &str,
163        wasm_bytes: &[u8],
164    ) -> Result<(), ShaperailError> {
165        let key = PathBuf::from(name);
166        let mut plugins = self.plugins.write().await;
167
168        let module = Module::new(&self.engine, wasm_bytes).map_err(|e| {
169            ShaperailError::Internal(format!("Failed to compile WASM module '{name}': {e}"))
170        })?;
171
172        plugins.insert(key, CompiledPlugin { module });
173        Ok(())
174    }
175
176    /// Calls a hook function on a loaded WASM plugin.
177    ///
178    /// The `hook_name` should be `"before_hook"` or `"after_hook"`.
179    /// Returns the modified `PluginContext` on success.
180    pub async fn call_hook(
181        &self,
182        plugin_path: &str,
183        hook_name: &str,
184        ctx: &PluginContext,
185    ) -> Result<PluginResult, ShaperailError> {
186        let key = if plugin_path.starts_with('/') || plugin_path.starts_with("__test:") {
187            PathBuf::from(plugin_path)
188        } else {
189            Path::new(plugin_path)
190                .canonicalize()
191                .unwrap_or_else(|_| PathBuf::from(plugin_path))
192        };
193
194        let plugins = self.plugins.read().await;
195        let compiled = plugins.get(&key).ok_or_else(|| {
196            ShaperailError::Internal(format!(
197                "WASM plugin '{}' not loaded. Call load_plugin first.",
198                key.display()
199            ))
200        })?;
201
202        let ctx_json = serde_json::to_vec(ctx).map_err(|e| {
203            ShaperailError::Internal(format!("Failed to serialize plugin context: {e}"))
204        })?;
205
206        // Create a fresh store per invocation for isolation
207        let mut store = Store::new(&self.engine, ());
208        store
209            .set_fuel(self.sandbox.max_fuel)
210            .map_err(|e| ShaperailError::Internal(format!("Failed to set fuel: {e}")))?;
211
212        // No WASI — plugin runs fully sandboxed (no fs, no network, no env)
213        let linker = Linker::new(&self.engine);
214        let instance = linker
215            .instantiate(&mut store, &compiled.module)
216            .map_err(|e| {
217                ShaperailError::Internal(format!("Failed to instantiate WASM plugin: {e}"))
218            })?;
219
220        // Call the hook, catching any traps (panics, OOM, fuel exhaustion)
221        match self.invoke_hook(&mut store, &instance, hook_name, &ctx_json) {
222            Ok(result) => Ok(result),
223            Err(e) => {
224                warn!(
225                    plugin = plugin_path,
226                    hook = hook_name,
227                    error = %e,
228                    "WASM plugin hook trapped — returning error without crashing server"
229                );
230                Ok(PluginResult {
231                    ok: false,
232                    ctx: None,
233                    error: Some(format!("WASM plugin trapped: {e}")),
234                    details: None,
235                })
236            }
237        }
238    }
239
240    /// Internal: invoke a hook function on a WASM instance.
241    fn invoke_hook(
242        &self,
243        store: &mut Store<()>,
244        instance: &Instance,
245        hook_name: &str,
246        ctx_json: &[u8],
247    ) -> Result<PluginResult, ShaperailError> {
248        // Get required exports
249        let memory = instance
250            .get_memory(store.as_context_mut(), "memory")
251            .ok_or_else(|| {
252                ShaperailError::Internal("WASM plugin does not export 'memory'".to_string())
253            })?;
254
255        let alloc_fn = instance
256            .get_func(store.as_context_mut(), "alloc")
257            .ok_or_else(|| {
258                ShaperailError::Internal("WASM plugin does not export 'alloc'".to_string())
259            })?;
260
261        let hook_fn = instance
262            .get_func(store.as_context_mut(), hook_name)
263            .ok_or_else(|| {
264                ShaperailError::Internal(format!("WASM plugin does not export '{hook_name}'"))
265            })?;
266
267        // Allocate memory in guest for input JSON
268        let input_len = ctx_json.len() as i32;
269        let mut alloc_result = [Val::I32(0)];
270        alloc_fn
271            .call(
272                store.as_context_mut(),
273                &[Val::I32(input_len)],
274                &mut alloc_result,
275            )
276            .map_err(|e| ShaperailError::Internal(format!("WASM alloc call failed: {e}")))?;
277        let input_ptr = alloc_result[0].unwrap_i32();
278
279        // Write input JSON into guest memory
280        write_to_memory(&memory, store, input_ptr as usize, ctx_json)?;
281
282        // Call the hook function: hook(ptr, len) -> i64 (packed ptr|len)
283        let mut hook_result = [Val::I64(0)];
284        hook_fn
285            .call(
286                store.as_context_mut(),
287                &[Val::I32(input_ptr), Val::I32(input_len)],
288                &mut hook_result,
289            )
290            .map_err(|e| {
291                ShaperailError::Internal(format!("WASM hook '{hook_name}' trapped: {e}"))
292            })?;
293
294        // Unpack result: high 32 bits = ptr, low 32 bits = len
295        let packed = hook_result[0].unwrap_i64();
296        let result_ptr = (packed >> 32) as usize;
297        let result_len = (packed & 0xFFFF_FFFF) as usize;
298
299        if result_len == 0 {
300            // Empty result means "no changes, ok"
301            return Ok(PluginResult {
302                ok: true,
303                ctx: None,
304                error: None,
305                details: None,
306            });
307        }
308
309        // Read result JSON from guest memory
310        let result_bytes = read_from_memory(&memory, store, result_ptr, result_len)?;
311
312        let result: PluginResult = serde_json::from_slice(&result_bytes).map_err(|e| {
313            error!(
314                raw = %String::from_utf8_lossy(&result_bytes),
315                "WASM plugin returned invalid JSON"
316            );
317            ShaperailError::Internal(format!("WASM plugin returned invalid JSON: {e}"))
318        })?;
319
320        Ok(result)
321    }
322
323    /// Returns true if a plugin is loaded for the given path.
324    pub async fn is_loaded(&self, path: &str) -> bool {
325        let key = PathBuf::from(path);
326        self.plugins.read().await.contains_key(&key)
327    }
328}
329
330impl Default for WasmRuntime {
331    fn default() -> Self {
332        Self::new().expect("Failed to create default WasmRuntime")
333    }
334}
335
336/// Write bytes into WASM linear memory at the given offset.
337fn write_to_memory(
338    memory: &Memory,
339    store: &mut Store<()>,
340    offset: usize,
341    data: &[u8],
342) -> Result<(), ShaperailError> {
343    let mem_data = memory.data_mut(store.as_context_mut());
344    let end = offset + data.len();
345    if end > mem_data.len() {
346        return Err(ShaperailError::Internal(format!(
347            "WASM memory write out of bounds: offset={offset}, len={}, memory_size={}",
348            data.len(),
349            mem_data.len()
350        )));
351    }
352    mem_data[offset..end].copy_from_slice(data);
353    Ok(())
354}
355
356/// Read bytes from WASM linear memory at the given offset.
357fn read_from_memory(
358    memory: &Memory,
359    store: &mut Store<()>,
360    offset: usize,
361    len: usize,
362) -> Result<Vec<u8>, ShaperailError> {
363    let mem_data = memory.data(store.as_context());
364    let end = offset + len;
365    if end > mem_data.len() {
366        return Err(ShaperailError::Internal(format!(
367            "WASM memory read out of bounds: offset={offset}, len={len}, memory_size={}",
368            mem_data.len()
369        )));
370    }
371    Ok(mem_data[offset..end].to_vec())
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377
378    /// Minimal WASM module that implements the plugin interface.
379    /// Exports: memory, alloc, dealloc, before_hook
380    ///
381    /// The before_hook reads JSON, parses nothing complex, and returns
382    /// `{"ok": true}` (no modifications).
383    fn passthrough_wasm() -> Vec<u8> {
384        wat::parse_str(
385            r#"
386            (module
387                (memory (export "memory") 2)
388
389                ;; Simple bump allocator (start after reserved output area)
390                (global $bump (mut i32) (i32.const 4096))
391
392                (func (export "alloc") (param $size i32) (result i32)
393                    (local $ptr i32)
394                    (local.set $ptr (global.get $bump))
395                    (global.set $bump (i32.add (global.get $bump) (local.get $size)))
396                    (local.get $ptr)
397                )
398
399                (func (export "dealloc") (param $ptr i32) (param $size i32)
400                    ;; no-op for bump allocator
401                )
402
403                ;; before_hook: ignore input, return {"ok":true} packed as (ptr << 32) | len
404                (func (export "before_hook") (param $ptr i32) (param $len i32) (result i64)
405                    (local $out_ptr i32)
406                    (local $out_len i32)
407
408                    ;; Write {"ok":true} at offset 0
409                    ;; {"ok":true} = 0x7B226F6B223A747275657D (11 bytes)
410                    (i32.store8 (i32.const 0) (i32.const 0x7B))  ;; {
411                    (i32.store8 (i32.const 1) (i32.const 0x22))  ;; "
412                    (i32.store8 (i32.const 2) (i32.const 0x6F))  ;; o
413                    (i32.store8 (i32.const 3) (i32.const 0x6B))  ;; k
414                    (i32.store8 (i32.const 4) (i32.const 0x22))  ;; "
415                    (i32.store8 (i32.const 5) (i32.const 0x3A))  ;; :
416                    (i32.store8 (i32.const 6) (i32.const 0x74))  ;; t
417                    (i32.store8 (i32.const 7) (i32.const 0x72))  ;; r
418                    (i32.store8 (i32.const 8) (i32.const 0x75))  ;; u
419                    (i32.store8 (i32.const 9) (i32.const 0x65))  ;; e
420                    (i32.store8 (i32.const 10) (i32.const 0x7D)) ;; }
421
422                    (local.set $out_ptr (i32.const 0))
423                    (local.set $out_len (i32.const 11))
424
425                    ;; Pack result: (ptr << 32) | len
426                    (i64.or
427                        (i64.shl
428                            (i64.extend_i32_u (local.get $out_ptr))
429                            (i64.const 32)
430                        )
431                        (i64.extend_i32_u (local.get $out_len))
432                    )
433                )
434            )
435            "#,
436        )
437        .expect("WAT parse failed")
438    }
439
440    /// WASM module that modifies the input — lowercases a known field.
441    /// For simplicity, this just returns a fixed modified context.
442    fn modifier_wasm() -> Vec<u8> {
443        // This module reads input JSON, and returns a result that modifies
444        // the input by adding a field "wasm_modified": true
445        let response = r#"{"ok":true,"ctx":{"input":{"name":"modified_by_wasm"},"data":null,"user":null,"headers":{},"tenant_id":null}}"#;
446        let bytes = response.as_bytes();
447        let len = bytes.len();
448
449        // Build data section with the response string
450        let mut wat = String::from(
451            r#"
452            (module
453                (memory (export "memory") 2)
454                (global $bump (mut i32) (i32.const 4096))
455
456                (func (export "alloc") (param $size i32) (result i32)
457                    (local $ptr i32)
458                    (local.set $ptr (global.get $bump))
459                    (global.set $bump (i32.add (global.get $bump) (local.get $size)))
460                    (local.get $ptr)
461                )
462
463                (func (export "dealloc") (param $ptr i32) (param $size i32))
464
465                (func (export "before_hook") (param $ptr i32) (param $len i32) (result i64)
466            "#,
467        );
468
469        // Write response bytes to memory at offset 0
470        for (i, b) in bytes.iter().enumerate() {
471            wat.push_str(&format!(
472                "                    (i32.store8 (i32.const {i}) (i32.const {}))\n",
473                *b as i32
474            ));
475        }
476
477        wat.push_str(&format!(
478            r#"
479                    (i64.or
480                        (i64.shl (i64.extend_i32_u (i32.const 0)) (i64.const 32))
481                        (i64.extend_i32_u (i32.const {len}))
482                    )
483                )
484            )
485            "#
486        ));
487
488        wat::parse_str(&wat).expect("WAT parse failed")
489    }
490
491    /// WASM module that traps (unreachable instruction) to test crash isolation.
492    fn crashing_wasm() -> Vec<u8> {
493        wat::parse_str(
494            r#"
495            (module
496                (memory (export "memory") 2)
497                (global $bump (mut i32) (i32.const 4096))
498
499                (func (export "alloc") (param $size i32) (result i32)
500                    (local $ptr i32)
501                    (local.set $ptr (global.get $bump))
502                    (global.set $bump (i32.add (global.get $bump) (local.get $size)))
503                    (local.get $ptr)
504                )
505
506                (func (export "dealloc") (param $ptr i32) (param $size i32))
507
508                (func (export "before_hook") (param $ptr i32) (param $len i32) (result i64)
509                    unreachable
510                )
511            )
512            "#,
513        )
514        .expect("WAT parse failed")
515    }
516
517    /// WASM module that returns an error result.
518    fn error_wasm() -> Vec<u8> {
519        let response = r#"{"ok":false,"error":"validation failed: email is required"}"#;
520        let bytes = response.as_bytes();
521        let len = bytes.len();
522
523        let mut wat = String::from(
524            r#"
525            (module
526                (memory (export "memory") 2)
527                (global $bump (mut i32) (i32.const 4096))
528
529                (func (export "alloc") (param $size i32) (result i32)
530                    (local $ptr i32)
531                    (local.set $ptr (global.get $bump))
532                    (global.set $bump (i32.add (global.get $bump) (local.get $size)))
533                    (local.get $ptr)
534                )
535
536                (func (export "dealloc") (param $ptr i32) (param $size i32))
537
538                (func (export "before_hook") (param $ptr i32) (param $len i32) (result i64)
539            "#,
540        );
541
542        for (i, b) in bytes.iter().enumerate() {
543            wat.push_str(&format!(
544                "                    (i32.store8 (i32.const {i}) (i32.const {}))\n",
545                *b as i32
546            ));
547        }
548
549        wat.push_str(&format!(
550            r#"
551                    (i64.or
552                        (i64.shl (i64.extend_i32_u (i32.const 0)) (i64.const 32))
553                        (i64.extend_i32_u (i32.const {len}))
554                    )
555                )
556            )
557            "#
558        ));
559
560        wat::parse_str(&wat).expect("WAT parse failed")
561    }
562
563    fn test_context() -> PluginContext {
564        let mut input = serde_json::Map::new();
565        input.insert("name".to_string(), serde_json::json!("Alice"));
566        input.insert("email".to_string(), serde_json::json!("alice@example.com"));
567
568        PluginContext {
569            input,
570            data: None,
571            user: Some(PluginUser {
572                id: "user-123".to_string(),
573                role: "admin".to_string(),
574            }),
575            headers: HashMap::new(),
576            tenant_id: None,
577        }
578    }
579
580    #[tokio::test]
581    async fn passthrough_hook_runs_and_returns_ok() {
582        let runtime = WasmRuntime::new().unwrap();
583        let wasm = passthrough_wasm();
584        runtime
585            .load_plugin_bytes("__test:passthrough", &wasm)
586            .await
587            .unwrap();
588
589        let ctx = test_context();
590        let result = runtime
591            .call_hook("__test:passthrough", "before_hook", &ctx)
592            .await
593            .unwrap();
594
595        assert!(result.ok);
596    }
597
598    #[tokio::test]
599    async fn modifier_hook_modifies_context() {
600        let runtime = WasmRuntime::new().unwrap();
601        let wasm = modifier_wasm();
602        runtime
603            .load_plugin_bytes("__test:modifier", &wasm)
604            .await
605            .unwrap();
606
607        let ctx = test_context();
608        let result = runtime
609            .call_hook("__test:modifier", "before_hook", &ctx)
610            .await
611            .unwrap();
612
613        assert!(result.ok);
614        let modified_ctx = result.ctx.unwrap();
615        assert_eq!(
616            modified_ctx.input.get("name").and_then(|v| v.as_str()),
617            Some("modified_by_wasm")
618        );
619    }
620
621    #[tokio::test]
622    async fn crashing_plugin_does_not_crash_server() {
623        let runtime = WasmRuntime::new().unwrap();
624        let wasm = crashing_wasm();
625        runtime
626            .load_plugin_bytes("__test:crash", &wasm)
627            .await
628            .unwrap();
629
630        let ctx = test_context();
631        let result = runtime
632            .call_hook("__test:crash", "before_hook", &ctx)
633            .await
634            .unwrap();
635
636        // Plugin trapped but server is fine — returns error result
637        assert!(!result.ok);
638        assert!(result.error.as_ref().unwrap().contains("trapped"));
639    }
640
641    #[tokio::test]
642    async fn error_hook_returns_plugin_error() {
643        let runtime = WasmRuntime::new().unwrap();
644        let wasm = error_wasm();
645        runtime
646            .load_plugin_bytes("__test:error", &wasm)
647            .await
648            .unwrap();
649
650        let ctx = test_context();
651        let result = runtime
652            .call_hook("__test:error", "before_hook", &ctx)
653            .await
654            .unwrap();
655
656        assert!(!result.ok);
657        assert_eq!(
658            result.error.as_deref(),
659            Some("validation failed: email is required")
660        );
661    }
662
663    #[tokio::test]
664    async fn unloaded_plugin_returns_error() {
665        let runtime = WasmRuntime::new().unwrap();
666        let ctx = test_context();
667        let result = runtime
668            .call_hook("__test:nonexistent", "before_hook", &ctx)
669            .await;
670
671        assert!(result.is_err());
672    }
673
674    #[tokio::test]
675    async fn fuel_exhaustion_does_not_crash_server() {
676        // Create runtime with very low fuel limit
677        let sandbox = SandboxConfig {
678            max_memory_pages: 256,
679            max_fuel: 1, // Very low — will exhaust quickly
680        };
681        let runtime = WasmRuntime::with_sandbox(sandbox).unwrap();
682
683        // Use the modifier which does more work
684        let wasm = modifier_wasm();
685        runtime
686            .load_plugin_bytes("__test:fuel", &wasm)
687            .await
688            .unwrap();
689
690        let ctx = test_context();
691        let result = runtime
692            .call_hook("__test:fuel", "before_hook", &ctx)
693            .await
694            .unwrap();
695
696        // Should fail gracefully due to fuel exhaustion
697        assert!(!result.ok);
698        assert!(result.error.as_ref().unwrap().contains("trapped"));
699    }
700
701    #[tokio::test]
702    async fn sandbox_no_wasi_by_default() {
703        // This test verifies that WASM plugins cannot access host resources.
704        // The passthrough module has no WASI imports, proving sandboxing works.
705        // If we tried to instantiate a module WITH WASI imports, it would fail
706        // because we don't provide WASI in the linker.
707        let runtime = WasmRuntime::new().unwrap();
708        let wasm = passthrough_wasm();
709        runtime
710            .load_plugin_bytes("__test:sandbox", &wasm)
711            .await
712            .unwrap();
713
714        assert!(runtime.is_loaded("__test:sandbox").await);
715
716        let ctx = test_context();
717        let result = runtime
718            .call_hook("__test:sandbox", "before_hook", &ctx)
719            .await
720            .unwrap();
721        assert!(result.ok);
722    }
723}