1use std::collections::HashMap;
26use std::path::{Path, PathBuf};
27use std::sync::Arc;
28use std::time::SystemTime;
29
30use anyhow::{Context, Result};
31use wasmtime::{Config, Engine, Linker, Module, Store};
32
33use crate::types::{DirectiveWrapper, PluginInput, PluginOp, PluginOutput};
34
35fn materialize_ops(input: &[DirectiveWrapper], output: &PluginOutput) -> Vec<DirectiveWrapper> {
44 let mut out = Vec::with_capacity(output.ops.len());
45 for op in &output.ops {
46 match op {
47 PluginOp::Keep(i) => {
48 if let Some(w) = input.get(*i) {
49 out.push(w.clone());
50 }
51 }
52 PluginOp::Modify(_, w) | PluginOp::Insert(w) => out.push(w.clone()),
53 PluginOp::Delete(_) => {}
54 }
55 }
56 out
57}
58
59#[derive(Debug, Clone)]
61pub struct RuntimeConfig {
62 pub max_memory: usize,
64 pub max_time_secs: u64,
66}
67
68impl Default for RuntimeConfig {
69 fn default() -> Self {
70 Self {
71 max_memory: 256 * 1024 * 1024, max_time_secs: 30,
73 }
74 }
75}
76
77pub fn validate_plugin_module(bytes: &[u8]) -> Result<()> {
88 let engine = Engine::default();
89 let module = Module::new(&engine, bytes)?;
90
91 if let Some(import) = module.imports().next() {
93 anyhow::bail!(
94 "plugin has forbidden import: {}::{}",
95 import.module(),
96 import.name()
97 );
98 }
99
100 let exports: Vec<_> = module.exports().map(|e| e.name()).collect();
102
103 if !exports.contains(&"memory") {
104 anyhow::bail!("plugin must export 'memory'");
105 }
106 if !exports.contains(&"alloc") {
107 anyhow::bail!("plugin must export 'alloc' function");
108 }
109 if !exports.contains(&"process") {
110 anyhow::bail!("plugin must export 'process' function");
111 }
112
113 Ok(())
114}
115
116pub struct Plugin {
118 name: String,
120 module: Module,
122 engine: Arc<Engine>,
124}
125
126impl Plugin {
127 pub fn load(path: &Path, _config: &RuntimeConfig) -> Result<Self> {
129 let name = path
130 .file_stem()
131 .and_then(|s| s.to_str())
132 .unwrap_or("unknown")
133 .to_string();
134
135 let mut engine_config = Config::new();
137 engine_config.consume_fuel(true); let engine = Arc::new(Engine::new(&engine_config)?);
140
141 let wasm_bytes =
143 std::fs::read(path).with_context(|| format!("failed to read {}", path.display()))?;
144
145 let module = Module::new(&engine, &wasm_bytes)
146 .map_err(anyhow::Error::from)
147 .with_context(|| format!("failed to compile {}", path.display()))?;
148
149 Ok(Self {
150 name,
151 module,
152 engine,
153 })
154 }
155
156 pub fn load_bytes(
158 name: impl Into<String>,
159 bytes: &[u8],
160 _config: &RuntimeConfig,
161 ) -> Result<Self> {
162 let name = name.into();
163
164 let mut engine_config = Config::new();
165 engine_config.consume_fuel(true);
166
167 let engine = Arc::new(Engine::new(&engine_config)?);
168 let module = Module::new(&engine, bytes)?;
169
170 Ok(Self {
171 name,
172 module,
173 engine,
174 })
175 }
176
177 pub fn name(&self) -> &str {
179 &self.name
180 }
181
182 pub fn execute(&self, input: &PluginInput, config: &RuntimeConfig) -> Result<PluginOutput> {
184 let mut store = Store::new(&self.engine, ());
186
187 let fuel = config.max_time_secs * 1_000_000;
189 store.set_fuel(fuel)?;
190
191 let linker = Linker::new(&self.engine);
194
195 let instance = linker.instantiate(&mut store, &self.module)?;
197
198 let input_bytes = rmp_serde::to_vec(input)?;
200
201 let memory = instance
203 .get_memory(&mut store, "memory")
204 .ok_or_else(|| anyhow::anyhow!("plugin must export 'memory'"))?;
205
206 let alloc = instance
208 .get_typed_func::<u32, u32>(&mut store, "alloc")
209 .map_err(anyhow::Error::from)
210 .context("plugin must export 'alloc' function")?;
211
212 let input_ptr = alloc.call(&mut store, input_bytes.len() as u32)?;
214
215 memory.write(&mut store, input_ptr as usize, &input_bytes)?;
217
218 let process = instance
220 .get_typed_func::<(u32, u32), u64>(&mut store, "process")
221 .map_err(anyhow::Error::from)
222 .context("plugin must export 'process' function")?;
223
224 let result = process.call(&mut store, (input_ptr, input_bytes.len() as u32))?;
225
226 let output_ptr = (result >> 32) as u32;
228 let output_len = (result & 0xFFFF_FFFF) as u32;
229
230 let mut output_bytes = vec![0u8; output_len as usize];
232 memory.read(&store, output_ptr as usize, &mut output_bytes)?;
233
234 let output: PluginOutput = rmp_serde::from_slice(&output_bytes)?;
236
237 Ok(output)
238 }
239}
240
241pub struct PluginManager {
243 config: RuntimeConfig,
245 plugins: Vec<Plugin>,
247}
248
249impl PluginManager {
250 pub fn new() -> Self {
252 Self::with_config(RuntimeConfig::default())
253 }
254
255 pub const fn with_config(config: RuntimeConfig) -> Self {
257 Self {
258 config,
259 plugins: Vec::new(),
260 }
261 }
262
263 pub fn load(&mut self, path: &Path) -> Result<usize> {
265 let plugin = Plugin::load(path, &self.config)?;
266 let index = self.plugins.len();
267 self.plugins.push(plugin);
268 Ok(index)
269 }
270
271 pub fn load_bytes(&mut self, name: impl Into<String>, bytes: &[u8]) -> Result<usize> {
273 let plugin = Plugin::load_bytes(name, bytes, &self.config)?;
274 let index = self.plugins.len();
275 self.plugins.push(plugin);
276 Ok(index)
277 }
278
279 pub fn execute(&self, index: usize, input: &PluginInput) -> Result<PluginOutput> {
281 let plugin = self
282 .plugins
283 .get(index)
284 .context("plugin index out of bounds")?;
285 plugin.execute(input, &self.config)
286 }
287
288 pub fn execute_all(&self, mut input: PluginInput) -> Result<PluginOutput> {
301 let mut all_errors = Vec::new();
302 let n_original = input.directives.len();
303
304 for plugin in &self.plugins {
305 let output = plugin.execute(&input, &self.config)?;
306 input.directives = materialize_ops(&input.directives, &output);
308 all_errors.extend(output.errors);
309 }
310
311 let mut ops: Vec<PluginOp> = (0..n_original).map(PluginOp::Delete).collect();
315 for w in input.directives {
316 ops.push(PluginOp::Insert(w));
317 }
318
319 Ok(PluginOutput {
320 ops,
321 errors: all_errors,
322 })
323 }
324
325 pub const fn len(&self) -> usize {
327 self.plugins.len()
328 }
329
330 pub const fn is_empty(&self) -> bool {
332 self.plugins.is_empty()
333 }
334}
335
336impl Default for PluginManager {
337 fn default() -> Self {
338 Self::new()
339 }
340}
341
342struct TrackedPlugin {
344 plugin: Plugin,
346 path: PathBuf,
348 modified: SystemTime,
350}
351
352pub struct WatchingPluginManager {
372 config: RuntimeConfig,
374 plugins: Vec<TrackedPlugin>,
376 name_index: HashMap<String, usize>,
378 on_reload: Option<Box<dyn Fn(&str) + Send + Sync>>,
380}
381
382impl WatchingPluginManager {
383 pub fn new() -> Self {
385 Self::with_config(RuntimeConfig::default())
386 }
387
388 pub fn with_config(config: RuntimeConfig) -> Self {
390 Self {
391 config,
392 plugins: Vec::new(),
393 name_index: HashMap::new(),
394 on_reload: None,
395 }
396 }
397
398 pub fn on_reload<F>(&mut self, callback: F)
400 where
401 F: Fn(&str) + Send + Sync + 'static,
402 {
403 self.on_reload = Some(Box::new(callback));
404 }
405
406 pub fn load(&mut self, path: impl AsRef<Path>) -> Result<usize> {
408 let path = path.as_ref();
409 let abs_path = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
411
412 let metadata = std::fs::metadata(&abs_path)
414 .with_context(|| format!("failed to stat {}", abs_path.display()))?;
415 let modified = metadata.modified()?;
416
417 let plugin = Plugin::load(&abs_path, &self.config)?;
419 let name = plugin.name().to_string();
420 let index = self.plugins.len();
421
422 self.plugins.push(TrackedPlugin {
424 plugin,
425 path: abs_path,
426 modified,
427 });
428 self.name_index.insert(name, index);
429
430 Ok(index)
431 }
432
433 pub fn check_and_reload(&mut self) -> Result<bool> {
437 let mut reloaded = false;
438
439 for tracked in &mut self.plugins {
440 let metadata = match std::fs::metadata(&tracked.path) {
442 Ok(m) => m,
443 Err(_) => continue, };
445
446 let current_modified = match metadata.modified() {
447 Ok(m) => m,
448 Err(_) => continue,
449 };
450
451 if current_modified > tracked.modified {
453 match Plugin::load(&tracked.path, &self.config) {
455 Ok(new_plugin) => {
456 let name = tracked.plugin.name().to_string();
457 tracked.plugin = new_plugin;
458 tracked.modified = current_modified;
459 reloaded = true;
460
461 if let Some(ref callback) = self.on_reload {
463 callback(&name);
464 }
465 }
466 Err(e) => {
467 eprintln!(
469 "warning: failed to reload plugin {}: {}",
470 tracked.path.display(),
471 e
472 );
473 }
474 }
475 }
476 }
477
478 Ok(reloaded)
479 }
480
481 pub fn reload_all(&mut self) -> Result<()> {
483 for tracked in &mut self.plugins {
484 let new_plugin = Plugin::load(&tracked.path, &self.config)?;
485 let metadata = std::fs::metadata(&tracked.path)?;
486 tracked.plugin = new_plugin;
487 tracked.modified = metadata.modified()?;
488 }
489 Ok(())
490 }
491
492 pub fn get(&self, name: &str) -> Option<&Plugin> {
494 self.name_index.get(name).map(|&i| &self.plugins[i].plugin)
495 }
496
497 pub fn execute(&self, index: usize, input: &PluginInput) -> Result<PluginOutput> {
499 let tracked = self
500 .plugins
501 .get(index)
502 .context("plugin index out of bounds")?;
503 tracked.plugin.execute(input, &self.config)
504 }
505
506 pub fn execute_by_name(&self, name: &str, input: &PluginInput) -> Result<PluginOutput> {
508 let index = self
509 .name_index
510 .get(name)
511 .with_context(|| format!("plugin '{name}' not found"))?;
512 self.execute(*index, input)
513 }
514
515 pub fn execute_all(&self, mut input: PluginInput) -> Result<PluginOutput> {
520 let mut all_errors = Vec::new();
521 let n_original = input.directives.len();
522
523 for tracked in &self.plugins {
524 let output = tracked.plugin.execute(&input, &self.config)?;
525 input.directives = materialize_ops(&input.directives, &output);
526 all_errors.extend(output.errors);
527 }
528
529 let mut ops: Vec<PluginOp> = (0..n_original).map(PluginOp::Delete).collect();
530 for w in input.directives {
531 ops.push(PluginOp::Insert(w));
532 }
533
534 Ok(PluginOutput {
535 ops,
536 errors: all_errors,
537 })
538 }
539
540 pub const fn len(&self) -> usize {
542 self.plugins.len()
543 }
544
545 pub const fn is_empty(&self) -> bool {
547 self.plugins.is_empty()
548 }
549
550 pub fn plugin_info(&self) -> Vec<(&Path, SystemTime)> {
552 self.plugins
553 .iter()
554 .map(|t| (t.path.as_path(), t.modified))
555 .collect()
556 }
557}
558
559impl Default for WatchingPluginManager {
560 fn default() -> Self {
561 Self::new()
562 }
563}
564
565#[cfg(test)]
566mod tests {
567 use super::*;
568
569 #[test]
573 fn test_valid_plugin_validation() {
574 let wasm = wat::parse_str(
580 r#"
581 (module
582 (memory (export "memory") 1)
583 (func (export "alloc") (param i32) (result i32)
584 i32.const 0
585 )
586 (func (export "process") (param i32 i32) (result i64)
587 i64.const 0
588 )
589 )
590 "#,
591 )
592 .expect("valid wat");
593
594 let result = validate_plugin_module(&wasm);
595 assert!(
596 result.is_ok(),
597 "valid plugin should pass validation: {:?}",
598 result.err()
599 );
600 }
601
602 #[test]
604 fn test_wasi_import_rejected() {
605 let wasm = wat::parse_str(
607 r#"
608 (module
609 (import "wasi_snapshot_preview1" "fd_write"
610 (func $fd_write (param i32 i32 i32 i32) (result i32))
611 )
612 (memory (export "memory") 1)
613 (func (export "alloc") (param i32) (result i32)
614 i32.const 0
615 )
616 (func (export "process") (param i32 i32) (result i64)
617 i64.const 0
618 )
619 )
620 "#,
621 )
622 .expect("valid wat");
623
624 let result = validate_plugin_module(&wasm);
625 assert!(
626 result.is_err(),
627 "module with WASI import should be rejected"
628 );
629 let err = result.unwrap_err().to_string();
630 assert!(
631 err.contains("forbidden import"),
632 "error should mention forbidden import: {err}"
633 );
634 assert!(
635 err.contains("wasi_snapshot_preview1"),
636 "error should mention WASI: {err}"
637 );
638 }
639
640 #[test]
642 fn test_env_import_rejected() {
643 let wasm = wat::parse_str(
645 r#"
646 (module
647 (import "env" "some_func" (func $some_func))
648 (memory (export "memory") 1)
649 (func (export "alloc") (param i32) (result i32)
650 i32.const 0
651 )
652 (func (export "process") (param i32 i32) (result i64)
653 i64.const 0
654 )
655 )
656 "#,
657 )
658 .expect("valid wat");
659
660 let result = validate_plugin_module(&wasm);
661 assert!(result.is_err(), "module with env import should be rejected");
662 }
663
664 #[test]
666 fn test_missing_exports_rejected() {
667 let wasm = wat::parse_str(
669 r#"
670 (module
671 (memory (export "memory") 1)
672 (func (export "process") (param i32 i32) (result i64)
673 i64.const 0
674 )
675 )
676 "#,
677 )
678 .expect("valid wat");
679
680 let result = validate_plugin_module(&wasm);
681 assert!(result.is_err(), "module missing alloc should be rejected");
682 assert!(result.unwrap_err().to_string().contains("alloc"));
683 }
684
685 #[test]
687 fn test_runtime_config_defaults() {
688 let config = RuntimeConfig::default();
689 assert_eq!(config.max_memory, 256 * 1024 * 1024); assert_eq!(config.max_time_secs, 30);
691 }
692
693 #[test]
695 fn test_missing_memory_rejected() {
696 let wasm = wat::parse_str(
697 r#"
698 (module
699 (func (export "alloc") (param i32) (result i32)
700 i32.const 0
701 )
702 (func (export "process") (param i32 i32) (result i64)
703 i64.const 0
704 )
705 )
706 "#,
707 )
708 .expect("valid wat");
709
710 let result = validate_plugin_module(&wasm);
711 assert!(result.is_err(), "module missing memory should be rejected");
712 assert!(result.unwrap_err().to_string().contains("memory"));
713 }
714
715 #[test]
717 fn test_missing_process_rejected() {
718 let wasm = wat::parse_str(
719 r#"
720 (module
721 (memory (export "memory") 1)
722 (func (export "alloc") (param i32) (result i32)
723 i32.const 0
724 )
725 )
726 "#,
727 )
728 .expect("valid wat");
729
730 let result = validate_plugin_module(&wasm);
731 assert!(result.is_err(), "module missing process should be rejected");
732 assert!(result.unwrap_err().to_string().contains("process"));
733 }
734
735 #[test]
737 fn test_invalid_wasm_rejected() {
738 let invalid = b"not valid wasm bytes";
739 let result = validate_plugin_module(invalid);
740 assert!(result.is_err(), "invalid WASM should be rejected");
741 }
742
743 #[test]
745 fn test_runtime_config_custom() {
746 let config = RuntimeConfig {
747 max_memory: 512 * 1024 * 1024, max_time_secs: 60,
749 };
750 assert_eq!(config.max_memory, 512 * 1024 * 1024);
751 assert_eq!(config.max_time_secs, 60);
752 }
753
754 #[test]
759 fn test_plugin_manager_new() {
760 let manager = PluginManager::new();
761 assert!(manager.is_empty());
762 assert_eq!(manager.len(), 0);
763 }
764
765 #[test]
766 fn test_plugin_manager_with_config() {
767 let config = RuntimeConfig {
768 max_memory: 128 * 1024 * 1024,
769 max_time_secs: 10,
770 };
771 let manager = PluginManager::with_config(config);
772 assert!(manager.is_empty());
773 }
774
775 #[test]
776 fn test_plugin_manager_default() {
777 let manager = PluginManager::default();
778 assert!(manager.is_empty());
779 assert_eq!(manager.len(), 0);
780 }
781
782 #[test]
783 fn test_watching_plugin_manager_new() {
784 let manager = WatchingPluginManager::new();
785 assert!(manager.is_empty());
786 assert_eq!(manager.len(), 0);
787 assert!(manager.plugin_info().is_empty());
788 }
789
790 #[test]
791 fn test_watching_plugin_manager_with_config() {
792 let config = RuntimeConfig {
793 max_memory: 64 * 1024 * 1024,
794 max_time_secs: 5,
795 };
796 let manager = WatchingPluginManager::with_config(config);
797 assert!(manager.is_empty());
798 }
799
800 #[test]
801 fn test_watching_plugin_manager_default() {
802 let manager = WatchingPluginManager::default();
803 assert!(manager.is_empty());
804 assert_eq!(manager.len(), 0);
805 }
806
807 #[test]
808 fn test_watching_plugin_manager_get_unknown() {
809 let manager = WatchingPluginManager::new();
810 assert!(manager.get("nonexistent").is_none());
811 }
812
813 #[test]
814 fn test_plugin_manager_execute_out_of_bounds() {
815 let manager = PluginManager::new();
816 let input = crate::types::PluginInput {
817 directives: vec![],
818 options: crate::types::PluginOptions::default(),
819 config: None,
820 };
821 let result = manager.execute(0, &input);
822 assert!(result.is_err());
823 assert!(result.unwrap_err().to_string().contains("out of bounds"));
824 }
825
826 #[test]
827 fn test_watching_plugin_manager_execute_out_of_bounds() {
828 let manager = WatchingPluginManager::new();
829 let input = crate::types::PluginInput {
830 directives: vec![],
831 options: crate::types::PluginOptions::default(),
832 config: None,
833 };
834 let result = manager.execute(0, &input);
835 assert!(result.is_err());
836 assert!(result.unwrap_err().to_string().contains("out of bounds"));
837 }
838
839 #[test]
840 fn test_watching_plugin_manager_execute_by_name_unknown() {
841 let manager = WatchingPluginManager::new();
842 let input = crate::types::PluginInput {
843 directives: vec![],
844 options: crate::types::PluginOptions::default(),
845 config: None,
846 };
847 let result = manager.execute_by_name("unknown", &input);
848 assert!(result.is_err());
849 assert!(result.unwrap_err().to_string().contains("not found"));
850 }
851
852 #[test]
853 fn test_plugin_manager_execute_all_empty() {
854 let manager = PluginManager::new();
855 let input = crate::types::PluginInput {
856 directives: vec![],
857 options: crate::types::PluginOptions::default(),
858 config: None,
859 };
860 let result = manager.execute_all(input);
861 assert!(result.is_ok());
862 let output = result.unwrap();
863 assert!(output.ops.is_empty());
864 assert!(output.errors.is_empty());
865 }
866
867 #[test]
868 fn test_watching_plugin_manager_execute_all_empty() {
869 let manager = WatchingPluginManager::new();
870 let input = crate::types::PluginInput {
871 directives: vec![],
872 options: crate::types::PluginOptions::default(),
873 config: None,
874 };
875 let result = manager.execute_all(input);
876 assert!(result.is_ok());
877 let output = result.unwrap();
878 assert!(output.ops.is_empty());
879 assert!(output.errors.is_empty());
880 }
881
882 #[test]
883 fn test_watching_plugin_manager_check_reload_empty() {
884 let mut manager = WatchingPluginManager::new();
885 let result = manager.check_and_reload();
886 assert!(result.is_ok());
887 assert!(!result.unwrap()); }
889
890 #[test]
891 fn test_watching_plugin_manager_reload_all_empty() {
892 let mut manager = WatchingPluginManager::new();
893 let result = manager.reload_all();
894 assert!(result.is_ok()); }
896
897 #[test]
898 fn test_plugin_load_bytes() {
899 let wasm = wat::parse_str(
900 r#"
901 (module
902 (memory (export "memory") 1)
903 (func (export "alloc") (param i32) (result i32)
904 i32.const 0
905 )
906 (func (export "process") (param i32 i32) (result i64)
907 i64.const 0
908 )
909 )
910 "#,
911 )
912 .expect("valid wat");
913
914 let config = RuntimeConfig::default();
915 let result = Plugin::load_bytes("test_plugin", &wasm, &config);
916 assert!(result.is_ok());
917
918 let plugin = result.unwrap();
919 assert_eq!(plugin.name(), "test_plugin");
920 }
921
922 #[test]
923 fn test_plugin_manager_load_bytes() {
924 let wasm = wat::parse_str(
925 r#"
926 (module
927 (memory (export "memory") 1)
928 (func (export "alloc") (param i32) (result i32)
929 i32.const 0
930 )
931 (func (export "process") (param i32 i32) (result i64)
932 i64.const 0
933 )
934 )
935 "#,
936 )
937 .expect("valid wat");
938
939 let mut manager = PluginManager::new();
940 let result = manager.load_bytes("my_plugin", &wasm);
941 assert!(result.is_ok());
942 assert_eq!(result.unwrap(), 0); assert_eq!(manager.len(), 1);
944 assert!(!manager.is_empty());
945 }
946
947 #[test]
948 fn test_plugin_manager_multiple_plugins() {
949 let wasm = wat::parse_str(
950 r#"
951 (module
952 (memory (export "memory") 1)
953 (func (export "alloc") (param i32) (result i32)
954 i32.const 0
955 )
956 (func (export "process") (param i32 i32) (result i64)
957 i64.const 0
958 )
959 )
960 "#,
961 )
962 .expect("valid wat");
963
964 let mut manager = PluginManager::new();
965 manager.load_bytes("plugin1", &wasm).unwrap();
966 manager.load_bytes("plugin2", &wasm).unwrap();
967 manager.load_bytes("plugin3", &wasm).unwrap();
968
969 assert_eq!(manager.len(), 3);
970 }
971
972 #[test]
973 fn test_validate_truncated_wasm() {
974 let truncated = &[0x00, 0x61, 0x73, 0x6d]; let result = validate_plugin_module(truncated);
977 assert!(result.is_err());
978 }
979
980 #[test]
981 fn test_validate_wrong_magic() {
982 let wrong_magic = &[0xFF, 0xFF, 0xFF, 0xFF];
983 let result = validate_plugin_module(wrong_magic);
984 assert!(result.is_err());
985 }
986
987 #[test]
988 fn test_validate_empty_wasm() {
989 let empty: &[u8] = &[];
990 let result = validate_plugin_module(empty);
991 assert!(result.is_err());
992 }
993}