Skip to main content

suture_driver/
plugin.rs

1use crate::SutureDriver;
2use std::collections::HashMap;
3use std::path::Path;
4use std::sync::Arc;
5
6#[derive(Debug, thiserror::Error)]
7pub enum PluginError {
8    #[error("failed to load plugin: {0}")]
9    LoadFailed(String),
10    #[error("missing required export: {0}")]
11    MissingExport(String),
12    #[error("plugin ABI version mismatch: expected {expected}, got {actual}")]
13    AbiVersionMismatch { expected: i32, actual: i32 },
14    #[cfg(feature = "wasm-plugins")]
15    #[error("wasmtime error: {0}")]
16    Wasmtime(#[from] wasmtime::Error),
17}
18
19pub trait DriverPlugin: Send + Sync {
20    fn name(&self) -> &str;
21    fn extensions(&self) -> &[&str];
22    fn description(&self) -> &str;
23    fn as_driver(&self) -> &dyn SutureDriver;
24}
25
26pub struct BuiltinDriverPlugin<D> {
27    name: &'static str,
28    extensions: Vec<&'static str>,
29    description: &'static str,
30    driver: D,
31}
32
33impl<D: SutureDriver + Send + Sync + 'static> BuiltinDriverPlugin<D> {
34    pub fn new(
35        name: &'static str,
36        extensions: Vec<&'static str>,
37        description: &'static str,
38        driver: D,
39    ) -> Self {
40        Self {
41            name,
42            extensions,
43            description,
44            driver,
45        }
46    }
47}
48
49impl<D: SutureDriver + Send + Sync + 'static> DriverPlugin for BuiltinDriverPlugin<D> {
50    fn name(&self) -> &str {
51        self.name
52    }
53    fn extensions(&self) -> &[&str] {
54        &self.extensions
55    }
56    fn description(&self) -> &str {
57        self.description
58    }
59    fn as_driver(&self) -> &dyn SutureDriver {
60        &self.driver
61    }
62}
63
64pub struct PluginRegistry {
65    plugins: HashMap<String, Arc<dyn DriverPlugin>>,
66    extension_map: HashMap<String, String>,
67}
68
69impl PluginRegistry {
70    pub fn new() -> Self {
71        Self {
72            plugins: HashMap::new(),
73            extension_map: HashMap::new(),
74        }
75    }
76
77    pub fn register(&mut self, plugin: Arc<dyn DriverPlugin>) {
78        let name = plugin.name().to_string();
79        for ext in plugin.extensions() {
80            self.extension_map.insert(ext.to_string(), name.clone());
81        }
82        self.plugins.insert(name, plugin);
83    }
84
85    pub fn get(&self, name: &str) -> Option<&dyn DriverPlugin> {
86        self.plugins.get(name).map(|p| p.as_ref())
87    }
88
89    pub fn get_by_extension(&self, ext: &str) -> Option<&dyn DriverPlugin> {
90        let normalized = if ext.starts_with('.') {
91            ext.to_string()
92        } else {
93            format!(".{}", ext)
94        };
95        self.extension_map
96            .get(&normalized)
97            .and_then(|name| self.plugins.get(name).map(|p| p.as_ref()))
98    }
99
100    pub fn list_drivers(&self) -> Vec<&str> {
101        let mut names: Vec<&str> = self.plugins.keys().map(|s| s.as_str()).collect();
102        names.sort();
103        names
104    }
105
106    pub fn discover_plugins(&mut self, plugin_dir: &Path) {
107        if !plugin_dir.exists() {
108            return;
109        }
110
111        if let Ok(entries) = std::fs::read_dir(plugin_dir) {
112            for entry in entries.flatten() {
113                let path = entry.path();
114                if path
115                    .extension()
116                    .map(|e| e == "suture-plugin")
117                    .unwrap_or(false)
118                    && let Ok(content) = std::fs::read_to_string(&path)
119                    && let Some(desc) = Self::parse_plugin_descriptor(&content)
120                {
121                    let _ = desc; // plugin descriptor found; future: dynamic loading
122                }
123            }
124        }
125    }
126
127    fn parse_plugin_descriptor(content: &str) -> Option<PluginDescriptor> {
128        let mut name = None;
129        let mut extensions = Vec::new();
130        let mut description = String::new();
131
132        for line in content.lines() {
133            let line = line.trim();
134            if let Some(val) = line
135                .strip_prefix("name")
136                .and_then(Self::extract_string_value)
137            {
138                name = Some(val);
139            } else if let Some(start) = line.find('[') {
140                if let Some(end) = line[start..].find(']') {
141                    let inner = &line[start + 1..start + end];
142                    for ext in inner.split(',') {
143                        let ext = ext.trim().trim_matches('"');
144                        if !ext.is_empty() {
145                            extensions.push(ext.to_string());
146                        }
147                    }
148                }
149            } else if let Some(val) = line
150                .strip_prefix("description")
151                .and_then(Self::extract_string_value)
152            {
153                description = val;
154            }
155        }
156
157        name.map(|name| PluginDescriptor {
158            name,
159            extensions,
160            description,
161        })
162    }
163
164    fn extract_string_value(line: &str) -> Option<String> {
165        if let Some(eq_pos) = line.find('=') {
166            let val = line[eq_pos + 1..].trim();
167            if val.starts_with('"') && val.ends_with('"') {
168                return Some(val[1..val.len() - 1].to_string());
169            }
170        }
171        None
172    }
173}
174
175#[allow(dead_code)]
176struct PluginDescriptor {
177    name: String,
178    extensions: Vec<String>,
179    #[allow(dead_code)]
180    description: String,
181}
182
183impl Default for PluginRegistry {
184    fn default() -> Self {
185        Self::new()
186    }
187}
188
189impl PluginRegistry {
190    #[cfg(feature = "wasm-plugins")]
191    pub fn load_wasm_plugin(&mut self, path: &Path) -> Result<(), PluginError> {
192        let plugin = WasmDriverPlugin::from_file(path)?;
193        let name = plugin.name.clone();
194        let extensions: Vec<String> = plugin.extensions_storage.clone();
195        let plugin_arc = Arc::new(plugin);
196        for ext in &extensions {
197            self.extension_map.insert(ext.to_string(), name.clone());
198        }
199        self.plugins.insert(name, plugin_arc);
200        Ok(())
201    }
202}
203
204#[cfg(feature = "wasm-plugins")]
205pub struct WasmDriverPlugin {
206    #[allow(dead_code)]
207    engine: wasmtime::Engine,
208    #[allow(dead_code)]
209    instance: wasmtime::Instance,
210    #[allow(dead_code)]
211    store: wasmtime::Store<()>,
212    name: String,
213    extensions_storage: Vec<String>,
214    extensions: Vec<&'static str>,
215}
216
217#[cfg(feature = "wasm-plugins")]
218impl WasmDriverPlugin {
219    pub fn from_file(path: &std::path::Path) -> Result<Self, PluginError> {
220        let engine = wasmtime::Engine::default();
221        let module = wasmtime::Module::from_file(&engine, path)
222            .map_err(|e| PluginError::LoadFailed(e.to_string()))?;
223
224        let mut store = wasmtime::Store::new(&engine, ());
225        let mut linker = wasmtime::Linker::new(&engine);
226        // TODO: Add `suture.alloc` host import for WASM memory allocation.
227        // WASM modules that call `suture.alloc` will fail at link time until this is implemented.
228
229        let instance = linker
230            .instantiate(&mut store, &module)
231            .map_err(|e| PluginError::LoadFailed(e.to_string()))?;
232
233        let version = Self::call_version_export(&mut store, &instance)?;
234        if version != 1 {
235            return Err(PluginError::AbiVersionMismatch {
236                expected: 1,
237                actual: version,
238            });
239        }
240
241        let name = Self::call_string_export(&mut store, &instance, "plugin_name")
242            .unwrap_or_else(|| "unknown".to_string());
243        let extensions_storage = Self::call_extensions_export(&mut store, &instance);
244
245        let extensions: Vec<&str> = extensions_storage
246            .iter()
247            .map(|s| {
248                let leaked: &'static str = Box::leak(s.clone().into_boxed_str());
249                leaked
250            })
251            .collect();
252
253        let plugin = Self {
254            engine,
255            instance,
256            store,
257            name,
258            extensions_storage,
259            extensions,
260        };
261
262        Ok(plugin)
263    }
264
265    fn call_version_export(
266        store: &mut wasmtime::Store<()>,
267        instance: &wasmtime::Instance,
268    ) -> Result<i32, PluginError> {
269        let func = instance
270            .get_typed_func::<(), i32>(&mut *store, "plugin_version")
271            .map_err(|_| PluginError::MissingExport("plugin_version".to_string()))?;
272        let version = func.call(&mut *store, ())?;
273        Ok(version)
274    }
275
276    fn call_string_export(
277        store: &mut wasmtime::Store<()>,
278        instance: &wasmtime::Instance,
279        export_name: &str,
280    ) -> Option<String> {
281        let Ok(func) = instance.get_typed_func::<(), i32>(&mut *store, export_name) else {
282            return None;
283        };
284        let Ok(ptr) = func.call(&mut *store, ()) else {
285            return None;
286        };
287
288        let memory = instance.get_memory(&mut *store, "memory")?;
289
290        let mut buf = Vec::new();
291        let mut offset = ptr as usize;
292        loop {
293            if offset >= memory.data_size(&mut *store) {
294                return None;
295            }
296            let byte = memory.data(&mut *store)[offset];
297            if byte == 0 {
298                break;
299            }
300            buf.push(byte);
301            offset += 1;
302        }
303        String::from_utf8(buf).ok()
304    }
305
306    fn call_extensions_export(
307        store: &mut wasmtime::Store<()>,
308        instance: &wasmtime::Instance,
309    ) -> Vec<String> {
310        match Self::call_string_export(store, instance, "plugin_extensions") {
311            Some(csv) => csv
312                .split(',')
313                .map(|s| s.trim().to_string())
314                .filter(|s| !s.is_empty())
315                .collect(),
316            None => vec![],
317        }
318    }
319}
320
321#[cfg(feature = "wasm-plugins")]
322impl DriverPlugin for WasmDriverPlugin {
323    fn name(&self) -> &str {
324        &self.name
325    }
326
327    fn extensions(&self) -> &[&str] {
328        &self.extensions
329    }
330
331    fn description(&self) -> &str {
332        "WASM plugin driver"
333    }
334
335    fn as_driver(&self) -> &dyn SutureDriver {
336        self
337    }
338}
339
340#[cfg(feature = "wasm-plugins")]
341impl SutureDriver for WasmDriverPlugin {
342    fn name(&self) -> &str {
343        &self.name
344    }
345
346    fn supported_extensions(&self) -> &[&str] {
347        &self.extensions
348    }
349
350    fn diff(
351        &self,
352        _base_content: Option<&str>,
353        _new_content: &str,
354    ) -> Result<Vec<crate::SemanticChange>, crate::DriverError> {
355        Err(crate::DriverError::ParseError(
356            "WASM plugin diff is not yet implemented".to_string(),
357        ))
358    }
359
360    fn format_diff(
361        &self,
362        _base_content: Option<&str>,
363        _new_content: &str,
364    ) -> Result<String, crate::DriverError> {
365        Err(crate::DriverError::ParseError(
366            "WASM plugin format_diff is not yet implemented".to_string(),
367        ))
368    }
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374    use crate::DriverError;
375
376    struct MockDriver {
377        driver_name: &'static str,
378        driver_extensions: Vec<&'static str>,
379    }
380
381    impl MockDriver {
382        fn new(name: &'static str, extensions: Vec<&'static str>) -> Self {
383            Self {
384                driver_name: name,
385                driver_extensions: extensions,
386            }
387        }
388    }
389
390    impl SutureDriver for MockDriver {
391        fn name(&self) -> &str {
392            self.driver_name
393        }
394        fn supported_extensions(&self) -> &[&str] {
395            &self.driver_extensions
396        }
397        fn diff(
398            &self,
399            _base_content: Option<&str>,
400            _new_content: &str,
401        ) -> Result<Vec<crate::SemanticChange>, DriverError> {
402            Ok(vec![])
403        }
404        fn format_diff(
405            &self,
406            _base_content: Option<&str>,
407            _new_content: &str,
408        ) -> Result<String, DriverError> {
409            Ok(String::new())
410        }
411    }
412
413    fn make_plugin(name: &'static str, extensions: Vec<&'static str>) -> Arc<dyn DriverPlugin> {
414        Arc::new(BuiltinDriverPlugin::new(
415            name,
416            extensions.clone(),
417            "test driver",
418            MockDriver::new(name, extensions),
419        ))
420    }
421
422    #[test]
423    fn register_and_get_by_name() {
424        let mut reg = PluginRegistry::new();
425        reg.register(make_plugin("json", vec![".json"]));
426        assert!(reg.get("json").is_some());
427        assert!(reg.get("yaml").is_none());
428        assert_eq!(reg.get("json").unwrap().name(), "json");
429    }
430
431    #[test]
432    fn get_by_extension_with_dot() {
433        let mut reg = PluginRegistry::new();
434        reg.register(make_plugin("json", vec![".json"]));
435        assert!(reg.get_by_extension(".json").is_some());
436        assert!(reg.get_by_extension(".yaml").is_none());
437    }
438
439    #[test]
440    fn get_by_extension_without_dot() {
441        let mut reg = PluginRegistry::new();
442        reg.register(make_plugin("yaml", vec![".yaml", ".yml"]));
443        assert!(reg.get_by_extension("yaml").is_some());
444        assert!(reg.get_by_extension("yml").is_some());
445    }
446
447    #[test]
448    fn list_drivers_sorted() {
449        let mut reg = PluginRegistry::new();
450        reg.register(make_plugin("csv", vec![".csv"]));
451        reg.register(make_plugin("xml", vec![".xml"]));
452        reg.register(make_plugin("json", vec![".json"]));
453        assert_eq!(reg.list_drivers(), vec!["csv", "json", "xml"]);
454    }
455
456    #[test]
457    fn discover_plugins_nonexistent_dir() {
458        let mut reg = PluginRegistry::new();
459        reg.discover_plugins(Path::new("/tmp/suture-test-nonexistent-12345"));
460        assert!(reg.list_drivers().is_empty());
461    }
462
463    #[test]
464    fn parse_plugin_descriptor_valid() {
465        let content = r#"
466name = "my-driver"
467extensions = [".custom", ".ext"]
468description = "A custom driver"
469"#;
470        let desc = PluginRegistry::parse_plugin_descriptor(content).unwrap();
471        assert_eq!(desc.name, "my-driver");
472        assert_eq!(desc.extensions, vec![".custom", ".ext"]);
473        assert_eq!(desc.description, "A custom driver");
474    }
475
476    #[test]
477    fn parse_plugin_descriptor_missing_name() {
478        let content = r#"extensions = [".custom"]"#;
479        assert!(PluginRegistry::parse_plugin_descriptor(content).is_none());
480    }
481
482    #[test]
483    fn as_driver_returns_underlying_driver() {
484        let mut reg = PluginRegistry::new();
485        reg.register(make_plugin("json", vec![".json"]));
486        let plugin = reg.get("json").unwrap();
487        assert_eq!(plugin.as_driver().name(), "json");
488        assert_eq!(plugin.as_driver().supported_extensions(), &[".json"]);
489    }
490
491    #[cfg(feature = "wasm-plugins")]
492    #[test]
493    fn test_wasm_plugin_abi_documentation() {
494        let abi_path = Path::new(env!("CARGO_MANIFEST_DIR")).join("src/wasm_abi.md");
495        let content = std::fs::read_to_string(&abi_path).expect("wasm_abi.md should exist");
496        assert!(
497            content.contains("plugin_name"),
498            "ABI doc should define plugin_name export"
499        );
500        assert!(
501            content.contains("plugin_extensions"),
502            "ABI doc should define plugin_extensions export"
503        );
504        assert!(
505            content.contains("plugin_version"),
506            "ABI doc should define plugin_version export"
507        );
508        assert!(
509            content.contains("merge"),
510            "ABI doc should define merge function"
511        );
512        assert!(
513            content.contains("diff"),
514            "ABI doc should define diff function"
515        );
516        assert!(
517            content.contains("ABI Version"),
518            "ABI doc should specify version"
519        );
520        assert!(
521            content.contains("Memory Layout"),
522            "ABI doc should specify memory layout"
523        );
524        assert!(
525            content.contains("Error Handling"),
526            "ABI doc should specify error handling"
527        );
528    }
529
530    #[cfg(feature = "wasm-plugins")]
531    #[test]
532    fn test_plugin_registry_load_wasm_missing_file() {
533        let mut reg = PluginRegistry::new();
534        let result = reg.load_wasm_plugin(Path::new("/tmp/nonexistent-plugin.wasm"));
535        assert!(result.is_err());
536        match result.unwrap_err() {
537            PluginError::LoadFailed(msg) => {
538                assert!(msg.contains("failed to read") || msg.contains("No such file"));
539            }
540            other => panic!("expected LoadFailed, got: {other}"),
541        }
542    }
543
544    #[cfg(feature = "wasm-plugins")]
545    #[test]
546    fn test_plugin_registry_load_wasm_invalid_module() {
547        let dir = std::env::temp_dir().join("suture-wasm-test-invalid");
548        std::fs::create_dir_all(&dir).unwrap();
549        let path = dir.join("invalid.wasm");
550        std::fs::write(&path, b"not a valid wasm module").unwrap();
551
552        let mut reg = PluginRegistry::new();
553        let result = reg.load_wasm_plugin(&path);
554        assert!(result.is_err());
555
556        let _ = std::fs::remove_dir_all(&dir);
557    }
558}