Skip to main content

suture_driver/
plugin.rs

1use crate::SutureDriver;
2use std::collections::HashMap;
3use std::path::Path;
4use std::sync::Arc;
5
6#[derive(Debug, thiserror::Error)]
7pub enum PluginError {
8    #[error("failed to load plugin: {0}")]
9    LoadFailed(String),
10    #[error("missing required export: {0}")]
11    MissingExport(String),
12    #[error("plugin ABI version mismatch: expected {expected}, got {actual}")]
13    AbiVersionMismatch { expected: i32, actual: i32 },
14    #[cfg(feature = "wasm-plugins")]
15    #[error("wasmtime error: {0}")]
16    Wasmtime(#[from] wasmtime::Error),
17}
18
19pub trait DriverPlugin: Send + Sync {
20    fn name(&self) -> &str;
21    fn extensions(&self) -> &[&str];
22    fn description(&self) -> &str;
23    fn as_driver(&self) -> &dyn SutureDriver;
24}
25
26pub struct BuiltinDriverPlugin<D> {
27    name: &'static str,
28    extensions: Vec<&'static str>,
29    description: &'static str,
30    driver: D,
31}
32
33impl<D: SutureDriver + Send + Sync + 'static> BuiltinDriverPlugin<D> {
34    pub fn new(
35        name: &'static str,
36        extensions: Vec<&'static str>,
37        description: &'static str,
38        driver: D,
39    ) -> Self {
40        Self {
41            name,
42            extensions,
43            description,
44            driver,
45        }
46    }
47}
48
49impl<D: SutureDriver + Send + Sync + 'static> DriverPlugin for BuiltinDriverPlugin<D> {
50    fn name(&self) -> &str {
51        self.name
52    }
53    fn extensions(&self) -> &[&str] {
54        &self.extensions
55    }
56    fn description(&self) -> &str {
57        self.description
58    }
59    fn as_driver(&self) -> &dyn SutureDriver {
60        &self.driver
61    }
62}
63
64pub struct PluginRegistry {
65    plugins: HashMap<String, Arc<dyn DriverPlugin>>,
66    extension_map: HashMap<String, String>,
67}
68
69impl PluginRegistry {
70    pub fn new() -> Self {
71        Self {
72            plugins: HashMap::new(),
73            extension_map: HashMap::new(),
74        }
75    }
76
77    pub fn register(&mut self, plugin: Arc<dyn DriverPlugin>) {
78        let name = plugin.name().to_string();
79        for ext in plugin.extensions() {
80            self.extension_map.insert(ext.to_string(), name.clone());
81        }
82        self.plugins.insert(name, plugin);
83    }
84
85    pub fn get(&self, name: &str) -> Option<&dyn DriverPlugin> {
86        self.plugins.get(name).map(|p| p.as_ref())
87    }
88
89    pub fn get_by_extension(&self, ext: &str) -> Option<&dyn DriverPlugin> {
90        let normalized = if ext.starts_with('.') {
91            ext.to_string()
92        } else {
93            format!(".{}", ext)
94        };
95        self.extension_map
96            .get(&normalized)
97            .and_then(|name| self.plugins.get(name).map(|p| p.as_ref()))
98    }
99
100    pub fn list_drivers(&self) -> Vec<&str> {
101        let mut names: Vec<&str> = self.plugins.keys().map(|s| s.as_str()).collect();
102        names.sort();
103        names
104    }
105
106    pub fn discover_plugins(&mut self, plugin_dir: &Path) {
107        if !plugin_dir.exists() {
108            return;
109        }
110
111        if let Ok(entries) = std::fs::read_dir(plugin_dir) {
112            for entry in entries.flatten() {
113                let path = entry.path();
114                if path
115                    .extension()
116                    .map(|e| e == "suture-plugin")
117                    .unwrap_or(false)
118                    && let Ok(content) = std::fs::read_to_string(&path)
119                    && let Some(desc) = Self::parse_plugin_descriptor(&content)
120                {
121                    let _ = desc; // plugin descriptor found; future: dynamic loading
122                }
123            }
124        }
125    }
126
127    fn parse_plugin_descriptor(content: &str) -> Option<PluginDescriptor> {
128        let mut name = None;
129        let mut extensions = Vec::new();
130        let mut description = String::new();
131
132        for line in content.lines() {
133            let line = line.trim();
134            if let Some(val) = line
135                .strip_prefix("name")
136                .and_then(Self::extract_string_value)
137            {
138                name = Some(val);
139            } else if let Some(start) = line.find('[') {
140                if let Some(end) = line[start..].find(']') {
141                    let inner = &line[start + 1..start + end];
142                    for ext in inner.split(',') {
143                        let ext = ext.trim().trim_matches('"');
144                        if !ext.is_empty() {
145                            extensions.push(ext.to_string());
146                        }
147                    }
148                }
149            } else if let Some(val) = line
150                .strip_prefix("description")
151                .and_then(Self::extract_string_value)
152            {
153                description = val;
154            }
155        }
156
157        name.map(|name| PluginDescriptor {
158            name,
159            extensions,
160            description,
161        })
162    }
163
164    fn extract_string_value(line: &str) -> Option<String> {
165        if let Some(eq_pos) = line.find('=') {
166            let val = line[eq_pos + 1..].trim();
167            if val.starts_with('"') && val.ends_with('"') {
168                return Some(val[1..val.len() - 1].to_string());
169            }
170        }
171        None
172    }
173}
174
175#[allow(dead_code)]
176struct PluginDescriptor {
177    name: String,
178    extensions: Vec<String>,
179    #[allow(dead_code)]
180    description: String,
181}
182
183impl Default for PluginRegistry {
184    fn default() -> Self {
185        Self::new()
186    }
187}
188
189impl PluginRegistry {
190    #[cfg(feature = "wasm-plugins")]
191    pub fn load_wasm_plugin(&mut self, path: &Path) -> Result<(), PluginError> {
192        let plugin = WasmDriverPlugin::from_file(path)?;
193        let name = plugin.name.clone();
194        let extensions: Vec<String> = plugin.extensions_storage.clone();
195        let plugin_arc = Arc::new(plugin);
196        for ext in &extensions {
197            self.extension_map.insert(ext.to_string(), name.clone());
198        }
199        self.plugins.insert(name, plugin_arc);
200        Ok(())
201    }
202}
203
204#[cfg(feature = "wasm-plugins")]
205pub struct WasmDriverPlugin {
206    #[allow(dead_code)]
207    engine: wasmtime::Engine,
208    #[allow(dead_code)]
209    instance: wasmtime::Instance,
210    #[allow(dead_code)]
211    store: wasmtime::Store<()>,
212    name: String,
213    extensions_storage: Vec<String>,
214    extensions: Vec<&'static str>,
215}
216
217#[cfg(feature = "wasm-plugins")]
218impl WasmDriverPlugin {
219    pub fn from_file(path: &std::path::Path) -> Result<Self, PluginError> {
220        let engine = wasmtime::Engine::default();
221        let module = wasmtime::Module::from_file(&engine, path)
222            .map_err(|e| PluginError::LoadFailed(e.to_string()))?;
223
224        let mut store = wasmtime::Store::new(&engine, ());
225        let linker = wasmtime::Linker::new(&engine);
226
227        let instance = linker
228            .instantiate(&mut store, &module)
229            .map_err(|e| PluginError::LoadFailed(e.to_string()))?;
230
231        let version = Self::call_version_export(&mut store, &instance)?;
232        if version != 1 {
233            return Err(PluginError::AbiVersionMismatch {
234                expected: 1,
235                actual: version,
236            });
237        }
238
239        let name =
240            Self::call_string_export(&mut store, &instance, "plugin_name")
241                .unwrap_or_else(|| "unknown".to_string());
242        let extensions_storage = Self::call_extensions_export(&mut store, &instance);
243
244        let mut plugin = Self {
245            engine,
246            instance,
247            store,
248            name,
249            extensions_storage,
250            extensions: Vec::new(),
251        };
252
253        // SAFETY: extensions_storage is never modified after this point,
254        // so the raw pointers remain valid for the lifetime of the struct.
255        unsafe {
256            plugin.extensions = plugin
257                .extensions_storage
258                .iter()
259                .map(|s| {
260                    let ptr: *const str = s.as_str();
261                    &*ptr
262                })
263                .collect();
264        }
265
266        Ok(plugin)
267    }
268
269    fn call_version_export(
270        store: &mut wasmtime::Store<()>,
271        instance: &wasmtime::Instance,
272    ) -> Result<i32, PluginError> {
273        let func = instance
274            .get_typed_func::<(), i32>(&mut *store, "plugin_version")
275            .map_err(|_| PluginError::MissingExport("plugin_version".to_string()))?;
276        let version = func.call(&mut *store, ())?;
277        Ok(version)
278    }
279
280    fn call_string_export(
281        store: &mut wasmtime::Store<()>,
282        instance: &wasmtime::Instance,
283        export_name: &str,
284    ) -> Option<String> {
285        let Ok(func) = instance.get_typed_func::<(), i32>(&mut *store, export_name) else {
286            return None;
287        };
288        let Ok(ptr) = func.call(&mut *store, ()) else {
289            return None;
290        };
291
292        let memory = instance
293            .get_memory(&mut *store, "memory")?;
294
295        let mut buf = Vec::new();
296        let mut offset = ptr as usize;
297        loop {
298            if offset >= memory.data_size(&mut *store) {
299                return None;
300            }
301            let byte = memory.data(&mut *store)[offset];
302            if byte == 0 {
303                break;
304            }
305            buf.push(byte);
306            offset += 1;
307        }
308        String::from_utf8(buf).ok()
309    }
310
311    fn call_extensions_export(
312        store: &mut wasmtime::Store<()>,
313        instance: &wasmtime::Instance,
314    ) -> Vec<String> {
315        match Self::call_string_export(store, instance, "plugin_extensions") {
316            Some(csv) => csv
317                .split(',')
318                .map(|s| s.trim().to_string())
319                .filter(|s| !s.is_empty())
320                .collect(),
321            None => vec![],
322        }
323    }
324}
325
326#[cfg(feature = "wasm-plugins")]
327impl DriverPlugin for WasmDriverPlugin {
328    fn name(&self) -> &str {
329        &self.name
330    }
331
332    fn extensions(&self) -> &[&str] {
333        &self.extensions
334    }
335
336    fn description(&self) -> &str {
337        "WASM plugin driver"
338    }
339
340    fn as_driver(&self) -> &dyn SutureDriver {
341        self
342    }
343}
344
345#[cfg(feature = "wasm-plugins")]
346impl SutureDriver for WasmDriverPlugin {
347    fn name(&self) -> &str {
348        &self.name
349    }
350
351    fn supported_extensions(&self) -> &[&str] {
352        &self.extensions
353    }
354
355    fn diff(
356        &self,
357        _base_content: Option<&str>,
358        _new_content: &str,
359    ) -> Result<Vec<crate::SemanticChange>, crate::DriverError> {
360        Ok(vec![])
361    }
362
363    fn format_diff(
364        &self,
365        _base_content: Option<&str>,
366        _new_content: &str,
367    ) -> Result<String, crate::DriverError> {
368        Ok(String::new())
369    }
370}
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375    use crate::DriverError;
376
377    struct MockDriver {
378        driver_name: &'static str,
379        driver_extensions: Vec<&'static str>,
380    }
381
382    impl MockDriver {
383        fn new(name: &'static str, extensions: Vec<&'static str>) -> Self {
384            Self {
385                driver_name: name,
386                driver_extensions: extensions,
387            }
388        }
389    }
390
391    impl SutureDriver for MockDriver {
392        fn name(&self) -> &str {
393            self.driver_name
394        }
395        fn supported_extensions(&self) -> &[&str] {
396            &self.driver_extensions
397        }
398        fn diff(
399            &self,
400            _base_content: Option<&str>,
401            _new_content: &str,
402        ) -> Result<Vec<crate::SemanticChange>, DriverError> {
403            Ok(vec![])
404        }
405        fn format_diff(
406            &self,
407            _base_content: Option<&str>,
408            _new_content: &str,
409        ) -> Result<String, DriverError> {
410            Ok(String::new())
411        }
412    }
413
414    fn make_plugin(name: &'static str, extensions: Vec<&'static str>) -> Arc<dyn DriverPlugin> {
415        Arc::new(BuiltinDriverPlugin::new(
416            name,
417            extensions.clone(),
418            "test driver",
419            MockDriver::new(name, extensions),
420        ))
421    }
422
423    #[test]
424    fn register_and_get_by_name() {
425        let mut reg = PluginRegistry::new();
426        reg.register(make_plugin("json", vec![".json"]));
427        assert!(reg.get("json").is_some());
428        assert!(reg.get("yaml").is_none());
429        assert_eq!(reg.get("json").unwrap().name(), "json");
430    }
431
432    #[test]
433    fn get_by_extension_with_dot() {
434        let mut reg = PluginRegistry::new();
435        reg.register(make_plugin("json", vec![".json"]));
436        assert!(reg.get_by_extension(".json").is_some());
437        assert!(reg.get_by_extension(".yaml").is_none());
438    }
439
440    #[test]
441    fn get_by_extension_without_dot() {
442        let mut reg = PluginRegistry::new();
443        reg.register(make_plugin("yaml", vec![".yaml", ".yml"]));
444        assert!(reg.get_by_extension("yaml").is_some());
445        assert!(reg.get_by_extension("yml").is_some());
446    }
447
448    #[test]
449    fn list_drivers_sorted() {
450        let mut reg = PluginRegistry::new();
451        reg.register(make_plugin("csv", vec![".csv"]));
452        reg.register(make_plugin("xml", vec![".xml"]));
453        reg.register(make_plugin("json", vec![".json"]));
454        assert_eq!(reg.list_drivers(), vec!["csv", "json", "xml"]);
455    }
456
457    #[test]
458    fn discover_plugins_nonexistent_dir() {
459        let mut reg = PluginRegistry::new();
460        reg.discover_plugins(Path::new("/tmp/suture-test-nonexistent-12345"));
461        assert!(reg.list_drivers().is_empty());
462    }
463
464    #[test]
465    fn parse_plugin_descriptor_valid() {
466        let content = r#"
467name = "my-driver"
468extensions = [".custom", ".ext"]
469description = "A custom driver"
470"#;
471        let desc = PluginRegistry::parse_plugin_descriptor(content).unwrap();
472        assert_eq!(desc.name, "my-driver");
473        assert_eq!(desc.extensions, vec![".custom", ".ext"]);
474        assert_eq!(desc.description, "A custom driver");
475    }
476
477    #[test]
478    fn parse_plugin_descriptor_missing_name() {
479        let content = r#"extensions = [".custom"]"#;
480        assert!(PluginRegistry::parse_plugin_descriptor(content).is_none());
481    }
482
483    #[test]
484    fn as_driver_returns_underlying_driver() {
485        let mut reg = PluginRegistry::new();
486        reg.register(make_plugin("json", vec![".json"]));
487        let plugin = reg.get("json").unwrap();
488        assert_eq!(plugin.as_driver().name(), "json");
489        assert_eq!(plugin.as_driver().supported_extensions(), &[".json"]);
490    }
491
492    #[cfg(feature = "wasm-plugins")]
493    #[test]
494    fn test_wasm_plugin_abi_documentation() {
495        let abi_path = Path::new(env!("CARGO_MANIFEST_DIR")).join("src/wasm_abi.md");
496        let content = std::fs::read_to_string(&abi_path)
497            .expect("wasm_abi.md should exist");
498        assert!(content.contains("plugin_name"), "ABI doc should define plugin_name export");
499        assert!(content.contains("plugin_extensions"), "ABI doc should define plugin_extensions export");
500        assert!(content.contains("plugin_version"), "ABI doc should define plugin_version export");
501        assert!(content.contains("merge"), "ABI doc should define merge function");
502        assert!(content.contains("diff"), "ABI doc should define diff function");
503        assert!(content.contains("ABI Version"), "ABI doc should specify version");
504        assert!(content.contains("Memory Layout"), "ABI doc should specify memory layout");
505        assert!(content.contains("Error Handling"), "ABI doc should specify error handling");
506    }
507
508    #[cfg(feature = "wasm-plugins")]
509    #[test]
510    fn test_plugin_registry_load_wasm_missing_file() {
511        let mut reg = PluginRegistry::new();
512        let result = reg.load_wasm_plugin(Path::new("/tmp/nonexistent-plugin.wasm"));
513        assert!(result.is_err());
514        match result.unwrap_err() {
515            PluginError::LoadFailed(msg) => {
516                assert!(msg.contains("failed to read") || msg.contains("No such file"));
517            }
518            other => panic!("expected LoadFailed, got: {other}"),
519        }
520    }
521
522    #[cfg(feature = "wasm-plugins")]
523    #[test]
524    fn test_plugin_registry_load_wasm_invalid_module() {
525        let dir = std::env::temp_dir().join("suture-wasm-test-invalid");
526        std::fs::create_dir_all(&dir).unwrap();
527        let path = dir.join("invalid.wasm");
528        std::fs::write(&path, b"not a valid wasm module").unwrap();
529
530        let mut reg = PluginRegistry::new();
531        let result = reg.load_wasm_plugin(&path);
532        assert!(result.is_err());
533
534        let _ = std::fs::remove_dir_all(&dir);
535    }
536}