Skip to main content

trustformers_core/plugins/
loader.rs

1//! Plugin loading infrastructure.
2
3use crate::errors::{Result, TrustformersError};
4use crate::plugins::{Plugin, PluginInfo};
5use std::collections::HashMap;
6use std::path::Path;
7use std::sync::{Arc, Mutex};
8
9/// Plugin loader for dynamic loading and instantiation.
10///
11/// The `PluginLoader` handles the runtime loading of plugin libraries,
12/// symbol resolution, and plugin instantiation. It supports various
13/// plugin formats and provides caching for performance.
14///
15/// # Supported Formats
16///
17/// - Dynamic libraries (.so, .dll, .dylib)
18/// - WebAssembly modules (.wasm)
19/// - Static plugins (compiled-in)
20///
21/// # Example
22///
23/// ```no_run
24/// use trustformers_core::plugins::{PluginLoader, PluginInfo};
25/// use std::path::Path;
26///
27/// let loader = PluginLoader::new();
28///
29/// // Load plugin info from metadata
30/// let info = loader.load_plugin_info(Path::new("plugins/custom_attention.so")).unwrap();
31///
32/// // Load the actual plugin
33/// let plugin = loader.load_plugin(&info).unwrap();
34/// ```
35#[derive(Debug)]
36pub struct PluginLoader {
37    /// Cache of loaded libraries
38    library_cache: Arc<Mutex<HashMap<String, LibraryHandle>>>,
39    /// Static plugin registry
40    static_plugins: Arc<Mutex<HashMap<String, StaticPluginFactory>>>,
41    /// Cache hit counter
42    cache_hits: Arc<Mutex<u64>>,
43    /// Cache miss counter
44    cache_misses: Arc<Mutex<u64>>,
45    /// Loader configuration
46    #[allow(dead_code)]
47    config: LoaderConfig,
48}
49
50impl PluginLoader {
51    /// Creates a new plugin loader.
52    ///
53    /// # Returns
54    ///
55    /// A new loader instance with default configuration.
56    pub fn new() -> Self {
57        Self {
58            library_cache: Arc::new(Mutex::new(HashMap::new())),
59            static_plugins: Arc::new(Mutex::new(HashMap::new())),
60            cache_hits: Arc::new(Mutex::new(0)),
61            cache_misses: Arc::new(Mutex::new(0)),
62            config: LoaderConfig::default(),
63        }
64    }
65
66    /// Creates a plugin loader with custom configuration.
67    ///
68    /// # Arguments
69    ///
70    /// * `config` - Loader configuration
71    ///
72    /// # Returns
73    ///
74    /// A new loader instance.
75    pub fn with_config(config: LoaderConfig) -> Self {
76        Self {
77            library_cache: Arc::new(Mutex::new(HashMap::new())),
78            static_plugins: Arc::new(Mutex::new(HashMap::new())),
79            cache_hits: Arc::new(Mutex::new(0)),
80            cache_misses: Arc::new(Mutex::new(0)),
81            config,
82        }
83    }
84
85    /// Loads plugin information from a file or metadata.
86    ///
87    /// # Arguments
88    ///
89    /// * `path` - Path to the plugin file or metadata
90    ///
91    /// # Returns
92    ///
93    /// Plugin information if successfully loaded.
94    ///
95    /// # Errors
96    ///
97    /// - File not found
98    /// - Invalid plugin format
99    /// - Metadata parsing errors
100    pub fn load_plugin_info<P: AsRef<Path>>(&self, path: P) -> Result<PluginInfo> {
101        let path = path.as_ref();
102
103        // Check for companion metadata file
104        let metadata_path = path.with_extension("json");
105        if metadata_path.exists() {
106            return self.load_metadata_file(&metadata_path);
107        }
108
109        // Try to load embedded metadata from the plugin file
110        self.load_embedded_metadata(path)
111    }
112
113    /// Loads a plugin instance from plugin information.
114    ///
115    /// # Arguments
116    ///
117    /// * `info` - Plugin information containing loading details
118    ///
119    /// # Returns
120    ///
121    /// A boxed plugin instance ready for use.
122    ///
123    /// # Errors
124    ///
125    /// - Plugin file not found
126    /// - Symbol resolution failures
127    /// - Plugin initialization errors
128    pub fn load_plugin(&self, info: &PluginInfo) -> Result<Box<dyn Plugin>> {
129        // Check if it's a static plugin first
130        if let Ok(static_plugins) = self.static_plugins.lock() {
131            if let Some(factory) = static_plugins.get(info.name()) {
132                return factory();
133            }
134        }
135
136        // Load as dynamic library
137        self.load_dynamic_plugin(info)
138    }
139
140    /// Registers a static plugin factory.
141    ///
142    /// Static plugins are compiled into the binary and don't require
143    /// dynamic loading. This method registers a factory function
144    /// that can create instances of the plugin.
145    ///
146    /// # Arguments
147    ///
148    /// * `name` - Plugin name
149    /// * `factory` - Factory function for creating plugin instances
150    ///
151    /// # Returns
152    ///
153    /// `Ok(())` on successful registration.
154    pub fn register_static_plugin(&self, name: &str, factory: StaticPluginFactory) -> Result<()> {
155        let mut static_plugins = self
156            .static_plugins
157            .lock()
158            .map_err(|_| TrustformersError::lock_error("Failed to acquire lock".to_string()))?;
159
160        static_plugins.insert(name.to_string(), factory);
161        Ok(())
162    }
163
164    /// Unloads a plugin library from the cache.
165    ///
166    /// # Arguments
167    ///
168    /// * `name` - Plugin name to unload
169    ///
170    /// # Returns
171    ///
172    /// `Ok(())` on success.
173    pub fn unload_library(&self, name: &str) -> Result<()> {
174        let mut cache = self
175            .library_cache
176            .lock()
177            .map_err(|_| TrustformersError::lock_error("Failed to acquire lock".to_string()))?;
178
179        cache.remove(name);
180        Ok(())
181    }
182
183    /// Clears all cached libraries.
184    pub fn clear_cache(&self) -> Result<()> {
185        let mut cache = self
186            .library_cache
187            .lock()
188            .map_err(|_| TrustformersError::lock_error("Failed to acquire lock".to_string()))?;
189
190        cache.clear();
191        Ok(())
192    }
193
194    /// Gets loader statistics.
195    ///
196    /// # Returns
197    ///
198    /// Loader statistics including cache information.
199    pub fn stats(&self) -> Result<LoaderStats> {
200        let cache = self
201            .library_cache
202            .lock()
203            .map_err(|_| TrustformersError::lock_error("Failed to acquire lock".to_string()))?;
204        let static_plugins = self
205            .static_plugins
206            .lock()
207            .map_err(|_| TrustformersError::lock_error("Failed to acquire lock".to_string()))?;
208
209        let cache_hits = self
210            .cache_hits
211            .lock()
212            .map_err(|_| TrustformersError::lock_error("Failed to acquire lock".to_string()))?;
213        let cache_misses = self
214            .cache_misses
215            .lock()
216            .map_err(|_| TrustformersError::lock_error("Failed to acquire lock".to_string()))?;
217
218        Ok(LoaderStats {
219            cached_libraries: cache.len(),
220            static_plugins: static_plugins.len(),
221            cache_hits: *cache_hits,
222            cache_misses: *cache_misses,
223        })
224    }
225
226    /// Loads metadata from a JSON file.
227    fn load_metadata_file<P: AsRef<Path>>(&self, path: P) -> Result<PluginInfo> {
228        let content = std::fs::read_to_string(path)
229            .map_err(|e| TrustformersError::io_error(format!("Failed to read metadata: {}", e)))?;
230
231        serde_json::from_str(&content)
232            .map_err(|e| TrustformersError::serialization_error(format!("Invalid metadata: {}", e)))
233    }
234
235    /// Loads embedded metadata from a plugin file.
236    fn load_embedded_metadata<P: AsRef<Path>>(&self, path: P) -> Result<PluginInfo> {
237        // This is a simplified implementation
238        // In a real implementation, you would read metadata from the plugin file
239        // For now, we'll create basic info from the filename
240        let path = path.as_ref();
241        let name = path.file_stem().and_then(|s| s.to_str()).ok_or_else(|| {
242            TrustformersError::plugin_error("Invalid plugin filename".to_string())
243        })?;
244
245        Ok(PluginInfo::new(
246            name,
247            "1.0.0",
248            "Dynamically loaded plugin",
249            &[],
250        ))
251    }
252
253    /// Loads a plugin as a dynamic library.
254    fn load_dynamic_plugin(&self, info: &PluginInfo) -> Result<Box<dyn Plugin>> {
255        // Check cache first
256        {
257            let cache = self
258                .library_cache
259                .lock()
260                .map_err(|_| TrustformersError::lock_error("Failed to acquire lock".to_string()))?;
261
262            if let Some(handle) = cache.get(info.name()) {
263                // Increment cache hit counter
264                if let Ok(mut hits) = self.cache_hits.lock() {
265                    *hits += 1;
266                }
267                return handle.create_plugin();
268            }
269        }
270
271        // Cache miss - increment counter
272        if let Ok(mut misses) = self.cache_misses.lock() {
273            *misses += 1;
274        }
275
276        // Load the library
277        let handle = LibraryHandle::load(info)?;
278        let plugin = handle.create_plugin()?;
279
280        // Cache the handle
281        {
282            let mut cache = self
283                .library_cache
284                .lock()
285                .map_err(|_| TrustformersError::lock_error("Failed to acquire lock".to_string()))?;
286            cache.insert(info.name().to_string(), handle);
287        }
288
289        Ok(plugin)
290    }
291}
292
293impl Default for PluginLoader {
294    fn default() -> Self {
295        Self::new()
296    }
297}
298
299/// Type alias for static plugin factory functions.
300pub type StaticPluginFactory = fn() -> Result<Box<dyn Plugin>>;
301
302/// Handle to a loaded dynamic library.
303///
304/// This struct manages the lifetime of a loaded plugin library
305/// and provides symbol resolution for plugin creation.
306#[derive(Debug)]
307struct LibraryHandle {
308    /// Library name
309    #[allow(dead_code)]
310    name: String,
311    /// Entry point information
312    _entry_point: String,
313}
314
315impl LibraryHandle {
316    /// Loads a plugin library.
317    ///
318    /// # Arguments
319    ///
320    /// * `info` - Plugin information
321    ///
322    /// # Returns
323    ///
324    /// A library handle if loading succeeds.
325    fn load(info: &PluginInfo) -> Result<Self> {
326        // This is a simplified implementation
327        // In a real implementation, you would use libloading or similar
328        // to actually load the dynamic library
329
330        Ok(Self {
331            name: info.name().to_string(),
332            _entry_point: info.entry_point().to_string(),
333        })
334    }
335
336    /// Creates a plugin instance from this library.
337    ///
338    /// # Returns
339    ///
340    /// A boxed plugin instance.
341    fn create_plugin(&self) -> Result<Box<dyn Plugin>> {
342        // This is a simplified implementation
343        // In a real implementation, you would resolve the plugin factory symbol
344        // and call it to create the plugin instance
345
346        Err(TrustformersError::plugin_error(
347            "Dynamic plugin loading not implemented in this example".to_string(),
348        ))
349    }
350}
351
352/// Plugin loader configuration.
353#[derive(Debug, Clone)]
354pub struct LoaderConfig {
355    /// Enable library caching
356    pub cache_enabled: bool,
357    /// Maximum number of cached libraries
358    pub max_cached_libraries: usize,
359    /// Plugin load timeout in seconds
360    pub load_timeout_secs: u64,
361    /// Enable lazy loading
362    pub lazy_loading: bool,
363    /// Symbol name prefix for plugin factories
364    pub symbol_prefix: String,
365}
366
367impl Default for LoaderConfig {
368    fn default() -> Self {
369        Self {
370            cache_enabled: true,
371            max_cached_libraries: 50,
372            load_timeout_secs: 30,
373            lazy_loading: true,
374            symbol_prefix: "create_plugin".to_string(),
375        }
376    }
377}
378
379/// Plugin loader statistics.
380#[derive(Debug, Clone)]
381pub struct LoaderStats {
382    /// Number of cached libraries
383    pub cached_libraries: usize,
384    /// Number of registered static plugins
385    pub static_plugins: usize,
386    /// Cache hit count
387    pub cache_hits: u64,
388    /// Cache miss count
389    pub cache_misses: u64,
390}
391
392/// Plugin loading error types.
393#[derive(Debug, Clone)]
394pub enum LoadError {
395    /// Library file not found
396    LibraryNotFound(String),
397    /// Symbol not found in library
398    SymbolNotFound(String),
399    /// Plugin initialization failed
400    InitializationFailed(String),
401    /// Invalid plugin format
402    InvalidFormat(String),
403    /// Version incompatibility
404    VersionMismatch(String),
405    /// Dependency not satisfied
406    DependencyNotSatisfied(String),
407}
408
409impl std::fmt::Display for LoadError {
410    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
411        match self {
412            LoadError::LibraryNotFound(path) => write!(f, "Library not found: {}", path),
413            LoadError::SymbolNotFound(symbol) => write!(f, "Symbol not found: {}", symbol),
414            LoadError::InitializationFailed(msg) => write!(f, "Initialization failed: {}", msg),
415            LoadError::InvalidFormat(msg) => write!(f, "Invalid format: {}", msg),
416            LoadError::VersionMismatch(msg) => write!(f, "Version mismatch: {}", msg),
417            LoadError::DependencyNotSatisfied(dep) => {
418                write!(f, "Dependency not satisfied: {}", dep)
419            },
420        }
421    }
422}
423
424impl std::error::Error for LoadError {}
425
426/// Macro for registering static plugins.
427///
428/// This macro generates the boilerplate code needed to register
429/// a static plugin with the loader.
430///
431/// # Example
432///
433/// ```no_run
434/// use trustformers_core::register_static_plugin;
435/// use trustformers_core::plugins::Plugin;
436/// use trustformers_core::tensor::Tensor;
437/// use trustformers_core::errors::Result;
438/// use std::collections::HashMap;
439///
440/// #[derive(Debug, Clone, Default)]
441/// struct MyPlugin {
442///     config: HashMap<String, serde_json::Value>,
443/// }
444/// impl Plugin for MyPlugin {
445///     fn name(&self) -> &str { "my_plugin" }
446///     fn version(&self) -> &str { "1.0.0" }
447///     fn description(&self) -> &str { "My custom plugin" }
448///     fn configure(&mut self, config: HashMap<String, serde_json::Value>) -> Result<()> {
449///         self.config = config; Ok(())
450///     }
451///     fn get_config(&self) -> &HashMap<String, serde_json::Value> { &self.config }
452///     fn as_any(&self) -> &dyn std::any::Any { self }
453///     fn forward(&self, input: Tensor) -> Result<Tensor> { Ok(input) }
454/// }
455///
456/// register_static_plugin!(MyPlugin, "my_plugin");
457/// ```
458#[macro_export]
459macro_rules! register_static_plugin {
460    ($plugin_type:ty, $name:expr) => {
461        pub fn register_plugin() -> $crate::errors::Result<Box<dyn $crate::plugins::Plugin>> {
462            Ok(Box::new(<$plugin_type>::default()))
463        }
464
465        #[cfg(feature = "static-plugins")]
466        #[ctor::ctor]
467        fn register() {
468            use $crate::plugins::PluginLoader;
469
470            let loader = PluginLoader::new();
471            let _ = loader.register_static_plugin($name, register_plugin);
472        }
473    };
474}