Skip to main content

suture_driver/
plugin.rs

1use crate::SutureDriver;
2use std::collections::HashMap;
3use std::path::Path;
4use std::sync::Arc;
5
6#[derive(Debug, thiserror::Error)]
7pub enum PluginError {
8    #[error("failed to load plugin: {0}")]
9    LoadFailed(String),
10    #[error("missing required export: {0}")]
11    MissingExport(String),
12    #[error("plugin ABI version mismatch: expected {expected}, got {actual}")]
13    AbiVersionMismatch { expected: i32, actual: i32 },
14    #[cfg(feature = "wasm-plugins")]
15    #[error("wasmtime error: {0}")]
16    Wasmtime(#[from] wasmtime::Error),
17}
18
19pub trait DriverPlugin: Send + Sync {
20    fn name(&self) -> &str;
21    fn extensions(&self) -> &[&str];
22    fn description(&self) -> &str;
23    fn as_driver(&self) -> &dyn SutureDriver;
24}
25
26pub struct BuiltinDriverPlugin<D> {
27    name: &'static str,
28    extensions: Vec<&'static str>,
29    description: &'static str,
30    driver: D,
31}
32
33impl<D: SutureDriver + Send + Sync + 'static> BuiltinDriverPlugin<D> {
34    pub fn new(
35        name: &'static str,
36        extensions: Vec<&'static str>,
37        description: &'static str,
38        driver: D,
39    ) -> Self {
40        Self {
41            name,
42            extensions,
43            description,
44            driver,
45        }
46    }
47}
48
49impl<D: SutureDriver + Send + Sync + 'static> DriverPlugin for BuiltinDriverPlugin<D> {
50    fn name(&self) -> &str {
51        self.name
52    }
53    fn extensions(&self) -> &[&str] {
54        &self.extensions
55    }
56    fn description(&self) -> &str {
57        self.description
58    }
59    fn as_driver(&self) -> &dyn SutureDriver {
60        &self.driver
61    }
62}
63
64pub struct PluginRegistry {
65    plugins: HashMap<String, Arc<dyn DriverPlugin>>,
66    extension_map: HashMap<String, String>,
67}
68
69impl PluginRegistry {
70    pub fn new() -> Self {
71        Self {
72            plugins: HashMap::new(),
73            extension_map: HashMap::new(),
74        }
75    }
76
77    pub fn register(&mut self, plugin: Arc<dyn DriverPlugin>) {
78        let name = plugin.name().to_string();
79        for ext in plugin.extensions() {
80            self.extension_map.insert(ext.to_string(), name.clone());
81        }
82        self.plugins.insert(name, plugin);
83    }
84
85    pub fn get(&self, name: &str) -> Option<&dyn DriverPlugin> {
86        self.plugins.get(name).map(|p| p.as_ref())
87    }
88
89    pub fn get_by_extension(&self, ext: &str) -> Option<&dyn DriverPlugin> {
90        let normalized = if ext.starts_with('.') {
91            ext.to_string()
92        } else {
93            format!(".{}", ext)
94        };
95        self.extension_map
96            .get(&normalized)
97            .and_then(|name| self.plugins.get(name).map(|p| p.as_ref()))
98    }
99
100    pub fn list_drivers(&self) -> Vec<&str> {
101        let mut names: Vec<&str> = self.plugins.keys().map(|s| s.as_str()).collect();
102        names.sort();
103        names
104    }
105
106    pub fn discover_plugins(&mut self, plugin_dir: &Path) {
107        if !plugin_dir.exists() {
108            return;
109        }
110
111        if let Ok(entries) = std::fs::read_dir(plugin_dir) {
112            for entry in entries.flatten() {
113                let path = entry.path();
114                if path
115                    .extension()
116                    .map(|e| e == "suture-plugin")
117                    .unwrap_or(false)
118                    && let Ok(content) = std::fs::read_to_string(&path)
119                    && let Some(desc) = Self::parse_plugin_descriptor(&content)
120                {
121                    let _ = desc; // plugin descriptor found; future: dynamic loading
122                }
123            }
124        }
125    }
126
127    fn parse_plugin_descriptor(content: &str) -> Option<PluginDescriptor> {
128        let mut name = None;
129        let mut extensions = Vec::new();
130        let mut description = String::new();
131
132        for line in content.lines() {
133            let line = line.trim();
134            if let Some(val) = line
135                .strip_prefix("name")
136                .and_then(Self::extract_string_value)
137            {
138                name = Some(val);
139            } else if let Some(start) = line.find('[') {
140                if let Some(end) = line[start..].find(']') {
141                    let inner = &line[start + 1..start + end];
142                    for ext in inner.split(',') {
143                        let ext = ext.trim().trim_matches('"');
144                        if !ext.is_empty() {
145                            extensions.push(ext.to_string());
146                        }
147                    }
148                }
149            } else if let Some(val) = line
150                .strip_prefix("description")
151                .and_then(Self::extract_string_value)
152            {
153                description = val;
154            }
155        }
156
157        name.map(|name| PluginDescriptor {
158            name,
159            extensions,
160            description,
161        })
162    }
163
164    fn extract_string_value(line: &str) -> Option<String> {
165        if let Some(eq_pos) = line.find('=') {
166            let val = line[eq_pos + 1..].trim();
167            if val.starts_with('"') && val.ends_with('"') {
168                return Some(val[1..val.len() - 1].to_string());
169            }
170        }
171        None
172    }
173}
174
175#[allow(dead_code)]
176struct PluginDescriptor {
177    name: String,
178    extensions: Vec<String>,
179    #[allow(dead_code)]
180    description: String,
181}
182
183impl Default for PluginRegistry {
184    fn default() -> Self {
185        Self::new()
186    }
187}
188
189impl PluginRegistry {
190    #[cfg(feature = "wasm-plugins")]
191    pub fn load_wasm_plugin(&mut self, path: &Path) -> Result<(), PluginError> {
192        let plugin = WasmDriverPlugin::from_file(path)?;
193        let name = plugin.name.clone();
194        let extensions: Vec<String> = plugin.extensions_storage.clone();
195        let plugin_arc = Arc::new(plugin);
196        for ext in &extensions {
197            self.extension_map.insert(ext.to_string(), name.clone());
198        }
199        self.plugins.insert(name, plugin_arc);
200        Ok(())
201    }
202}
203
204#[cfg(feature = "wasm-plugins")]
205pub struct WasmDriverPlugin {
206    #[allow(dead_code)]
207    engine: wasmtime::Engine,
208    #[allow(dead_code)]
209    instance: wasmtime::Instance,
210    #[allow(dead_code)]
211    store: wasmtime::Store<()>,
212    name: String,
213    extensions_storage: Vec<String>,
214    extensions: Vec<&'static str>,
215}
216
217#[cfg(feature = "wasm-plugins")]
218impl WasmDriverPlugin {
219    pub fn from_file(path: &std::path::Path) -> Result<Self, PluginError> {
220        let engine = wasmtime::Engine::default();
221        let module = wasmtime::Module::from_file(&engine, path)
222            .map_err(|e| PluginError::LoadFailed(e.to_string()))?;
223
224        let mut store = wasmtime::Store::new(&engine, ());
225        let mut linker = wasmtime::Linker::new(&engine);
226        // TODO: Add `suture.alloc` host import for WASM memory allocation.
227        // WASM modules that call `suture.alloc` will fail at link time until this is implemented.
228
229        let instance = linker
230            .instantiate(&mut store, &module)
231            .map_err(|e| PluginError::LoadFailed(e.to_string()))?;
232
233        let version = Self::call_version_export(&mut store, &instance)?;
234        if version != 1 {
235            return Err(PluginError::AbiVersionMismatch {
236                expected: 1,
237                actual: version,
238            });
239        }
240
241        let name =
242            Self::call_string_export(&mut store, &instance, "plugin_name")
243                .unwrap_or_else(|| "unknown".to_string());
244        let extensions_storage = Self::call_extensions_export(&mut store, &instance);
245
246        let extensions: Vec<&'static str> = extensions_storage
247            .iter()
248            .map(|s| Box::leak(s.clone().into_boxed_str()))
249            .collect();
250
251        let plugin = Self {
252            engine,
253            instance,
254            store,
255            name,
256            extensions_storage,
257            extensions,
258        };
259
260        Ok(plugin)
261    }
262
263    fn call_version_export(
264        store: &mut wasmtime::Store<()>,
265        instance: &wasmtime::Instance,
266    ) -> Result<i32, PluginError> {
267        let func = instance
268            .get_typed_func::<(), i32>(&mut *store, "plugin_version")
269            .map_err(|_| PluginError::MissingExport("plugin_version".to_string()))?;
270        let version = func.call(&mut *store, ())?;
271        Ok(version)
272    }
273
274    fn call_string_export(
275        store: &mut wasmtime::Store<()>,
276        instance: &wasmtime::Instance,
277        export_name: &str,
278    ) -> Option<String> {
279        let Ok(func) = instance.get_typed_func::<(), i32>(&mut *store, export_name) else {
280            return None;
281        };
282        let Ok(ptr) = func.call(&mut *store, ()) else {
283            return None;
284        };
285
286        let memory = instance
287            .get_memory(&mut *store, "memory")?;
288
289        let mut buf = Vec::new();
290        let mut offset = ptr as usize;
291        loop {
292            if offset >= memory.data_size(&mut *store) {
293                return None;
294            }
295            let byte = memory.data(&mut *store)[offset];
296            if byte == 0 {
297                break;
298            }
299            buf.push(byte);
300            offset += 1;
301        }
302        String::from_utf8(buf).ok()
303    }
304
305    fn call_extensions_export(
306        store: &mut wasmtime::Store<()>,
307        instance: &wasmtime::Instance,
308    ) -> Vec<String> {
309        match Self::call_string_export(store, instance, "plugin_extensions") {
310            Some(csv) => csv
311                .split(',')
312                .map(|s| s.trim().to_string())
313                .filter(|s| !s.is_empty())
314                .collect(),
315            None => vec![],
316        }
317    }
318}
319
320#[cfg(feature = "wasm-plugins")]
321impl DriverPlugin for WasmDriverPlugin {
322    fn name(&self) -> &str {
323        &self.name
324    }
325
326    fn extensions(&self) -> &[&str] {
327        &self.extensions
328    }
329
330    fn description(&self) -> &str {
331        "WASM plugin driver"
332    }
333
334    fn as_driver(&self) -> &dyn SutureDriver {
335        self
336    }
337}
338
339#[cfg(feature = "wasm-plugins")]
340impl SutureDriver for WasmDriverPlugin {
341    fn name(&self) -> &str {
342        &self.name
343    }
344
345    fn supported_extensions(&self) -> &[&str] {
346        &self.extensions
347    }
348
349    fn diff(
350        &self,
351        _base_content: Option<&str>,
352        _new_content: &str,
353    ) -> Result<Vec<crate::SemanticChange>, crate::DriverError> {
354        Err(crate::DriverError::ParseError(
355            "WASM plugin diff is not yet implemented".to_string(),
356        ))
357    }
358
359    fn format_diff(
360        &self,
361        _base_content: Option<&str>,
362        _new_content: &str,
363    ) -> Result<String, crate::DriverError> {
364        Err(crate::DriverError::ParseError(
365            "WASM plugin format_diff is not yet implemented".to_string(),
366        ))
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373    use crate::DriverError;
374
375    struct MockDriver {
376        driver_name: &'static str,
377        driver_extensions: Vec<&'static str>,
378    }
379
380    impl MockDriver {
381        fn new(name: &'static str, extensions: Vec<&'static str>) -> Self {
382            Self {
383                driver_name: name,
384                driver_extensions: extensions,
385            }
386        }
387    }
388
389    impl SutureDriver for MockDriver {
390        fn name(&self) -> &str {
391            self.driver_name
392        }
393        fn supported_extensions(&self) -> &[&str] {
394            &self.driver_extensions
395        }
396        fn diff(
397            &self,
398            _base_content: Option<&str>,
399            _new_content: &str,
400        ) -> Result<Vec<crate::SemanticChange>, DriverError> {
401            Ok(vec![])
402        }
403        fn format_diff(
404            &self,
405            _base_content: Option<&str>,
406            _new_content: &str,
407        ) -> Result<String, DriverError> {
408            Ok(String::new())
409        }
410    }
411
412    fn make_plugin(name: &'static str, extensions: Vec<&'static str>) -> Arc<dyn DriverPlugin> {
413        Arc::new(BuiltinDriverPlugin::new(
414            name,
415            extensions.clone(),
416            "test driver",
417            MockDriver::new(name, extensions),
418        ))
419    }
420
421    #[test]
422    fn register_and_get_by_name() {
423        let mut reg = PluginRegistry::new();
424        reg.register(make_plugin("json", vec![".json"]));
425        assert!(reg.get("json").is_some());
426        assert!(reg.get("yaml").is_none());
427        assert_eq!(reg.get("json").unwrap().name(), "json");
428    }
429
430    #[test]
431    fn get_by_extension_with_dot() {
432        let mut reg = PluginRegistry::new();
433        reg.register(make_plugin("json", vec![".json"]));
434        assert!(reg.get_by_extension(".json").is_some());
435        assert!(reg.get_by_extension(".yaml").is_none());
436    }
437
438    #[test]
439    fn get_by_extension_without_dot() {
440        let mut reg = PluginRegistry::new();
441        reg.register(make_plugin("yaml", vec![".yaml", ".yml"]));
442        assert!(reg.get_by_extension("yaml").is_some());
443        assert!(reg.get_by_extension("yml").is_some());
444    }
445
446    #[test]
447    fn list_drivers_sorted() {
448        let mut reg = PluginRegistry::new();
449        reg.register(make_plugin("csv", vec![".csv"]));
450        reg.register(make_plugin("xml", vec![".xml"]));
451        reg.register(make_plugin("json", vec![".json"]));
452        assert_eq!(reg.list_drivers(), vec!["csv", "json", "xml"]);
453    }
454
455    #[test]
456    fn discover_plugins_nonexistent_dir() {
457        let mut reg = PluginRegistry::new();
458        reg.discover_plugins(Path::new("/tmp/suture-test-nonexistent-12345"));
459        assert!(reg.list_drivers().is_empty());
460    }
461
462    #[test]
463    fn parse_plugin_descriptor_valid() {
464        let content = r#"
465name = "my-driver"
466extensions = [".custom", ".ext"]
467description = "A custom driver"
468"#;
469        let desc = PluginRegistry::parse_plugin_descriptor(content).unwrap();
470        assert_eq!(desc.name, "my-driver");
471        assert_eq!(desc.extensions, vec![".custom", ".ext"]);
472        assert_eq!(desc.description, "A custom driver");
473    }
474
475    #[test]
476    fn parse_plugin_descriptor_missing_name() {
477        let content = r#"extensions = [".custom"]"#;
478        assert!(PluginRegistry::parse_plugin_descriptor(content).is_none());
479    }
480
481    #[test]
482    fn as_driver_returns_underlying_driver() {
483        let mut reg = PluginRegistry::new();
484        reg.register(make_plugin("json", vec![".json"]));
485        let plugin = reg.get("json").unwrap();
486        assert_eq!(plugin.as_driver().name(), "json");
487        assert_eq!(plugin.as_driver().supported_extensions(), &[".json"]);
488    }
489
490    #[cfg(feature = "wasm-plugins")]
491    #[test]
492    fn test_wasm_plugin_abi_documentation() {
493        let abi_path = Path::new(env!("CARGO_MANIFEST_DIR")).join("src/wasm_abi.md");
494        let content = std::fs::read_to_string(&abi_path)
495            .expect("wasm_abi.md should exist");
496        assert!(content.contains("plugin_name"), "ABI doc should define plugin_name export");
497        assert!(content.contains("plugin_extensions"), "ABI doc should define plugin_extensions export");
498        assert!(content.contains("plugin_version"), "ABI doc should define plugin_version export");
499        assert!(content.contains("merge"), "ABI doc should define merge function");
500        assert!(content.contains("diff"), "ABI doc should define diff function");
501        assert!(content.contains("ABI Version"), "ABI doc should specify version");
502        assert!(content.contains("Memory Layout"), "ABI doc should specify memory layout");
503        assert!(content.contains("Error Handling"), "ABI doc should specify error handling");
504    }
505
506    #[cfg(feature = "wasm-plugins")]
507    #[test]
508    fn test_plugin_registry_load_wasm_missing_file() {
509        let mut reg = PluginRegistry::new();
510        let result = reg.load_wasm_plugin(Path::new("/tmp/nonexistent-plugin.wasm"));
511        assert!(result.is_err());
512        match result.unwrap_err() {
513            PluginError::LoadFailed(msg) => {
514                assert!(msg.contains("failed to read") || msg.contains("No such file"));
515            }
516            other => panic!("expected LoadFailed, got: {other}"),
517        }
518    }
519
520    #[cfg(feature = "wasm-plugins")]
521    #[test]
522    fn test_plugin_registry_load_wasm_invalid_module() {
523        let dir = std::env::temp_dir().join("suture-wasm-test-invalid");
524        std::fs::create_dir_all(&dir).unwrap();
525        let path = dir.join("invalid.wasm");
526        std::fs::write(&path, b"not a valid wasm module").unwrap();
527
528        let mut reg = PluginRegistry::new();
529        let result = reg.load_wasm_plugin(&path);
530        assert!(result.is_err());
531
532        let _ = std::fs::remove_dir_all(&dir);
533    }
534}