Skip to main content

torsh_jit/
plugin_system.rs

1//! Plugin System for ToRSh JIT Compilation
2//!
3//! This module provides a dynamic plugin system that allows loading and
4//! registering custom functionality at runtime. Plugins can provide:
5//! - Custom operators
6//! - Optimization passes
7//! - Backend implementations
8//! - Type systems
9//! - Debug tools
10
11use crate::{custom_ops::CustomOpBuilder, JitError, JitResult};
12use std::collections::HashMap;
13use std::ffi::OsStr;
14use std::path::{Path, PathBuf};
15use std::sync::{Arc, RwLock};
16
17/// Plugin interface version for compatibility checking
18pub const PLUGIN_API_VERSION: u32 = 1;
19
20/// Plugin metadata
21#[derive(Debug, Clone)]
22pub struct PluginMetadata {
23    /// Plugin name
24    pub name: String,
25
26    /// Plugin version
27    pub version: String,
28
29    /// Plugin description
30    pub description: String,
31
32    /// Plugin author
33    pub author: String,
34
35    /// Required API version
36    pub api_version: u32,
37
38    /// Plugin dependencies
39    pub dependencies: Vec<String>,
40
41    /// Plugin capabilities
42    pub capabilities: Vec<PluginCapability>,
43}
44
45/// Plugin capabilities
46#[derive(Debug, Clone)]
47pub enum PluginCapability {
48    /// Provides custom operators
49    CustomOperators,
50
51    /// Provides optimization passes
52    OptimizationPasses,
53
54    /// Provides backend implementations
55    BackendImplementation(String),
56
57    /// Provides type systems
58    TypeSystem,
59
60    /// Provides debugging tools
61    DebuggingTools,
62
63    /// Custom capability
64    Custom(String),
65}
66
67/// Plugin trait that all plugins must implement
68pub trait Plugin: Send + Sync {
69    /// Get plugin metadata
70    fn metadata(&self) -> &PluginMetadata;
71
72    /// Initialize the plugin
73    fn initialize(&mut self, context: &PluginContext) -> JitResult<()>;
74
75    /// Register plugin functionality
76    fn register(&self, registry: &mut PluginRegistry) -> JitResult<()>;
77
78    /// Cleanup when plugin is unloaded
79    fn cleanup(&mut self) -> JitResult<()>;
80}
81
82/// Plugin context provided during initialization
83#[derive(Debug)]
84pub struct PluginContext {
85    /// JIT compiler version
86    pub jit_version: String,
87
88    /// Available features
89    pub features: Vec<String>,
90
91    /// Configuration parameters
92    pub config: HashMap<String, String>,
93}
94
95/// Dynamic library plugin wrapper
96pub struct DynamicPlugin {
97    /// Plugin metadata
98    metadata: PluginMetadata,
99
100    /// Dynamic library handle (conceptual - would use libloading in real implementation)
101    _lib_handle: String,
102
103    /// Plugin instance
104    plugin: Box<dyn Plugin>,
105}
106
107impl DynamicPlugin {
108    /// Load plugin from dynamic library
109    pub fn load<P: AsRef<Path>>(path: P) -> JitResult<Self> {
110        let path = path.as_ref();
111
112        // In a real implementation, this would use libloading or similar
113        // For now, we'll simulate it
114        let lib_path = path.to_string_lossy().to_string();
115
116        // Validate file exists and is a valid library
117        if !path.exists() {
118            return Err(JitError::RuntimeError(format!(
119                "Plugin file not found: {}",
120                path.display()
121            )));
122        }
123
124        // Load library symbols (simulated)
125        let metadata = Self::load_metadata(&lib_path)?;
126        let plugin = Self::create_plugin_instance(&lib_path, &metadata)?;
127
128        Ok(Self {
129            metadata,
130            _lib_handle: lib_path,
131            plugin,
132        })
133    }
134
135    /// Load plugin metadata from library
136    fn load_metadata(lib_path: &str) -> JitResult<PluginMetadata> {
137        // Simulated metadata loading
138        // In real implementation, would load from library symbols
139        let name = Path::new(lib_path)
140            .file_stem()
141            .and_then(OsStr::to_str)
142            .unwrap_or("unknown")
143            .to_string();
144
145        Ok(PluginMetadata {
146            name,
147            version: "1.0.0".to_string(),
148            description: "Dynamically loaded plugin".to_string(),
149            author: "Unknown".to_string(),
150            api_version: PLUGIN_API_VERSION,
151            dependencies: vec![],
152            capabilities: vec![PluginCapability::CustomOperators],
153        })
154    }
155
156    /// Create plugin instance from library
157    fn create_plugin_instance(
158        _lib_path: &str,
159        metadata: &PluginMetadata,
160    ) -> JitResult<Box<dyn Plugin>> {
161        // Simulated plugin instantiation
162        // In real implementation, would call library constructor
163        Ok(Box::new(ExamplePlugin::new(metadata.clone())))
164    }
165
166    /// Get plugin metadata
167    pub fn metadata(&self) -> &PluginMetadata {
168        &self.metadata
169    }
170
171    /// Initialize plugin
172    pub fn initialize(&mut self, context: &PluginContext) -> JitResult<()> {
173        // Check API compatibility
174        if self.metadata.api_version != PLUGIN_API_VERSION {
175            return Err(JitError::RuntimeError(format!(
176                "Plugin API version mismatch: expected {}, got {}",
177                PLUGIN_API_VERSION, self.metadata.api_version
178            )));
179        }
180
181        self.plugin.initialize(context)
182    }
183
184    /// Register plugin functionality
185    pub fn register(&self, registry: &mut PluginRegistry) -> JitResult<()> {
186        self.plugin.register(registry)
187    }
188
189    /// Cleanup plugin
190    pub fn cleanup(&mut self) -> JitResult<()> {
191        self.plugin.cleanup()
192    }
193}
194
195/// Plugin registry for managing loaded plugins
196pub struct PluginRegistry {
197    /// Loaded plugins
198    plugins: HashMap<String, DynamicPlugin>,
199
200    /// Custom operator builders
201    custom_op_builders: Vec<Box<dyn Fn() -> JitResult<CustomOpBuilder> + Send + Sync>>,
202
203    /// Optimization pass factories
204    optimization_passes: Vec<Box<dyn Fn() -> JitResult<Box<dyn OptimizationPass>> + Send + Sync>>,
205
206    /// Backend implementations
207    backend_impls: HashMap<String, Box<dyn Backend + Send + Sync>>,
208
209    /// Plugin search paths
210    search_paths: Vec<PathBuf>,
211}
212
213/// Optimization pass trait for plugins
214pub trait OptimizationPass: Send + Sync {
215    /// Pass name
216    fn name(&self) -> &str;
217
218    /// Apply optimization pass
219    fn apply(&self, graph: &mut crate::ComputationGraph) -> JitResult<bool>;
220
221    /// Pass dependencies (other passes that must run before this one)
222    fn dependencies(&self) -> Vec<String>;
223}
224
225/// Backend trait for plugin backends
226pub trait Backend: Send + Sync {
227    /// Backend name
228    fn name(&self) -> &str;
229
230    /// Compile graph to backend-specific representation
231    fn compile(&self, graph: &crate::ComputationGraph) -> JitResult<Box<dyn CompiledCode>>;
232
233    /// Check if operation is supported
234    fn supports_operation(&self, op: &crate::graph::Operation) -> bool;
235}
236
237/// Compiled code trait
238pub trait CompiledCode: Send + Sync {
239    /// Execute compiled code
240    fn execute(&self, inputs: &[crate::TensorRef]) -> JitResult<Vec<crate::TensorRef>>;
241
242    /// Get execution statistics
243    fn stats(&self) -> ExecutionStats;
244}
245
246/// Execution statistics
247#[derive(Debug, Clone)]
248pub struct ExecutionStats {
249    pub execution_time: std::time::Duration,
250    pub memory_usage: usize,
251    pub operations_count: usize,
252}
253
254impl Default for PluginRegistry {
255    fn default() -> Self {
256        Self::new()
257    }
258}
259
260impl PluginRegistry {
261    /// Create a new plugin registry
262    pub fn new() -> Self {
263        Self {
264            plugins: HashMap::new(),
265            custom_op_builders: Vec::new(),
266            optimization_passes: Vec::new(),
267            backend_impls: HashMap::new(),
268            search_paths: vec![
269                PathBuf::from("./plugins"),
270                PathBuf::from("/usr/local/lib/torsh/plugins"),
271                PathBuf::from("~/.torsh/plugins"),
272            ],
273        }
274    }
275
276    /// Add plugin search path
277    pub fn add_search_path<P: AsRef<Path>>(&mut self, path: P) {
278        self.search_paths.push(path.as_ref().to_path_buf());
279    }
280
281    /// Load plugin from file
282    pub fn load_plugin<P: AsRef<Path>>(&mut self, path: P) -> JitResult<()> {
283        let mut plugin = DynamicPlugin::load(path)?;
284
285        let context = PluginContext {
286            jit_version: "0.1.0".to_string(),
287            features: vec!["custom_ops".to_string(), "optimization".to_string()],
288            config: HashMap::new(),
289        };
290
291        plugin.initialize(&context)?;
292        plugin.register(self)?;
293
294        let plugin_name = plugin.metadata().name.clone();
295        self.plugins.insert(plugin_name, plugin);
296
297        Ok(())
298    }
299
300    /// Load all plugins from search paths
301    pub fn load_all_plugins(&mut self) -> JitResult<Vec<String>> {
302        let mut loaded_plugins = Vec::new();
303
304        for search_path in &self.search_paths.clone() {
305            if let Ok(entries) = std::fs::read_dir(search_path) {
306                for entry in entries.flatten() {
307                    let path = entry.path();
308                    if self.is_plugin_file(&path) {
309                        match self.load_plugin(&path) {
310                            Ok(()) => {
311                                if let Some(filename) = path.file_name() {
312                                    loaded_plugins.push(filename.to_string_lossy().to_string());
313                                }
314                            }
315                            Err(e) => {
316                                eprintln!("Failed to load plugin {}: {}", path.display(), e);
317                            }
318                        }
319                    }
320                }
321            }
322        }
323
324        Ok(loaded_plugins)
325    }
326
327    /// Check if file is a plugin
328    fn is_plugin_file(&self, path: &Path) -> bool {
329        if let Some(extension) = path.extension() {
330            match extension.to_str() {
331                Some("so") | Some("dll") | Some("dylib") => true,
332                _ => false,
333            }
334        } else {
335            false
336        }
337    }
338
339    /// Find plugin by name
340    pub fn find_plugin(&self, name: &str) -> Option<&DynamicPlugin> {
341        self.plugins.get(name)
342    }
343
344    /// Unload plugin
345    pub fn unload_plugin(&mut self, name: &str) -> JitResult<()> {
346        if let Some(mut plugin) = self.plugins.remove(name) {
347            plugin.cleanup()?;
348        }
349        Ok(())
350    }
351
352    /// List loaded plugins
353    pub fn list_plugins(&self) -> Vec<&PluginMetadata> {
354        self.plugins.values().map(|p| p.metadata()).collect()
355    }
356
357    /// Register custom operator builder
358    pub fn register_custom_op_builder<F>(&mut self, builder: F)
359    where
360        F: Fn() -> JitResult<CustomOpBuilder> + Send + Sync + 'static,
361    {
362        self.custom_op_builders.push(Box::new(builder));
363    }
364
365    /// Register optimization pass
366    pub fn register_optimization_pass<F>(&mut self, factory: F)
367    where
368        F: Fn() -> JitResult<Box<dyn OptimizationPass>> + Send + Sync + 'static,
369    {
370        self.optimization_passes.push(Box::new(factory));
371    }
372
373    /// Register backend implementation
374    pub fn register_backend(&mut self, backend: Box<dyn Backend + Send + Sync>) {
375        let name = backend.name().to_string();
376        self.backend_impls.insert(name, backend);
377    }
378
379    /// Get custom operator builders
380    pub fn get_custom_op_builders(
381        &self,
382    ) -> &[Box<dyn Fn() -> JitResult<CustomOpBuilder> + Send + Sync>] {
383        &self.custom_op_builders
384    }
385
386    /// Get optimization passes
387    pub fn get_optimization_passes(
388        &self,
389    ) -> &[Box<dyn Fn() -> JitResult<Box<dyn OptimizationPass>> + Send + Sync>] {
390        &self.optimization_passes
391    }
392
393    /// Get backend implementation
394    pub fn get_backend(&self, name: &str) -> Option<&(dyn Backend + Send + Sync)> {
395        self.backend_impls.get(name).map(|b| b.as_ref())
396    }
397
398    /// List available backends
399    pub fn list_backends(&self) -> Vec<&str> {
400        self.backend_impls.keys().map(|s| s.as_str()).collect()
401    }
402}
403
404// Global plugin registry (documentation on accessor function below)
405lazy_static::lazy_static! {
406    static ref GLOBAL_REGISTRY: Arc<RwLock<PluginRegistry>> =
407        Arc::new(RwLock::new(PluginRegistry::new()));
408}
409
410/// Get global plugin registry
411pub fn global_registry() -> Arc<RwLock<PluginRegistry>> {
412    GLOBAL_REGISTRY.clone()
413}
414
415/// Load plugin into global registry
416pub fn load_plugin<P: AsRef<Path>>(path: P) -> JitResult<()> {
417    let binding = global_registry();
418    let mut registry = binding
419        .write()
420        .map_err(|_| JitError::RuntimeError("Failed to acquire registry lock".to_string()))?;
421    registry.load_plugin(path)
422}
423
424/// Load all plugins from search paths
425pub fn load_all_plugins() -> JitResult<Vec<String>> {
426    let binding = global_registry();
427    let mut registry = binding
428        .write()
429        .map_err(|_| JitError::RuntimeError("Failed to acquire registry lock".to_string()))?;
430    registry.load_all_plugins()
431}
432
433/// Plugin manager for high-level plugin operations
434pub struct PluginManager {
435    registry: Arc<RwLock<PluginRegistry>>,
436    auto_load: bool,
437}
438
439impl Default for PluginManager {
440    fn default() -> Self {
441        Self::new()
442    }
443}
444
445impl PluginManager {
446    /// Create a new plugin manager
447    pub fn new() -> Self {
448        Self {
449            registry: global_registry(),
450            auto_load: true,
451        }
452    }
453
454    /// Create plugin manager with custom registry
455    pub fn with_registry(registry: Arc<RwLock<PluginRegistry>>) -> Self {
456        Self {
457            registry,
458            auto_load: true,
459        }
460    }
461
462    /// Enable/disable auto-loading plugins
463    pub fn set_auto_load(&mut self, auto_load: bool) {
464        self.auto_load = auto_load;
465    }
466
467    /// Initialize plugin system
468    pub fn initialize(&self) -> JitResult<()> {
469        if self.auto_load {
470            self.load_all_plugins()?;
471        }
472        Ok(())
473    }
474
475    /// Load plugin
476    pub fn load_plugin<P: AsRef<Path>>(&self, path: P) -> JitResult<()> {
477        let mut registry = self
478            .registry
479            .write()
480            .map_err(|_| JitError::RuntimeError("Failed to acquire registry lock".to_string()))?;
481        registry.load_plugin(path)
482    }
483
484    /// Load all plugins
485    pub fn load_all_plugins(&self) -> JitResult<Vec<String>> {
486        let mut registry = self
487            .registry
488            .write()
489            .map_err(|_| JitError::RuntimeError("Failed to acquire registry lock".to_string()))?;
490        registry.load_all_plugins()
491    }
492
493    /// Unload plugin
494    pub fn unload_plugin(&self, name: &str) -> JitResult<()> {
495        let mut registry = self
496            .registry
497            .write()
498            .map_err(|_| JitError::RuntimeError("Failed to acquire registry lock".to_string()))?;
499        registry.unload_plugin(name)
500    }
501
502    /// List plugins
503    pub fn list_plugins(&self) -> Vec<PluginMetadata> {
504        match self.registry.read() {
505            Ok(registry) => registry.list_plugins().into_iter().cloned().collect(),
506            Err(_) => vec![],
507        }
508    }
509
510    /// Get plugin info
511    pub fn get_plugin_info(&self, name: &str) -> Option<PluginMetadata> {
512        let registry = self.registry.read().ok()?;
513        registry.find_plugin(name).map(|p| p.metadata().clone())
514    }
515
516    /// Check if plugin is loaded
517    pub fn is_plugin_loaded(&self, name: &str) -> bool {
518        match self.registry.read() {
519            Ok(registry) => registry.find_plugin(name).is_some(),
520            Err(_) => false,
521        }
522    }
523}
524
525/// Example plugin implementation
526pub struct ExamplePlugin {
527    metadata: PluginMetadata,
528    initialized: bool,
529}
530
531impl ExamplePlugin {
532    pub fn new(metadata: PluginMetadata) -> Self {
533        Self {
534            metadata,
535            initialized: false,
536        }
537    }
538}
539
540impl Plugin for ExamplePlugin {
541    fn metadata(&self) -> &PluginMetadata {
542        &self.metadata
543    }
544
545    fn initialize(&mut self, _context: &PluginContext) -> JitResult<()> {
546        self.initialized = true;
547        Ok(())
548    }
549
550    fn register(&self, registry: &mut PluginRegistry) -> JitResult<()> {
551        if !self.initialized {
552            return Err(JitError::RuntimeError("Plugin not initialized".to_string()));
553        }
554
555        // Register a sample custom operator
556        registry.register_custom_op_builder(|| {
557            Ok(CustomOpBuilder::new("plugin_add")
558                .namespace("example")
559                .forward(|inputs| {
560                    if inputs.len() != 2 {
561                        return Err(JitError::RuntimeError(
562                            "plugin_add requires 2 inputs".to_string(),
563                        ));
564                    }
565
566                    let a = &inputs[0];
567                    let b = &inputs[1];
568                    let mut result = a.clone();
569
570                    for (i, &val_b) in b.data.iter().enumerate() {
571                        if i < result.data.len() {
572                            result.data[i] += val_b;
573                        }
574                    }
575
576                    Ok(vec![result])
577                })
578                .vectorizable(true)
579                .parallelizable(true)
580                .elementwise(true))
581        });
582
583        Ok(())
584    }
585
586    fn cleanup(&mut self) -> JitResult<()> {
587        self.initialized = false;
588        Ok(())
589    }
590}
591
592/// Plugin discovery utilities
593pub mod discovery {
594    use super::*;
595
596    /// Discover plugins in a directory
597    pub fn discover_plugins<P: AsRef<Path>>(path: P) -> JitResult<Vec<PathBuf>> {
598        let mut plugins = Vec::new();
599        let path = path.as_ref();
600
601        if !path.exists() {
602            return Ok(plugins);
603        }
604
605        for entry in std::fs::read_dir(path)
606            .map_err(|e| JitError::RuntimeError(format!("Failed to read directory: {}", e)))?
607        {
608            let entry = entry
609                .map_err(|e| JitError::RuntimeError(format!("Failed to read entry: {}", e)))?;
610            let path = entry.path();
611
612            if is_plugin_file(&path) {
613                plugins.push(path);
614            }
615        }
616
617        Ok(plugins)
618    }
619
620    /// Check if file is a plugin
621    fn is_plugin_file(path: &Path) -> bool {
622        if let Some(extension) = path.extension() {
623            matches!(extension.to_str(), Some("so") | Some("dll") | Some("dylib"))
624        } else {
625            false
626        }
627    }
628
629    /// Validate plugin compatibility
630    pub fn validate_plugin(metadata: &PluginMetadata) -> JitResult<()> {
631        if metadata.api_version != PLUGIN_API_VERSION {
632            return Err(JitError::RuntimeError(format!(
633                "Incompatible plugin API version: expected {}, got {}",
634                PLUGIN_API_VERSION, metadata.api_version
635            )));
636        }
637
638        Ok(())
639    }
640}
641
642#[cfg(test)]
643mod tests {
644    use super::*;
645
646    #[test]
647    fn test_plugin_metadata() {
648        let metadata = PluginMetadata {
649            name: "test_plugin".to_string(),
650            version: "1.0.0".to_string(),
651            description: "Test plugin".to_string(),
652            author: "Test Author".to_string(),
653            api_version: PLUGIN_API_VERSION,
654            dependencies: vec![],
655            capabilities: vec![PluginCapability::CustomOperators],
656        };
657
658        assert_eq!(metadata.name, "test_plugin");
659        assert_eq!(metadata.api_version, PLUGIN_API_VERSION);
660    }
661
662    #[test]
663    fn test_plugin_registry() {
664        let mut registry = PluginRegistry::new();
665
666        // Test initial state
667        assert_eq!(registry.list_plugins().len(), 0);
668        assert_eq!(registry.list_backends().len(), 0);
669
670        // Test adding search path
671        registry.add_search_path(&std::env::temp_dir().join("plugins").display().to_string());
672        assert_eq!(registry.search_paths.len(), 4); // 3 default + 1 added
673    }
674
675    #[test]
676    fn test_example_plugin() {
677        let metadata = PluginMetadata {
678            name: "example".to_string(),
679            version: "1.0.0".to_string(),
680            description: "Example plugin".to_string(),
681            author: "Test".to_string(),
682            api_version: PLUGIN_API_VERSION,
683            dependencies: vec![],
684            capabilities: vec![PluginCapability::CustomOperators],
685        };
686
687        let mut plugin = ExamplePlugin::new(metadata);
688        assert!(!plugin.initialized);
689
690        let context = PluginContext {
691            jit_version: "0.1.0".to_string(),
692            features: vec![],
693            config: HashMap::new(),
694        };
695
696        assert!(plugin.initialize(&context).is_ok());
697        assert!(plugin.initialized);
698
699        let mut registry = PluginRegistry::new();
700        assert!(plugin.register(&mut registry).is_ok());
701        assert_eq!(registry.custom_op_builders.len(), 1);
702    }
703
704    #[test]
705    fn test_plugin_manager() {
706        let manager = PluginManager::new();
707        assert!(manager.auto_load);
708
709        let plugins = manager.list_plugins();
710        assert!(plugins.is_empty()); // No plugins loaded initially
711    }
712}