Skip to main content

trustformers_core/plugins/
registry.rs

1//! Plugin registry for discovery and management.
2
3use crate::errors::{Result, TrustformersError};
4use crate::plugins::{Plugin, PluginInfo, PluginLoader, PluginManager};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::path::{Path, PathBuf};
8use std::sync::{Arc, RwLock};
9
10/// Central registry for plugin discovery and management.
11///
12/// The `PluginRegistry` maintains a catalog of available plugins,
13/// handles plugin loading and unloading, and provides compatibility
14/// checking and version management.
15///
16/// # Thread Safety
17///
18/// The registry is thread-safe and can be shared across multiple threads.
19/// It uses read-write locks to allow concurrent reads while ensuring
20/// exclusive access for modifications.
21///
22/// # Example
23///
24/// ```no_run
25/// use trustformers_core::plugins::{PluginRegistry, PluginInfo, PluginManager};
26/// use std::path::Path;
27///
28/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
29/// let registry = PluginRegistry::new();
30///
31/// // Register a plugin
32/// let info = PluginInfo::new(
33///     "custom_attention",
34///     "1.0.0",
35///     "Custom attention mechanism",
36///     &["trustformers-core >= 0.1.0"]
37/// );
38/// registry.register("custom_attention", info)?;
39///
40/// // List all available plugins
41/// let plugins = registry.list_plugins();
42/// assert!(plugins.contains(&"custom_attention".to_string()));
43/// # Ok(())
44/// # }
45/// ```
46#[derive(Debug)]
47pub struct PluginRegistry {
48    /// Plugin information cache
49    plugins: Arc<RwLock<HashMap<String, PluginInfo>>>,
50    /// Loaded plugin instances
51    loaded: Arc<RwLock<HashMap<String, Box<dyn Plugin>>>>,
52    /// Plugin search paths
53    search_paths: Arc<RwLock<Vec<PathBuf>>>,
54    /// Plugin loader
55    loader: Arc<PluginLoader>,
56    /// Registry configuration
57    config: RegistryConfig,
58}
59
60impl PluginRegistry {
61    /// Creates a new plugin registry.
62    ///
63    /// # Returns
64    ///
65    /// A new registry instance with default configuration.
66    pub fn new() -> Self {
67        Self {
68            plugins: Arc::new(RwLock::new(HashMap::new())),
69            loaded: Arc::new(RwLock::new(HashMap::new())),
70            search_paths: Arc::new(RwLock::new(Vec::new())),
71            loader: Arc::new(PluginLoader::new()),
72            config: RegistryConfig::default(),
73        }
74    }
75
76    /// Creates a new plugin registry with custom configuration.
77    ///
78    /// # Arguments
79    ///
80    /// * `config` - Registry configuration
81    ///
82    /// # Returns
83    ///
84    /// A new registry instance.
85    pub fn with_config(config: RegistryConfig) -> Self {
86        Self {
87            plugins: Arc::new(RwLock::new(HashMap::new())),
88            loaded: Arc::new(RwLock::new(HashMap::new())),
89            search_paths: Arc::new(RwLock::new(Vec::new())),
90            loader: Arc::new(PluginLoader::new()),
91            config,
92        }
93    }
94
95    /// Registers a plugin in the registry.
96    ///
97    /// # Arguments
98    ///
99    /// * `name` - Plugin name (must be unique)
100    /// * `info` - Plugin metadata
101    ///
102    /// # Returns
103    ///
104    /// `Ok(())` on success, error if registration fails.
105    ///
106    /// # Errors
107    ///
108    /// - Plugin name already exists
109    /// - Invalid plugin information
110    pub fn register(&self, name: &str, info: PluginInfo) -> Result<()> {
111        info.validate()?;
112
113        let mut plugins = self.plugins.write().map_err(|_| {
114            TrustformersError::lock_error("Failed to acquire write lock".to_string())
115        })?;
116
117        if plugins.contains_key(name) {
118            return Err(TrustformersError::plugin_error(format!(
119                "Plugin '{}' is already registered",
120                name
121            )));
122        }
123
124        plugins.insert(name.to_string(), info);
125        Ok(())
126    }
127
128    /// Unregisters a plugin from the registry.
129    ///
130    /// # Arguments
131    ///
132    /// * `name` - Plugin name to unregister
133    ///
134    /// # Returns
135    ///
136    /// `Ok(())` on success, error if the plugin is not found.
137    pub fn unregister(&self, name: &str) -> Result<()> {
138        // Unload the plugin if it's currently loaded
139        if self.is_loaded(name) {
140            self.unload_plugin(name)?;
141        }
142
143        let mut plugins = self.plugins.write().map_err(|_| {
144            TrustformersError::lock_error("Failed to acquire write lock".to_string())
145        })?;
146
147        plugins.remove(name).ok_or_else(|| {
148            TrustformersError::plugin_error(format!("Plugin '{}' not found", name))
149        })?;
150
151        Ok(())
152    }
153
154    /// Gets information about a registered plugin.
155    ///
156    /// # Arguments
157    ///
158    /// * `name` - Plugin name
159    ///
160    /// # Returns
161    ///
162    /// Plugin information if found.
163    pub fn get_plugin_info(&self, name: &str) -> Result<PluginInfo> {
164        let plugins = self.plugins.read().map_err(|_| {
165            TrustformersError::lock_error("Failed to acquire read lock".to_string())
166        })?;
167
168        plugins
169            .get(name)
170            .cloned()
171            .ok_or_else(|| TrustformersError::plugin_error(format!("Plugin '{}' not found", name)))
172    }
173
174    /// Checks if a plugin is currently loaded.
175    ///
176    /// # Arguments
177    ///
178    /// * `name` - Plugin name
179    ///
180    /// # Returns
181    ///
182    /// `true` if the plugin is loaded, `false` otherwise.
183    pub fn is_loaded(&self, name: &str) -> bool {
184        self.loaded.read().map(|loaded| loaded.contains_key(name)).unwrap_or(false)
185    }
186
187    /// Unloads a plugin from memory.
188    ///
189    /// # Arguments
190    ///
191    /// * `name` - Plugin name to unload
192    ///
193    /// # Returns
194    ///
195    /// `Ok(())` on success.
196    pub fn unload_plugin(&self, name: &str) -> Result<()> {
197        let mut loaded = self.loaded.write().map_err(|_| {
198            TrustformersError::lock_error("Failed to acquire write lock".to_string())
199        })?;
200
201        if let Some(mut plugin) = loaded.remove(name) {
202            plugin.cleanup()?;
203        }
204
205        Ok(())
206    }
207
208    /// Adds a directory to the plugin search path.
209    ///
210    /// # Arguments
211    ///
212    /// * `path` - Directory path to add
213    pub fn add_search_path<P: AsRef<Path>>(&self, path: P) {
214        if let Ok(mut paths) = self.search_paths.write() {
215            paths.push(path.as_ref().to_path_buf());
216        }
217    }
218
219    /// Removes a directory from the plugin search path.
220    ///
221    /// # Arguments
222    ///
223    /// * `path` - Directory path to remove
224    pub fn remove_search_path<P: AsRef<Path>>(&self, path: P) {
225        if let Ok(mut paths) = self.search_paths.write() {
226            paths.retain(|p| p != path.as_ref());
227        }
228    }
229
230    /// Scans all search paths for plugins and registers them.
231    ///
232    /// # Returns
233    ///
234    /// Number of plugins discovered and registered.
235    pub fn scan_for_plugins(&self) -> Result<usize> {
236        let search_paths = self
237            .search_paths
238            .read()
239            .map_err(|_| TrustformersError::lock_error("Failed to acquire read lock".to_string()))?
240            .clone();
241
242        let mut count = 0;
243
244        for path in &search_paths {
245            if let Ok(entries) = std::fs::read_dir(path) {
246                for entry in entries.flatten() {
247                    let path = entry.path();
248                    if path.is_file() && self.is_plugin_file(&path) {
249                        if let Ok(info) = self.loader.load_plugin_info(&path) {
250                            let name = info.name().to_string();
251                            if self.register(&name, info).is_ok() {
252                                count += 1;
253                            }
254                        }
255                    }
256                }
257            }
258        }
259
260        Ok(count)
261    }
262
263    /// Checks if a file is a plugin file based on extension.
264    ///
265    /// # Arguments
266    ///
267    /// * `path` - File path to check
268    ///
269    /// # Returns
270    ///
271    /// `true` if it's a plugin file.
272    fn is_plugin_file(&self, path: &Path) -> bool {
273        if let Some(ext) = path.extension() {
274            let ext = ext.to_string_lossy().to_lowercase();
275            matches!(ext.as_str(), "so" | "dll" | "dylib" | "wasm")
276        } else {
277            false
278        }
279    }
280
281    /// Validates plugin dependencies.
282    ///
283    /// # Arguments
284    ///
285    /// * `name` - Plugin name to validate
286    ///
287    /// # Returns
288    ///
289    /// `Ok(())` if all dependencies are satisfied.
290    pub fn validate_dependencies(&self, name: &str) -> Result<()> {
291        let info = self.get_plugin_info(name)?;
292
293        for dep in info.dependencies() {
294            if !dep.optional {
295                let dep_info = self.get_plugin_info(&dep.name)?;
296                if !dep.requirement.matches(dep_info.version()) {
297                    return Err(TrustformersError::plugin_error(format!(
298                        "Plugin '{}' requires '{}' {} but found {}",
299                        name,
300                        dep.name,
301                        dep.requirement,
302                        dep_info.version()
303                    )));
304                }
305            }
306        }
307
308        Ok(())
309    }
310
311    /// Exports the registry to a configuration file.
312    ///
313    /// # Arguments
314    ///
315    /// * `path` - File path to write the configuration
316    ///
317    /// # Returns
318    ///
319    /// `Ok(())` on success.
320    pub fn export_config<P: AsRef<Path>>(&self, path: P) -> Result<()> {
321        let plugins = self.plugins.read().map_err(|_| {
322            TrustformersError::lock_error("Failed to acquire read lock".to_string())
323        })?;
324
325        let config = RegistryConfig {
326            plugins: plugins.clone(),
327            ..self.config.clone()
328        };
329
330        let json = serde_json::to_string_pretty(&config)
331            .map_err(|e| TrustformersError::serialization_error(e.to_string()))?;
332
333        std::fs::write(path, json).map_err(|e| TrustformersError::io_error(e.to_string()))?;
334
335        Ok(())
336    }
337
338    /// Imports registry configuration from a file.
339    ///
340    /// # Arguments
341    ///
342    /// * `path` - File path to read the configuration from
343    ///
344    /// # Returns
345    ///
346    /// `Ok(())` on success.
347    pub fn import_config<P: AsRef<Path>>(&self, path: P) -> Result<()> {
348        let json = std::fs::read_to_string(path)
349            .map_err(|e| TrustformersError::io_error(e.to_string()))?;
350
351        let config: RegistryConfig = serde_json::from_str(&json)
352            .map_err(|e| TrustformersError::serialization_error(e.to_string()))?;
353
354        let mut plugins = self.plugins.write().map_err(|_| {
355            TrustformersError::lock_error("Failed to acquire write lock".to_string())
356        })?;
357
358        for (name, info) in config.plugins {
359            plugins.insert(name, info);
360        }
361
362        Ok(())
363    }
364
365    /// Gets registry statistics.
366    ///
367    /// # Returns
368    ///
369    /// Registry statistics.
370    pub fn stats(&self) -> Result<RegistryStats> {
371        let plugins = self.plugins.read().map_err(|_| {
372            TrustformersError::lock_error("Failed to acquire read lock".to_string())
373        })?;
374        let loaded = self.loaded.read().map_err(|_| {
375            TrustformersError::lock_error("Failed to acquire read lock".to_string())
376        })?;
377
378        Ok(RegistryStats {
379            total_plugins: plugins.len(),
380            loaded_plugins: loaded.len(),
381            search_paths: self.search_paths.read().map(|paths| paths.len()).unwrap_or(0),
382        })
383    }
384}
385
386impl PluginManager for PluginRegistry {
387    fn discover_plugins(&self) -> Result<HashMap<String, PluginInfo>> {
388        let plugins = self.plugins.read().map_err(|_| {
389            TrustformersError::lock_error("Failed to acquire read lock".to_string())
390        })?;
391        Ok(plugins.clone())
392    }
393
394    fn is_compatible(&self, name: &str, version: &str) -> Result<bool> {
395        let info = self.get_plugin_info(name)?;
396        Ok(info.is_compatible_with("trustformers-core", version))
397    }
398
399    fn load_plugin(&self, name: &str) -> Result<Box<dyn Plugin>> {
400        // Check if already loaded
401        {
402            let loaded = self.loaded.read().map_err(|_| {
403                TrustformersError::lock_error("Failed to acquire read lock".to_string())
404            })?;
405            if let Some(plugin) = loaded.get(name) {
406                return Ok(plugin.clone());
407            }
408        }
409
410        // Validate dependencies
411        self.validate_dependencies(name)?;
412
413        // Get plugin info
414        let info = self.get_plugin_info(name)?;
415
416        // Load the plugin
417        let mut plugin = self.loader.load_plugin(&info)?;
418        plugin.initialize()?;
419
420        // Store in loaded cache
421        let plugin_clone = plugin.clone();
422        {
423            let mut loaded = self.loaded.write().map_err(|_| {
424                TrustformersError::lock_error("Failed to acquire write lock".to_string())
425            })?;
426            loaded.insert(name.to_string(), plugin);
427        }
428
429        Ok(plugin_clone)
430    }
431
432    fn list_plugins(&self) -> Vec<String> {
433        self.plugins
434            .read()
435            .map(|plugins| plugins.keys().cloned().collect())
436            .unwrap_or_default()
437    }
438}
439
440impl Default for PluginRegistry {
441    fn default() -> Self {
442        Self::new()
443    }
444}
445
446/// Registry configuration.
447#[derive(Debug, Clone, Serialize, Deserialize)]
448pub struct RegistryConfig {
449    /// Registered plugins
450    #[serde(default)]
451    pub plugins: HashMap<String, PluginInfo>,
452    /// Maximum number of loaded plugins
453    #[serde(default = "default_max_loaded")]
454    pub max_loaded_plugins: usize,
455    /// Enable automatic plugin discovery
456    #[serde(default = "default_auto_discovery")]
457    pub auto_discovery: bool,
458    /// Plugin cache directory
459    #[serde(default)]
460    pub cache_dir: Option<PathBuf>,
461    /// Plugin load timeout in seconds
462    #[serde(default = "default_load_timeout")]
463    pub load_timeout_secs: u64,
464}
465
466impl Default for RegistryConfig {
467    fn default() -> Self {
468        Self {
469            plugins: HashMap::new(),
470            max_loaded_plugins: default_max_loaded(),
471            auto_discovery: default_auto_discovery(),
472            cache_dir: None,
473            load_timeout_secs: default_load_timeout(),
474        }
475    }
476}
477
478fn default_max_loaded() -> usize {
479    100
480}
481fn default_auto_discovery() -> bool {
482    true
483}
484fn default_load_timeout() -> u64 {
485    30
486}
487
488/// Registry statistics.
489#[derive(Debug, Clone, Serialize, Deserialize)]
490pub struct RegistryStats {
491    /// Total number of registered plugins
492    pub total_plugins: usize,
493    /// Number of currently loaded plugins
494    pub loaded_plugins: usize,
495    /// Number of search paths
496    pub search_paths: usize,
497}