Skip to main content

suture_driver/
plugin.rs

1use crate::SutureDriver;
2use std::collections::HashMap;
3use std::path::Path;
4use std::sync::Arc;
5
6#[derive(Debug, thiserror::Error)]
7pub enum PluginError {
8    #[error("failed to load plugin: {0}")]
9    LoadFailed(String),
10    #[error("missing required export: {0}")]
11    MissingExport(String),
12    #[error("plugin ABI version mismatch: expected {expected}, got {actual}")]
13    AbiVersionMismatch { expected: i32, actual: i32 },
14    #[cfg(feature = "wasm-plugins")]
15    #[error("wasmtime error: {0}")]
16    Wasmtime(#[from] wasmtime::Error),
17}
18
19pub trait DriverPlugin: Send + Sync {
20    fn name(&self) -> &str;
21    fn extensions(&self) -> &[&str];
22    fn description(&self) -> &str;
23    fn as_driver(&self) -> &dyn SutureDriver;
24}
25
26pub struct BuiltinDriverPlugin<D> {
27    name: &'static str,
28    extensions: Vec<&'static str>,
29    description: &'static str,
30    driver: D,
31}
32
33impl<D: SutureDriver + Send + Sync + 'static> BuiltinDriverPlugin<D> {
34    pub fn new(
35        name: &'static str,
36        extensions: Vec<&'static str>,
37        description: &'static str,
38        driver: D,
39    ) -> Self {
40        Self {
41            name,
42            extensions,
43            description,
44            driver,
45        }
46    }
47}
48
49impl<D: SutureDriver + Send + Sync + 'static> DriverPlugin for BuiltinDriverPlugin<D> {
50    fn name(&self) -> &str {
51        self.name
52    }
53    fn extensions(&self) -> &[&str] {
54        &self.extensions
55    }
56    fn description(&self) -> &str {
57        self.description
58    }
59    fn as_driver(&self) -> &dyn SutureDriver {
60        &self.driver
61    }
62}
63
64pub struct PluginRegistry {
65    plugins: HashMap<String, Arc<dyn DriverPlugin>>,
66    extension_map: HashMap<String, String>,
67}
68
69impl PluginRegistry {
70    pub fn new() -> Self {
71        Self {
72            plugins: HashMap::new(),
73            extension_map: HashMap::new(),
74        }
75    }
76
77    pub fn register(&mut self, plugin: Arc<dyn DriverPlugin>) {
78        let name = plugin.name().to_string();
79        for ext in plugin.extensions() {
80            self.extension_map.insert(ext.to_string(), name.clone());
81        }
82        self.plugins.insert(name, plugin);
83    }
84
85    pub fn get(&self, name: &str) -> Option<&dyn DriverPlugin> {
86        self.plugins.get(name).map(|p| p.as_ref())
87    }
88
89    pub fn get_by_extension(&self, ext: &str) -> Option<&dyn DriverPlugin> {
90        let normalized = if ext.starts_with('.') {
91            ext.to_string()
92        } else {
93            format!(".{}", ext)
94        };
95        self.extension_map
96            .get(&normalized)
97            .and_then(|name| self.plugins.get(name).map(|p| p.as_ref()))
98    }
99
100    pub fn list_drivers(&self) -> Vec<&str> {
101        let mut names: Vec<&str> = self.plugins.keys().map(|s| s.as_str()).collect();
102        names.sort();
103        names
104    }
105
106    pub fn discover_plugins(&mut self, plugin_dir: &Path) {
107        if !plugin_dir.exists() {
108            return;
109        }
110
111        if let Ok(entries) = std::fs::read_dir(plugin_dir) {
112            let mut sorted_entries: Vec<_> = entries.flatten().collect();
113            sorted_entries.sort_by_key(|e| e.file_name());
114            for entry in sorted_entries {
115                let path = entry.path();
116                if path
117                    .extension()
118                    .map(|e| e == "suture-plugin")
119                    .unwrap_or(false)
120                    && let Ok(content) = std::fs::read_to_string(&path)
121                    && let Some(desc) = Self::parse_plugin_descriptor(&content)
122                {
123                    let _ = desc; // plugin descriptor found; future: dynamic loading
124                }
125            }
126        }
127    }
128
129    fn parse_plugin_descriptor(content: &str) -> Option<PluginDescriptor> {
130        let mut name = None;
131        let mut extensions = Vec::new();
132        let mut description = String::new();
133
134        for line in content.lines() {
135            let line = line.trim();
136            if let Some(val) = line
137                .strip_prefix("name")
138                .and_then(Self::extract_string_value)
139            {
140                name = Some(val);
141            } else if let Some(start) = line.find('[') {
142                if let Some(end) = line[start..].find(']') {
143                    let inner = &line[start + 1..start + end];
144                    for ext in inner.split(',') {
145                        let ext = ext.trim().trim_matches('"');
146                        if !ext.is_empty() {
147                            extensions.push(ext.to_string());
148                        }
149                    }
150                }
151            } else if let Some(val) = line
152                .strip_prefix("description")
153                .and_then(Self::extract_string_value)
154            {
155                description = val;
156            }
157        }
158
159        name.map(|name| PluginDescriptor {
160            name,
161            extensions,
162            description,
163        })
164    }
165
166    fn extract_string_value(line: &str) -> Option<String> {
167        if let Some(eq_pos) = line.find('=') {
168            let val = line[eq_pos + 1..].trim();
169            if val.starts_with('"') && val.ends_with('"') {
170                return Some(val[1..val.len() - 1].to_string());
171            }
172        }
173        None
174    }
175}
176
177#[allow(dead_code)]
178struct PluginDescriptor {
179    name: String,
180    extensions: Vec<String>,
181    #[allow(dead_code)]
182    description: String,
183}
184
185impl Default for PluginRegistry {
186    fn default() -> Self {
187        Self::new()
188    }
189}
190
191impl PluginRegistry {
192    #[cfg(feature = "wasm-plugins")]
193    pub fn load_wasm_plugin(&mut self, path: &Path) -> Result<(), PluginError> {
194        let plugin = WasmDriverPlugin::from_file(path)?;
195        let name = plugin.name.clone();
196        let extensions: Vec<String> = plugin.extensions_storage.clone();
197        let plugin_arc = Arc::new(plugin);
198        for ext in &extensions {
199            self.extension_map.insert(ext.to_string(), name.clone());
200        }
201        self.plugins.insert(name, plugin_arc);
202        Ok(())
203    }
204}
205
206#[cfg(feature = "wasm-plugins")]
207pub struct WasmDriverPlugin {
208    #[allow(dead_code)]
209    engine: wasmtime::Engine,
210    #[allow(dead_code)]
211    instance: wasmtime::Instance,
212    #[allow(dead_code)]
213    store: wasmtime::Store<()>,
214    name: String,
215    extensions_storage: Vec<String>,
216    extensions: Vec<&'static str>,
217}
218
219#[cfg(feature = "wasm-plugins")]
220impl WasmDriverPlugin {
221    pub fn from_file(path: &std::path::Path) -> Result<Self, PluginError> {
222        let engine = wasmtime::Engine::default();
223        let module = wasmtime::Module::from_file(&engine, path)
224            .map_err(|e| PluginError::LoadFailed(e.to_string()))?;
225
226        let mut store = wasmtime::Store::new(&engine, ());
227        let mut linker = wasmtime::Linker::new(&engine);
228        // TODO: Add `suture.alloc` host import for WASM memory allocation.
229        // WASM modules that call `suture.alloc` will fail at link time until this is implemented.
230
231        let instance = linker
232            .instantiate(&mut store, &module)
233            .map_err(|e| PluginError::LoadFailed(e.to_string()))?;
234
235        let version = Self::call_version_export(&mut store, &instance)?;
236        if version != 1 {
237            return Err(PluginError::AbiVersionMismatch {
238                expected: 1,
239                actual: version,
240            });
241        }
242
243        let name = Self::call_string_export(&mut store, &instance, "plugin_name")
244            .unwrap_or_else(|| "unknown".to_string());
245        let extensions_storage = Self::call_extensions_export(&mut store, &instance);
246
247        let extensions: Vec<&str> = extensions_storage
248            .iter()
249            .map(|s| {
250                let leaked: &'static str = Box::leak(s.clone().into_boxed_str());
251                leaked
252            })
253            .collect();
254
255        let plugin = Self {
256            engine,
257            instance,
258            store,
259            name,
260            extensions_storage,
261            extensions,
262        };
263
264        Ok(plugin)
265    }
266
267    fn call_version_export(
268        store: &mut wasmtime::Store<()>,
269        instance: &wasmtime::Instance,
270    ) -> Result<i32, PluginError> {
271        let func = instance
272            .get_typed_func::<(), i32>(&mut *store, "plugin_version")
273            .map_err(|_| PluginError::MissingExport("plugin_version".to_string()))?;
274        let version = func.call(&mut *store, ())?;
275        Ok(version)
276    }
277
278    fn call_string_export(
279        store: &mut wasmtime::Store<()>,
280        instance: &wasmtime::Instance,
281        export_name: &str,
282    ) -> Option<String> {
283        let Ok(func) = instance.get_typed_func::<(), i32>(&mut *store, export_name) else {
284            return None;
285        };
286        let Ok(ptr) = func.call(&mut *store, ()) else {
287            return None;
288        };
289
290        let memory = instance.get_memory(&mut *store, "memory")?;
291
292        let mut buf = Vec::new();
293        let mut offset = ptr as usize;
294        loop {
295            if offset >= memory.data_size(&mut *store) {
296                return None;
297            }
298            let byte = memory.data(&mut *store)[offset];
299            if byte == 0 {
300                break;
301            }
302            buf.push(byte);
303            offset += 1;
304        }
305        String::from_utf8(buf).ok()
306    }
307
308    fn call_extensions_export(
309        store: &mut wasmtime::Store<()>,
310        instance: &wasmtime::Instance,
311    ) -> Vec<String> {
312        match Self::call_string_export(store, instance, "plugin_extensions") {
313            Some(csv) => csv
314                .split(',')
315                .map(|s| s.trim().to_string())
316                .filter(|s| !s.is_empty())
317                .collect(),
318            None => vec![],
319        }
320    }
321}
322
323#[cfg(feature = "wasm-plugins")]
324impl DriverPlugin for WasmDriverPlugin {
325    fn name(&self) -> &str {
326        &self.name
327    }
328
329    fn extensions(&self) -> &[&str] {
330        &self.extensions
331    }
332
333    fn description(&self) -> &str {
334        "WASM plugin driver"
335    }
336
337    fn as_driver(&self) -> &dyn SutureDriver {
338        self
339    }
340}
341
342#[cfg(feature = "wasm-plugins")]
343impl SutureDriver for WasmDriverPlugin {
344    fn name(&self) -> &str {
345        &self.name
346    }
347
348    fn supported_extensions(&self) -> &[&str] {
349        &self.extensions
350    }
351
352    fn diff(
353        &self,
354        _base_content: Option<&str>,
355        _new_content: &str,
356    ) -> Result<Vec<crate::SemanticChange>, crate::DriverError> {
357        Err(crate::DriverError::ParseError(
358            "WASM plugin diff is not yet implemented".to_string(),
359        ))
360    }
361
362    fn format_diff(
363        &self,
364        _base_content: Option<&str>,
365        _new_content: &str,
366    ) -> Result<String, crate::DriverError> {
367        Err(crate::DriverError::ParseError(
368            "WASM plugin format_diff is not yet implemented".to_string(),
369        ))
370    }
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376    use crate::DriverError;
377
378    struct MockDriver {
379        driver_name: &'static str,
380        driver_extensions: Vec<&'static str>,
381    }
382
383    impl MockDriver {
384        fn new(name: &'static str, extensions: Vec<&'static str>) -> Self {
385            Self {
386                driver_name: name,
387                driver_extensions: extensions,
388            }
389        }
390    }
391
392    impl SutureDriver for MockDriver {
393        fn name(&self) -> &str {
394            self.driver_name
395        }
396        fn supported_extensions(&self) -> &[&str] {
397            &self.driver_extensions
398        }
399        fn diff(
400            &self,
401            _base_content: Option<&str>,
402            _new_content: &str,
403        ) -> Result<Vec<crate::SemanticChange>, DriverError> {
404            Ok(vec![])
405        }
406        fn format_diff(
407            &self,
408            _base_content: Option<&str>,
409            _new_content: &str,
410        ) -> Result<String, DriverError> {
411            Ok(String::new())
412        }
413    }
414
415    fn make_plugin(name: &'static str, extensions: Vec<&'static str>) -> Arc<dyn DriverPlugin> {
416        Arc::new(BuiltinDriverPlugin::new(
417            name,
418            extensions.clone(),
419            "test driver",
420            MockDriver::new(name, extensions),
421        ))
422    }
423
424    #[test]
425    fn register_and_get_by_name() {
426        let mut reg = PluginRegistry::new();
427        reg.register(make_plugin("json", vec![".json"]));
428        assert!(reg.get("json").is_some());
429        assert!(reg.get("yaml").is_none());
430        assert_eq!(reg.get("json").unwrap().name(), "json");
431    }
432
433    #[test]
434    fn get_by_extension_with_dot() {
435        let mut reg = PluginRegistry::new();
436        reg.register(make_plugin("json", vec![".json"]));
437        assert!(reg.get_by_extension(".json").is_some());
438        assert!(reg.get_by_extension(".yaml").is_none());
439    }
440
441    #[test]
442    fn get_by_extension_without_dot() {
443        let mut reg = PluginRegistry::new();
444        reg.register(make_plugin("yaml", vec![".yaml", ".yml"]));
445        assert!(reg.get_by_extension("yaml").is_some());
446        assert!(reg.get_by_extension("yml").is_some());
447    }
448
449    #[test]
450    fn list_drivers_sorted() {
451        let mut reg = PluginRegistry::new();
452        reg.register(make_plugin("csv", vec![".csv"]));
453        reg.register(make_plugin("xml", vec![".xml"]));
454        reg.register(make_plugin("json", vec![".json"]));
455        assert_eq!(reg.list_drivers(), vec!["csv", "json", "xml"]);
456    }
457
458    #[test]
459    fn discover_plugins_nonexistent_dir() {
460        let mut reg = PluginRegistry::new();
461        reg.discover_plugins(Path::new("/tmp/suture-test-nonexistent-12345"));
462        assert!(reg.list_drivers().is_empty());
463    }
464
465    #[test]
466    fn parse_plugin_descriptor_valid() {
467        let content = r#"
468name = "my-driver"
469extensions = [".custom", ".ext"]
470description = "A custom driver"
471"#;
472        let desc = PluginRegistry::parse_plugin_descriptor(content).unwrap();
473        assert_eq!(desc.name, "my-driver");
474        assert_eq!(desc.extensions, vec![".custom", ".ext"]);
475        assert_eq!(desc.description, "A custom driver");
476    }
477
478    #[test]
479    fn parse_plugin_descriptor_missing_name() {
480        let content = r#"extensions = [".custom"]"#;
481        assert!(PluginRegistry::parse_plugin_descriptor(content).is_none());
482    }
483
484    #[test]
485    fn as_driver_returns_underlying_driver() {
486        let mut reg = PluginRegistry::new();
487        reg.register(make_plugin("json", vec![".json"]));
488        let plugin = reg.get("json").unwrap();
489        assert_eq!(plugin.as_driver().name(), "json");
490        assert_eq!(plugin.as_driver().supported_extensions(), &[".json"]);
491    }
492
493    #[cfg(feature = "wasm-plugins")]
494    #[test]
495    fn test_wasm_plugin_abi_documentation() {
496        let abi_path = Path::new(env!("CARGO_MANIFEST_DIR")).join("src/wasm_abi.md");
497        let content = std::fs::read_to_string(&abi_path).expect("wasm_abi.md should exist");
498        assert!(
499            content.contains("plugin_name"),
500            "ABI doc should define plugin_name export"
501        );
502        assert!(
503            content.contains("plugin_extensions"),
504            "ABI doc should define plugin_extensions export"
505        );
506        assert!(
507            content.contains("plugin_version"),
508            "ABI doc should define plugin_version export"
509        );
510        assert!(
511            content.contains("merge"),
512            "ABI doc should define merge function"
513        );
514        assert!(
515            content.contains("diff"),
516            "ABI doc should define diff function"
517        );
518        assert!(
519            content.contains("ABI Version"),
520            "ABI doc should specify version"
521        );
522        assert!(
523            content.contains("Memory Layout"),
524            "ABI doc should specify memory layout"
525        );
526        assert!(
527            content.contains("Error Handling"),
528            "ABI doc should specify error handling"
529        );
530    }
531
532    #[cfg(feature = "wasm-plugins")]
533    #[test]
534    fn test_plugin_registry_load_wasm_missing_file() {
535        let mut reg = PluginRegistry::new();
536        let result = reg.load_wasm_plugin(Path::new("/tmp/nonexistent-plugin.wasm"));
537        assert!(result.is_err());
538        match result.unwrap_err() {
539            PluginError::LoadFailed(msg) => {
540                assert!(msg.contains("failed to read") || msg.contains("No such file"));
541            }
542            other => panic!("expected LoadFailed, got: {other}"),
543        }
544    }
545
546    #[cfg(feature = "wasm-plugins")]
547    #[test]
548    fn test_plugin_registry_load_wasm_invalid_module() {
549        let dir = std::env::temp_dir().join("suture-wasm-test-invalid");
550        std::fs::create_dir_all(&dir).unwrap();
551        let path = dir.join("invalid.wasm");
552        std::fs::write(&path, b"not a valid wasm module").unwrap();
553
554        let mut reg = PluginRegistry::new();
555        let result = reg.load_wasm_plugin(&path);
556        assert!(result.is_err());
557
558        let _ = std::fs::remove_dir_all(&dir);
559    }
560}