1use std::collections::HashMap;
26use std::path::{Path, PathBuf};
27use std::sync::Arc;
28use std::time::SystemTime;
29
30use anyhow::{Context, Result};
31use wasmtime::{Config, Engine, Linker, Module, Store};
32
33use crate::types::{PluginInput, PluginOutput};
34
35#[derive(Debug, Clone)]
37pub struct RuntimeConfig {
38 pub max_memory: usize,
40 pub max_time_secs: u64,
42}
43
44impl Default for RuntimeConfig {
45 fn default() -> Self {
46 Self {
47 max_memory: 256 * 1024 * 1024, max_time_secs: 30,
49 }
50 }
51}
52
53pub fn validate_plugin_module(bytes: &[u8]) -> Result<()> {
64 let engine = Engine::default();
65 let module = Module::new(&engine, bytes)?;
66
67 if let Some(import) = module.imports().next() {
69 anyhow::bail!(
70 "plugin has forbidden import: {}::{}",
71 import.module(),
72 import.name()
73 );
74 }
75
76 let exports: Vec<_> = module.exports().map(|e| e.name()).collect();
78
79 if !exports.contains(&"memory") {
80 anyhow::bail!("plugin must export 'memory'");
81 }
82 if !exports.contains(&"alloc") {
83 anyhow::bail!("plugin must export 'alloc' function");
84 }
85 if !exports.contains(&"process") {
86 anyhow::bail!("plugin must export 'process' function");
87 }
88
89 Ok(())
90}
91
92pub struct Plugin {
94 name: String,
96 module: Module,
98 engine: Arc<Engine>,
100}
101
102impl Plugin {
103 pub fn load(path: &Path, _config: &RuntimeConfig) -> Result<Self> {
105 let name = path
106 .file_stem()
107 .and_then(|s| s.to_str())
108 .unwrap_or("unknown")
109 .to_string();
110
111 let mut engine_config = Config::new();
113 engine_config.consume_fuel(true); let engine = Arc::new(Engine::new(&engine_config)?);
116
117 let wasm_bytes =
119 std::fs::read(path).with_context(|| format!("failed to read {}", path.display()))?;
120
121 let module = Module::new(&engine, &wasm_bytes)
122 .with_context(|| format!("failed to compile {}", path.display()))?;
123
124 Ok(Self {
125 name,
126 module,
127 engine,
128 })
129 }
130
131 pub fn load_bytes(
133 name: impl Into<String>,
134 bytes: &[u8],
135 _config: &RuntimeConfig,
136 ) -> Result<Self> {
137 let name = name.into();
138
139 let mut engine_config = Config::new();
140 engine_config.consume_fuel(true);
141
142 let engine = Arc::new(Engine::new(&engine_config)?);
143 let module = Module::new(&engine, bytes)?;
144
145 Ok(Self {
146 name,
147 module,
148 engine,
149 })
150 }
151
152 pub fn name(&self) -> &str {
154 &self.name
155 }
156
157 pub fn execute(&self, input: &PluginInput, config: &RuntimeConfig) -> Result<PluginOutput> {
159 let mut store = Store::new(&self.engine, ());
161
162 let fuel = config.max_time_secs * 1_000_000;
164 store.set_fuel(fuel)?;
165
166 let linker = Linker::new(&self.engine);
169
170 let instance = linker.instantiate(&mut store, &self.module)?;
172
173 let input_bytes = rmp_serde::to_vec(input)?;
175
176 let memory = instance
178 .get_memory(&mut store, "memory")
179 .context("plugin must export 'memory'")?;
180
181 let alloc = instance
183 .get_typed_func::<u32, u32>(&mut store, "alloc")
184 .context("plugin must export 'alloc' function")?;
185
186 let input_ptr = alloc.call(&mut store, input_bytes.len() as u32)?;
188
189 memory.write(&mut store, input_ptr as usize, &input_bytes)?;
191
192 let process = instance
194 .get_typed_func::<(u32, u32), u64>(&mut store, "process")
195 .context("plugin must export 'process' function")?;
196
197 let result = process.call(&mut store, (input_ptr, input_bytes.len() as u32))?;
198
199 let output_ptr = (result >> 32) as u32;
201 let output_len = (result & 0xFFFF_FFFF) as u32;
202
203 let mut output_bytes = vec![0u8; output_len as usize];
205 memory.read(&store, output_ptr as usize, &mut output_bytes)?;
206
207 let output: PluginOutput = rmp_serde::from_slice(&output_bytes)?;
209
210 Ok(output)
211 }
212}
213
214pub struct PluginManager {
216 config: RuntimeConfig,
218 plugins: Vec<Plugin>,
220}
221
222impl PluginManager {
223 pub fn new() -> Self {
225 Self::with_config(RuntimeConfig::default())
226 }
227
228 pub const fn with_config(config: RuntimeConfig) -> Self {
230 Self {
231 config,
232 plugins: Vec::new(),
233 }
234 }
235
236 pub fn load(&mut self, path: &Path) -> Result<usize> {
238 let plugin = Plugin::load(path, &self.config)?;
239 let index = self.plugins.len();
240 self.plugins.push(plugin);
241 Ok(index)
242 }
243
244 pub fn load_bytes(&mut self, name: impl Into<String>, bytes: &[u8]) -> Result<usize> {
246 let plugin = Plugin::load_bytes(name, bytes, &self.config)?;
247 let index = self.plugins.len();
248 self.plugins.push(plugin);
249 Ok(index)
250 }
251
252 pub fn execute(&self, index: usize, input: &PluginInput) -> Result<PluginOutput> {
254 let plugin = self
255 .plugins
256 .get(index)
257 .context("plugin index out of bounds")?;
258 plugin.execute(input, &self.config)
259 }
260
261 pub fn execute_all(&self, mut input: PluginInput) -> Result<PluginOutput> {
263 let mut all_errors = Vec::new();
264
265 for plugin in &self.plugins {
266 let output = plugin.execute(&input, &self.config)?;
267 all_errors.extend(output.errors);
268 input.directives = output.directives;
269 }
270
271 Ok(PluginOutput {
272 directives: input.directives,
273 errors: all_errors,
274 })
275 }
276
277 pub fn len(&self) -> usize {
279 self.plugins.len()
280 }
281
282 pub fn is_empty(&self) -> bool {
284 self.plugins.is_empty()
285 }
286}
287
288impl Default for PluginManager {
289 fn default() -> Self {
290 Self::new()
291 }
292}
293
294struct TrackedPlugin {
296 plugin: Plugin,
298 path: PathBuf,
300 modified: SystemTime,
302}
303
304pub struct WatchingPluginManager {
324 config: RuntimeConfig,
326 plugins: Vec<TrackedPlugin>,
328 name_index: HashMap<String, usize>,
330 on_reload: Option<Box<dyn Fn(&str) + Send + Sync>>,
332}
333
334impl WatchingPluginManager {
335 pub fn new() -> Self {
337 Self::with_config(RuntimeConfig::default())
338 }
339
340 pub fn with_config(config: RuntimeConfig) -> Self {
342 Self {
343 config,
344 plugins: Vec::new(),
345 name_index: HashMap::new(),
346 on_reload: None,
347 }
348 }
349
350 pub fn on_reload<F>(&mut self, callback: F)
352 where
353 F: Fn(&str) + Send + Sync + 'static,
354 {
355 self.on_reload = Some(Box::new(callback));
356 }
357
358 pub fn load(&mut self, path: impl AsRef<Path>) -> Result<usize> {
360 let path = path.as_ref();
361 let abs_path = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
362
363 let metadata = std::fs::metadata(&abs_path)
365 .with_context(|| format!("failed to stat {}", abs_path.display()))?;
366 let modified = metadata.modified()?;
367
368 let plugin = Plugin::load(&abs_path, &self.config)?;
370 let name = plugin.name().to_string();
371 let index = self.plugins.len();
372
373 self.plugins.push(TrackedPlugin {
375 plugin,
376 path: abs_path,
377 modified,
378 });
379 self.name_index.insert(name, index);
380
381 Ok(index)
382 }
383
384 pub fn check_and_reload(&mut self) -> Result<bool> {
388 let mut reloaded = false;
389
390 for tracked in &mut self.plugins {
391 let metadata = match std::fs::metadata(&tracked.path) {
393 Ok(m) => m,
394 Err(_) => continue, };
396
397 let current_modified = match metadata.modified() {
398 Ok(m) => m,
399 Err(_) => continue,
400 };
401
402 if current_modified > tracked.modified {
404 match Plugin::load(&tracked.path, &self.config) {
406 Ok(new_plugin) => {
407 let name = tracked.plugin.name().to_string();
408 tracked.plugin = new_plugin;
409 tracked.modified = current_modified;
410 reloaded = true;
411
412 if let Some(ref callback) = self.on_reload {
414 callback(&name);
415 }
416 }
417 Err(e) => {
418 eprintln!(
420 "warning: failed to reload plugin {}: {}",
421 tracked.path.display(),
422 e
423 );
424 }
425 }
426 }
427 }
428
429 Ok(reloaded)
430 }
431
432 pub fn reload_all(&mut self) -> Result<()> {
434 for tracked in &mut self.plugins {
435 let new_plugin = Plugin::load(&tracked.path, &self.config)?;
436 let metadata = std::fs::metadata(&tracked.path)?;
437 tracked.plugin = new_plugin;
438 tracked.modified = metadata.modified()?;
439 }
440 Ok(())
441 }
442
443 pub fn get(&self, name: &str) -> Option<&Plugin> {
445 self.name_index.get(name).map(|&i| &self.plugins[i].plugin)
446 }
447
448 pub fn execute(&self, index: usize, input: &PluginInput) -> Result<PluginOutput> {
450 let tracked = self
451 .plugins
452 .get(index)
453 .context("plugin index out of bounds")?;
454 tracked.plugin.execute(input, &self.config)
455 }
456
457 pub fn execute_by_name(&self, name: &str, input: &PluginInput) -> Result<PluginOutput> {
459 let index = self
460 .name_index
461 .get(name)
462 .with_context(|| format!("plugin '{name}' not found"))?;
463 self.execute(*index, input)
464 }
465
466 pub fn execute_all(&self, mut input: PluginInput) -> Result<PluginOutput> {
468 let mut all_errors = Vec::new();
469
470 for tracked in &self.plugins {
471 let output = tracked.plugin.execute(&input, &self.config)?;
472 all_errors.extend(output.errors);
473 input.directives = output.directives;
474 }
475
476 Ok(PluginOutput {
477 directives: input.directives,
478 errors: all_errors,
479 })
480 }
481
482 pub fn len(&self) -> usize {
484 self.plugins.len()
485 }
486
487 pub fn is_empty(&self) -> bool {
489 self.plugins.is_empty()
490 }
491
492 pub fn plugin_info(&self) -> Vec<(&Path, SystemTime)> {
494 self.plugins
495 .iter()
496 .map(|t| (t.path.as_path(), t.modified))
497 .collect()
498 }
499}
500
501impl Default for WatchingPluginManager {
502 fn default() -> Self {
503 Self::new()
504 }
505}
506
507#[cfg(test)]
508mod tests {
509 use super::*;
510
511 #[test]
515 fn test_valid_plugin_validation() {
516 let wasm = wat::parse_str(
522 r#"
523 (module
524 (memory (export "memory") 1)
525 (func (export "alloc") (param i32) (result i32)
526 i32.const 0
527 )
528 (func (export "process") (param i32 i32) (result i64)
529 i64.const 0
530 )
531 )
532 "#,
533 )
534 .expect("valid wat");
535
536 let result = validate_plugin_module(&wasm);
537 assert!(
538 result.is_ok(),
539 "valid plugin should pass validation: {:?}",
540 result.err()
541 );
542 }
543
544 #[test]
546 fn test_wasi_import_rejected() {
547 let wasm = wat::parse_str(
549 r#"
550 (module
551 (import "wasi_snapshot_preview1" "fd_write"
552 (func $fd_write (param i32 i32 i32 i32) (result i32))
553 )
554 (memory (export "memory") 1)
555 (func (export "alloc") (param i32) (result i32)
556 i32.const 0
557 )
558 (func (export "process") (param i32 i32) (result i64)
559 i64.const 0
560 )
561 )
562 "#,
563 )
564 .expect("valid wat");
565
566 let result = validate_plugin_module(&wasm);
567 assert!(
568 result.is_err(),
569 "module with WASI import should be rejected"
570 );
571 let err = result.unwrap_err().to_string();
572 assert!(
573 err.contains("forbidden import"),
574 "error should mention forbidden import: {err}"
575 );
576 assert!(
577 err.contains("wasi_snapshot_preview1"),
578 "error should mention WASI: {err}"
579 );
580 }
581
582 #[test]
584 fn test_env_import_rejected() {
585 let wasm = wat::parse_str(
587 r#"
588 (module
589 (import "env" "some_func" (func $some_func))
590 (memory (export "memory") 1)
591 (func (export "alloc") (param i32) (result i32)
592 i32.const 0
593 )
594 (func (export "process") (param i32 i32) (result i64)
595 i64.const 0
596 )
597 )
598 "#,
599 )
600 .expect("valid wat");
601
602 let result = validate_plugin_module(&wasm);
603 assert!(result.is_err(), "module with env import should be rejected");
604 }
605
606 #[test]
608 fn test_missing_exports_rejected() {
609 let wasm = wat::parse_str(
611 r#"
612 (module
613 (memory (export "memory") 1)
614 (func (export "process") (param i32 i32) (result i64)
615 i64.const 0
616 )
617 )
618 "#,
619 )
620 .expect("valid wat");
621
622 let result = validate_plugin_module(&wasm);
623 assert!(result.is_err(), "module missing alloc should be rejected");
624 assert!(result.unwrap_err().to_string().contains("alloc"));
625 }
626
627 #[test]
629 fn test_runtime_config_defaults() {
630 let config = RuntimeConfig::default();
631 assert_eq!(config.max_memory, 256 * 1024 * 1024); assert_eq!(config.max_time_secs, 30);
633 }
634
635 #[test]
637 fn test_missing_memory_rejected() {
638 let wasm = wat::parse_str(
639 r#"
640 (module
641 (func (export "alloc") (param i32) (result i32)
642 i32.const 0
643 )
644 (func (export "process") (param i32 i32) (result i64)
645 i64.const 0
646 )
647 )
648 "#,
649 )
650 .expect("valid wat");
651
652 let result = validate_plugin_module(&wasm);
653 assert!(result.is_err(), "module missing memory should be rejected");
654 assert!(result.unwrap_err().to_string().contains("memory"));
655 }
656
657 #[test]
659 fn test_missing_process_rejected() {
660 let wasm = wat::parse_str(
661 r#"
662 (module
663 (memory (export "memory") 1)
664 (func (export "alloc") (param i32) (result i32)
665 i32.const 0
666 )
667 )
668 "#,
669 )
670 .expect("valid wat");
671
672 let result = validate_plugin_module(&wasm);
673 assert!(result.is_err(), "module missing process should be rejected");
674 assert!(result.unwrap_err().to_string().contains("process"));
675 }
676
677 #[test]
679 fn test_invalid_wasm_rejected() {
680 let invalid = b"not valid wasm bytes";
681 let result = validate_plugin_module(invalid);
682 assert!(result.is_err(), "invalid WASM should be rejected");
683 }
684
685 #[test]
687 fn test_runtime_config_custom() {
688 let config = RuntimeConfig {
689 max_memory: 512 * 1024 * 1024, max_time_secs: 60,
691 };
692 assert_eq!(config.max_memory, 512 * 1024 * 1024);
693 assert_eq!(config.max_time_secs, 60);
694 }
695}