1use crate::{custom_ops::CustomOpBuilder, JitError, JitResult};
12use std::collections::HashMap;
13use std::ffi::OsStr;
14use std::path::{Path, PathBuf};
15use std::sync::{Arc, RwLock};
16
17pub const PLUGIN_API_VERSION: u32 = 1;
19
20#[derive(Debug, Clone)]
22pub struct PluginMetadata {
23 pub name: String,
25
26 pub version: String,
28
29 pub description: String,
31
32 pub author: String,
34
35 pub api_version: u32,
37
38 pub dependencies: Vec<String>,
40
41 pub capabilities: Vec<PluginCapability>,
43}
44
45#[derive(Debug, Clone)]
47pub enum PluginCapability {
48 CustomOperators,
50
51 OptimizationPasses,
53
54 BackendImplementation(String),
56
57 TypeSystem,
59
60 DebuggingTools,
62
63 Custom(String),
65}
66
67pub trait Plugin: Send + Sync {
69 fn metadata(&self) -> &PluginMetadata;
71
72 fn initialize(&mut self, context: &PluginContext) -> JitResult<()>;
74
75 fn register(&self, registry: &mut PluginRegistry) -> JitResult<()>;
77
78 fn cleanup(&mut self) -> JitResult<()>;
80}
81
82#[derive(Debug)]
84pub struct PluginContext {
85 pub jit_version: String,
87
88 pub features: Vec<String>,
90
91 pub config: HashMap<String, String>,
93}
94
95pub struct DynamicPlugin {
97 metadata: PluginMetadata,
99
100 _lib_handle: String,
102
103 plugin: Box<dyn Plugin>,
105}
106
107impl DynamicPlugin {
108 pub fn load<P: AsRef<Path>>(path: P) -> JitResult<Self> {
110 let path = path.as_ref();
111
112 let lib_path = path.to_string_lossy().to_string();
115
116 if !path.exists() {
118 return Err(JitError::RuntimeError(format!(
119 "Plugin file not found: {}",
120 path.display()
121 )));
122 }
123
124 let metadata = Self::load_metadata(&lib_path)?;
126 let plugin = Self::create_plugin_instance(&lib_path, &metadata)?;
127
128 Ok(Self {
129 metadata,
130 _lib_handle: lib_path,
131 plugin,
132 })
133 }
134
135 fn load_metadata(lib_path: &str) -> JitResult<PluginMetadata> {
137 let name = Path::new(lib_path)
140 .file_stem()
141 .and_then(OsStr::to_str)
142 .unwrap_or("unknown")
143 .to_string();
144
145 Ok(PluginMetadata {
146 name,
147 version: "1.0.0".to_string(),
148 description: "Dynamically loaded plugin".to_string(),
149 author: "Unknown".to_string(),
150 api_version: PLUGIN_API_VERSION,
151 dependencies: vec![],
152 capabilities: vec![PluginCapability::CustomOperators],
153 })
154 }
155
156 fn create_plugin_instance(
158 _lib_path: &str,
159 metadata: &PluginMetadata,
160 ) -> JitResult<Box<dyn Plugin>> {
161 Ok(Box::new(ExamplePlugin::new(metadata.clone())))
164 }
165
166 pub fn metadata(&self) -> &PluginMetadata {
168 &self.metadata
169 }
170
171 pub fn initialize(&mut self, context: &PluginContext) -> JitResult<()> {
173 if self.metadata.api_version != PLUGIN_API_VERSION {
175 return Err(JitError::RuntimeError(format!(
176 "Plugin API version mismatch: expected {}, got {}",
177 PLUGIN_API_VERSION, self.metadata.api_version
178 )));
179 }
180
181 self.plugin.initialize(context)
182 }
183
184 pub fn register(&self, registry: &mut PluginRegistry) -> JitResult<()> {
186 self.plugin.register(registry)
187 }
188
189 pub fn cleanup(&mut self) -> JitResult<()> {
191 self.plugin.cleanup()
192 }
193}
194
195pub struct PluginRegistry {
197 plugins: HashMap<String, DynamicPlugin>,
199
200 custom_op_builders: Vec<Box<dyn Fn() -> JitResult<CustomOpBuilder> + Send + Sync>>,
202
203 optimization_passes: Vec<Box<dyn Fn() -> JitResult<Box<dyn OptimizationPass>> + Send + Sync>>,
205
206 backend_impls: HashMap<String, Box<dyn Backend + Send + Sync>>,
208
209 search_paths: Vec<PathBuf>,
211}
212
213pub trait OptimizationPass: Send + Sync {
215 fn name(&self) -> &str;
217
218 fn apply(&self, graph: &mut crate::ComputationGraph) -> JitResult<bool>;
220
221 fn dependencies(&self) -> Vec<String>;
223}
224
225pub trait Backend: Send + Sync {
227 fn name(&self) -> &str;
229
230 fn compile(&self, graph: &crate::ComputationGraph) -> JitResult<Box<dyn CompiledCode>>;
232
233 fn supports_operation(&self, op: &crate::graph::Operation) -> bool;
235}
236
237pub trait CompiledCode: Send + Sync {
239 fn execute(&self, inputs: &[crate::TensorRef]) -> JitResult<Vec<crate::TensorRef>>;
241
242 fn stats(&self) -> ExecutionStats;
244}
245
246#[derive(Debug, Clone)]
248pub struct ExecutionStats {
249 pub execution_time: std::time::Duration,
250 pub memory_usage: usize,
251 pub operations_count: usize,
252}
253
254impl Default for PluginRegistry {
255 fn default() -> Self {
256 Self::new()
257 }
258}
259
260impl PluginRegistry {
261 pub fn new() -> Self {
263 Self {
264 plugins: HashMap::new(),
265 custom_op_builders: Vec::new(),
266 optimization_passes: Vec::new(),
267 backend_impls: HashMap::new(),
268 search_paths: vec![
269 PathBuf::from("./plugins"),
270 PathBuf::from("/usr/local/lib/torsh/plugins"),
271 PathBuf::from("~/.torsh/plugins"),
272 ],
273 }
274 }
275
276 pub fn add_search_path<P: AsRef<Path>>(&mut self, path: P) {
278 self.search_paths.push(path.as_ref().to_path_buf());
279 }
280
281 pub fn load_plugin<P: AsRef<Path>>(&mut self, path: P) -> JitResult<()> {
283 let mut plugin = DynamicPlugin::load(path)?;
284
285 let context = PluginContext {
286 jit_version: "0.1.0".to_string(),
287 features: vec!["custom_ops".to_string(), "optimization".to_string()],
288 config: HashMap::new(),
289 };
290
291 plugin.initialize(&context)?;
292 plugin.register(self)?;
293
294 let plugin_name = plugin.metadata().name.clone();
295 self.plugins.insert(plugin_name, plugin);
296
297 Ok(())
298 }
299
300 pub fn load_all_plugins(&mut self) -> JitResult<Vec<String>> {
302 let mut loaded_plugins = Vec::new();
303
304 for search_path in &self.search_paths.clone() {
305 if let Ok(entries) = std::fs::read_dir(search_path) {
306 for entry in entries.flatten() {
307 let path = entry.path();
308 if self.is_plugin_file(&path) {
309 match self.load_plugin(&path) {
310 Ok(()) => {
311 if let Some(filename) = path.file_name() {
312 loaded_plugins.push(filename.to_string_lossy().to_string());
313 }
314 }
315 Err(e) => {
316 eprintln!("Failed to load plugin {}: {}", path.display(), e);
317 }
318 }
319 }
320 }
321 }
322 }
323
324 Ok(loaded_plugins)
325 }
326
327 fn is_plugin_file(&self, path: &Path) -> bool {
329 if let Some(extension) = path.extension() {
330 match extension.to_str() {
331 Some("so") | Some("dll") | Some("dylib") => true,
332 _ => false,
333 }
334 } else {
335 false
336 }
337 }
338
339 pub fn find_plugin(&self, name: &str) -> Option<&DynamicPlugin> {
341 self.plugins.get(name)
342 }
343
344 pub fn unload_plugin(&mut self, name: &str) -> JitResult<()> {
346 if let Some(mut plugin) = self.plugins.remove(name) {
347 plugin.cleanup()?;
348 }
349 Ok(())
350 }
351
352 pub fn list_plugins(&self) -> Vec<&PluginMetadata> {
354 self.plugins.values().map(|p| p.metadata()).collect()
355 }
356
357 pub fn register_custom_op_builder<F>(&mut self, builder: F)
359 where
360 F: Fn() -> JitResult<CustomOpBuilder> + Send + Sync + 'static,
361 {
362 self.custom_op_builders.push(Box::new(builder));
363 }
364
365 pub fn register_optimization_pass<F>(&mut self, factory: F)
367 where
368 F: Fn() -> JitResult<Box<dyn OptimizationPass>> + Send + Sync + 'static,
369 {
370 self.optimization_passes.push(Box::new(factory));
371 }
372
373 pub fn register_backend(&mut self, backend: Box<dyn Backend + Send + Sync>) {
375 let name = backend.name().to_string();
376 self.backend_impls.insert(name, backend);
377 }
378
379 pub fn get_custom_op_builders(
381 &self,
382 ) -> &[Box<dyn Fn() -> JitResult<CustomOpBuilder> + Send + Sync>] {
383 &self.custom_op_builders
384 }
385
386 pub fn get_optimization_passes(
388 &self,
389 ) -> &[Box<dyn Fn() -> JitResult<Box<dyn OptimizationPass>> + Send + Sync>] {
390 &self.optimization_passes
391 }
392
393 pub fn get_backend(&self, name: &str) -> Option<&(dyn Backend + Send + Sync)> {
395 self.backend_impls.get(name).map(|b| b.as_ref())
396 }
397
398 pub fn list_backends(&self) -> Vec<&str> {
400 self.backend_impls.keys().map(|s| s.as_str()).collect()
401 }
402}
403
404lazy_static::lazy_static! {
406 static ref GLOBAL_REGISTRY: Arc<RwLock<PluginRegistry>> =
407 Arc::new(RwLock::new(PluginRegistry::new()));
408}
409
410pub fn global_registry() -> Arc<RwLock<PluginRegistry>> {
412 GLOBAL_REGISTRY.clone()
413}
414
415pub fn load_plugin<P: AsRef<Path>>(path: P) -> JitResult<()> {
417 let binding = global_registry();
418 let mut registry = binding
419 .write()
420 .map_err(|_| JitError::RuntimeError("Failed to acquire registry lock".to_string()))?;
421 registry.load_plugin(path)
422}
423
424pub fn load_all_plugins() -> JitResult<Vec<String>> {
426 let binding = global_registry();
427 let mut registry = binding
428 .write()
429 .map_err(|_| JitError::RuntimeError("Failed to acquire registry lock".to_string()))?;
430 registry.load_all_plugins()
431}
432
433pub struct PluginManager {
435 registry: Arc<RwLock<PluginRegistry>>,
436 auto_load: bool,
437}
438
439impl Default for PluginManager {
440 fn default() -> Self {
441 Self::new()
442 }
443}
444
445impl PluginManager {
446 pub fn new() -> Self {
448 Self {
449 registry: global_registry(),
450 auto_load: true,
451 }
452 }
453
454 pub fn with_registry(registry: Arc<RwLock<PluginRegistry>>) -> Self {
456 Self {
457 registry,
458 auto_load: true,
459 }
460 }
461
462 pub fn set_auto_load(&mut self, auto_load: bool) {
464 self.auto_load = auto_load;
465 }
466
467 pub fn initialize(&self) -> JitResult<()> {
469 if self.auto_load {
470 self.load_all_plugins()?;
471 }
472 Ok(())
473 }
474
475 pub fn load_plugin<P: AsRef<Path>>(&self, path: P) -> JitResult<()> {
477 let mut registry = self
478 .registry
479 .write()
480 .map_err(|_| JitError::RuntimeError("Failed to acquire registry lock".to_string()))?;
481 registry.load_plugin(path)
482 }
483
484 pub fn load_all_plugins(&self) -> JitResult<Vec<String>> {
486 let mut registry = self
487 .registry
488 .write()
489 .map_err(|_| JitError::RuntimeError("Failed to acquire registry lock".to_string()))?;
490 registry.load_all_plugins()
491 }
492
493 pub fn unload_plugin(&self, name: &str) -> JitResult<()> {
495 let mut registry = self
496 .registry
497 .write()
498 .map_err(|_| JitError::RuntimeError("Failed to acquire registry lock".to_string()))?;
499 registry.unload_plugin(name)
500 }
501
502 pub fn list_plugins(&self) -> Vec<PluginMetadata> {
504 match self.registry.read() {
505 Ok(registry) => registry.list_plugins().into_iter().cloned().collect(),
506 Err(_) => vec![],
507 }
508 }
509
510 pub fn get_plugin_info(&self, name: &str) -> Option<PluginMetadata> {
512 let registry = self.registry.read().ok()?;
513 registry.find_plugin(name).map(|p| p.metadata().clone())
514 }
515
516 pub fn is_plugin_loaded(&self, name: &str) -> bool {
518 match self.registry.read() {
519 Ok(registry) => registry.find_plugin(name).is_some(),
520 Err(_) => false,
521 }
522 }
523}
524
525pub struct ExamplePlugin {
527 metadata: PluginMetadata,
528 initialized: bool,
529}
530
531impl ExamplePlugin {
532 pub fn new(metadata: PluginMetadata) -> Self {
533 Self {
534 metadata,
535 initialized: false,
536 }
537 }
538}
539
540impl Plugin for ExamplePlugin {
541 fn metadata(&self) -> &PluginMetadata {
542 &self.metadata
543 }
544
545 fn initialize(&mut self, _context: &PluginContext) -> JitResult<()> {
546 self.initialized = true;
547 Ok(())
548 }
549
550 fn register(&self, registry: &mut PluginRegistry) -> JitResult<()> {
551 if !self.initialized {
552 return Err(JitError::RuntimeError("Plugin not initialized".to_string()));
553 }
554
555 registry.register_custom_op_builder(|| {
557 Ok(CustomOpBuilder::new("plugin_add")
558 .namespace("example")
559 .forward(|inputs| {
560 if inputs.len() != 2 {
561 return Err(JitError::RuntimeError(
562 "plugin_add requires 2 inputs".to_string(),
563 ));
564 }
565
566 let a = &inputs[0];
567 let b = &inputs[1];
568 let mut result = a.clone();
569
570 for (i, &val_b) in b.data.iter().enumerate() {
571 if i < result.data.len() {
572 result.data[i] += val_b;
573 }
574 }
575
576 Ok(vec![result])
577 })
578 .vectorizable(true)
579 .parallelizable(true)
580 .elementwise(true))
581 });
582
583 Ok(())
584 }
585
586 fn cleanup(&mut self) -> JitResult<()> {
587 self.initialized = false;
588 Ok(())
589 }
590}
591
592pub mod discovery {
594 use super::*;
595
596 pub fn discover_plugins<P: AsRef<Path>>(path: P) -> JitResult<Vec<PathBuf>> {
598 let mut plugins = Vec::new();
599 let path = path.as_ref();
600
601 if !path.exists() {
602 return Ok(plugins);
603 }
604
605 for entry in std::fs::read_dir(path)
606 .map_err(|e| JitError::RuntimeError(format!("Failed to read directory: {}", e)))?
607 {
608 let entry = entry
609 .map_err(|e| JitError::RuntimeError(format!("Failed to read entry: {}", e)))?;
610 let path = entry.path();
611
612 if is_plugin_file(&path) {
613 plugins.push(path);
614 }
615 }
616
617 Ok(plugins)
618 }
619
620 fn is_plugin_file(path: &Path) -> bool {
622 if let Some(extension) = path.extension() {
623 matches!(extension.to_str(), Some("so") | Some("dll") | Some("dylib"))
624 } else {
625 false
626 }
627 }
628
629 pub fn validate_plugin(metadata: &PluginMetadata) -> JitResult<()> {
631 if metadata.api_version != PLUGIN_API_VERSION {
632 return Err(JitError::RuntimeError(format!(
633 "Incompatible plugin API version: expected {}, got {}",
634 PLUGIN_API_VERSION, metadata.api_version
635 )));
636 }
637
638 Ok(())
639 }
640}
641
642#[cfg(test)]
643mod tests {
644 use super::*;
645
646 #[test]
647 fn test_plugin_metadata() {
648 let metadata = PluginMetadata {
649 name: "test_plugin".to_string(),
650 version: "1.0.0".to_string(),
651 description: "Test plugin".to_string(),
652 author: "Test Author".to_string(),
653 api_version: PLUGIN_API_VERSION,
654 dependencies: vec![],
655 capabilities: vec![PluginCapability::CustomOperators],
656 };
657
658 assert_eq!(metadata.name, "test_plugin");
659 assert_eq!(metadata.api_version, PLUGIN_API_VERSION);
660 }
661
662 #[test]
663 fn test_plugin_registry() {
664 let mut registry = PluginRegistry::new();
665
666 assert_eq!(registry.list_plugins().len(), 0);
668 assert_eq!(registry.list_backends().len(), 0);
669
670 registry.add_search_path(&std::env::temp_dir().join("plugins").display().to_string());
672 assert_eq!(registry.search_paths.len(), 4); }
674
675 #[test]
676 fn test_example_plugin() {
677 let metadata = PluginMetadata {
678 name: "example".to_string(),
679 version: "1.0.0".to_string(),
680 description: "Example plugin".to_string(),
681 author: "Test".to_string(),
682 api_version: PLUGIN_API_VERSION,
683 dependencies: vec![],
684 capabilities: vec![PluginCapability::CustomOperators],
685 };
686
687 let mut plugin = ExamplePlugin::new(metadata);
688 assert!(!plugin.initialized);
689
690 let context = PluginContext {
691 jit_version: "0.1.0".to_string(),
692 features: vec![],
693 config: HashMap::new(),
694 };
695
696 assert!(plugin.initialize(&context).is_ok());
697 assert!(plugin.initialized);
698
699 let mut registry = PluginRegistry::new();
700 assert!(plugin.register(&mut registry).is_ok());
701 assert_eq!(registry.custom_op_builders.len(), 1);
702 }
703
704 #[test]
705 fn test_plugin_manager() {
706 let manager = PluginManager::new();
707 assert!(manager.auto_load);
708
709 let plugins = manager.list_plugins();
710 assert!(plugins.is_empty()); }
712}