trustformers_core/plugins/loader.rs
1//! Plugin loading infrastructure.
2
3use crate::errors::{Result, TrustformersError};
4use crate::plugins::{Plugin, PluginInfo};
5use std::collections::HashMap;
6use std::path::Path;
7use std::sync::{Arc, Mutex};
8
9/// Plugin loader for dynamic loading and instantiation.
10///
11/// The `PluginLoader` handles the runtime loading of plugin libraries,
12/// symbol resolution, and plugin instantiation. It supports various
13/// plugin formats and provides caching for performance.
14///
15/// # Supported Formats
16///
17/// - Dynamic libraries (.so, .dll, .dylib)
18/// - WebAssembly modules (.wasm)
19/// - Static plugins (compiled-in)
20///
21/// # Example
22///
23/// ```no_run
24/// use trustformers_core::plugins::{PluginLoader, PluginInfo};
25/// use std::path::Path;
26///
27/// let loader = PluginLoader::new();
28///
29/// // Load plugin info from metadata
30/// let info = loader.load_plugin_info(Path::new("plugins/custom_attention.so")).unwrap();
31///
32/// // Load the actual plugin
33/// let plugin = loader.load_plugin(&info).unwrap();
34/// ```
35#[derive(Debug)]
36pub struct PluginLoader {
37 /// Cache of loaded libraries
38 library_cache: Arc<Mutex<HashMap<String, LibraryHandle>>>,
39 /// Static plugin registry
40 static_plugins: Arc<Mutex<HashMap<String, StaticPluginFactory>>>,
41 /// Cache hit counter
42 cache_hits: Arc<Mutex<u64>>,
43 /// Cache miss counter
44 cache_misses: Arc<Mutex<u64>>,
45 /// Loader configuration
46 #[allow(dead_code)]
47 config: LoaderConfig,
48}
49
50impl PluginLoader {
51 /// Creates a new plugin loader.
52 ///
53 /// # Returns
54 ///
55 /// A new loader instance with default configuration.
56 pub fn new() -> Self {
57 Self {
58 library_cache: Arc::new(Mutex::new(HashMap::new())),
59 static_plugins: Arc::new(Mutex::new(HashMap::new())),
60 cache_hits: Arc::new(Mutex::new(0)),
61 cache_misses: Arc::new(Mutex::new(0)),
62 config: LoaderConfig::default(),
63 }
64 }
65
66 /// Creates a plugin loader with custom configuration.
67 ///
68 /// # Arguments
69 ///
70 /// * `config` - Loader configuration
71 ///
72 /// # Returns
73 ///
74 /// A new loader instance.
75 pub fn with_config(config: LoaderConfig) -> Self {
76 Self {
77 library_cache: Arc::new(Mutex::new(HashMap::new())),
78 static_plugins: Arc::new(Mutex::new(HashMap::new())),
79 cache_hits: Arc::new(Mutex::new(0)),
80 cache_misses: Arc::new(Mutex::new(0)),
81 config,
82 }
83 }
84
85 /// Loads plugin information from a file or metadata.
86 ///
87 /// # Arguments
88 ///
89 /// * `path` - Path to the plugin file or metadata
90 ///
91 /// # Returns
92 ///
93 /// Plugin information if successfully loaded.
94 ///
95 /// # Errors
96 ///
97 /// - File not found
98 /// - Invalid plugin format
99 /// - Metadata parsing errors
100 pub fn load_plugin_info<P: AsRef<Path>>(&self, path: P) -> Result<PluginInfo> {
101 let path = path.as_ref();
102
103 // Check for companion metadata file
104 let metadata_path = path.with_extension("json");
105 if metadata_path.exists() {
106 return self.load_metadata_file(&metadata_path);
107 }
108
109 // Try to load embedded metadata from the plugin file
110 self.load_embedded_metadata(path)
111 }
112
113 /// Loads a plugin instance from plugin information.
114 ///
115 /// # Arguments
116 ///
117 /// * `info` - Plugin information containing loading details
118 ///
119 /// # Returns
120 ///
121 /// A boxed plugin instance ready for use.
122 ///
123 /// # Errors
124 ///
125 /// - Plugin file not found
126 /// - Symbol resolution failures
127 /// - Plugin initialization errors
128 pub fn load_plugin(&self, info: &PluginInfo) -> Result<Box<dyn Plugin>> {
129 // Check if it's a static plugin first
130 if let Ok(static_plugins) = self.static_plugins.lock() {
131 if let Some(factory) = static_plugins.get(info.name()) {
132 return factory();
133 }
134 }
135
136 // Load as dynamic library
137 self.load_dynamic_plugin(info)
138 }
139
140 /// Registers a static plugin factory.
141 ///
142 /// Static plugins are compiled into the binary and don't require
143 /// dynamic loading. This method registers a factory function
144 /// that can create instances of the plugin.
145 ///
146 /// # Arguments
147 ///
148 /// * `name` - Plugin name
149 /// * `factory` - Factory function for creating plugin instances
150 ///
151 /// # Returns
152 ///
153 /// `Ok(())` on successful registration.
154 pub fn register_static_plugin(&self, name: &str, factory: StaticPluginFactory) -> Result<()> {
155 let mut static_plugins = self
156 .static_plugins
157 .lock()
158 .map_err(|_| TrustformersError::lock_error("Failed to acquire lock".to_string()))?;
159
160 static_plugins.insert(name.to_string(), factory);
161 Ok(())
162 }
163
164 /// Unloads a plugin library from the cache.
165 ///
166 /// # Arguments
167 ///
168 /// * `name` - Plugin name to unload
169 ///
170 /// # Returns
171 ///
172 /// `Ok(())` on success.
173 pub fn unload_library(&self, name: &str) -> Result<()> {
174 let mut cache = self
175 .library_cache
176 .lock()
177 .map_err(|_| TrustformersError::lock_error("Failed to acquire lock".to_string()))?;
178
179 cache.remove(name);
180 Ok(())
181 }
182
183 /// Clears all cached libraries.
184 pub fn clear_cache(&self) -> Result<()> {
185 let mut cache = self
186 .library_cache
187 .lock()
188 .map_err(|_| TrustformersError::lock_error("Failed to acquire lock".to_string()))?;
189
190 cache.clear();
191 Ok(())
192 }
193
194 /// Gets loader statistics.
195 ///
196 /// # Returns
197 ///
198 /// Loader statistics including cache information.
199 pub fn stats(&self) -> Result<LoaderStats> {
200 let cache = self
201 .library_cache
202 .lock()
203 .map_err(|_| TrustformersError::lock_error("Failed to acquire lock".to_string()))?;
204 let static_plugins = self
205 .static_plugins
206 .lock()
207 .map_err(|_| TrustformersError::lock_error("Failed to acquire lock".to_string()))?;
208
209 let cache_hits = self
210 .cache_hits
211 .lock()
212 .map_err(|_| TrustformersError::lock_error("Failed to acquire lock".to_string()))?;
213 let cache_misses = self
214 .cache_misses
215 .lock()
216 .map_err(|_| TrustformersError::lock_error("Failed to acquire lock".to_string()))?;
217
218 Ok(LoaderStats {
219 cached_libraries: cache.len(),
220 static_plugins: static_plugins.len(),
221 cache_hits: *cache_hits,
222 cache_misses: *cache_misses,
223 })
224 }
225
226 /// Loads metadata from a JSON file.
227 fn load_metadata_file<P: AsRef<Path>>(&self, path: P) -> Result<PluginInfo> {
228 let content = std::fs::read_to_string(path)
229 .map_err(|e| TrustformersError::io_error(format!("Failed to read metadata: {}", e)))?;
230
231 serde_json::from_str(&content)
232 .map_err(|e| TrustformersError::serialization_error(format!("Invalid metadata: {}", e)))
233 }
234
235 /// Loads embedded metadata from a plugin file.
236 fn load_embedded_metadata<P: AsRef<Path>>(&self, path: P) -> Result<PluginInfo> {
237 // This is a simplified implementation
238 // In a real implementation, you would read metadata from the plugin file
239 // For now, we'll create basic info from the filename
240 let path = path.as_ref();
241 let name = path.file_stem().and_then(|s| s.to_str()).ok_or_else(|| {
242 TrustformersError::plugin_error("Invalid plugin filename".to_string())
243 })?;
244
245 Ok(PluginInfo::new(
246 name,
247 "1.0.0",
248 "Dynamically loaded plugin",
249 &[],
250 ))
251 }
252
253 /// Loads a plugin as a dynamic library.
254 fn load_dynamic_plugin(&self, info: &PluginInfo) -> Result<Box<dyn Plugin>> {
255 // Check cache first
256 {
257 let cache = self
258 .library_cache
259 .lock()
260 .map_err(|_| TrustformersError::lock_error("Failed to acquire lock".to_string()))?;
261
262 if let Some(handle) = cache.get(info.name()) {
263 // Increment cache hit counter
264 if let Ok(mut hits) = self.cache_hits.lock() {
265 *hits += 1;
266 }
267 return handle.create_plugin();
268 }
269 }
270
271 // Cache miss - increment counter
272 if let Ok(mut misses) = self.cache_misses.lock() {
273 *misses += 1;
274 }
275
276 // Load the library
277 let handle = LibraryHandle::load(info)?;
278 let plugin = handle.create_plugin()?;
279
280 // Cache the handle
281 {
282 let mut cache = self
283 .library_cache
284 .lock()
285 .map_err(|_| TrustformersError::lock_error("Failed to acquire lock".to_string()))?;
286 cache.insert(info.name().to_string(), handle);
287 }
288
289 Ok(plugin)
290 }
291}
292
293impl Default for PluginLoader {
294 fn default() -> Self {
295 Self::new()
296 }
297}
298
299/// Type alias for static plugin factory functions.
300pub type StaticPluginFactory = fn() -> Result<Box<dyn Plugin>>;
301
302/// Handle to a loaded dynamic library.
303///
304/// This struct manages the lifetime of a loaded plugin library
305/// and provides symbol resolution for plugin creation.
306#[derive(Debug)]
307struct LibraryHandle {
308 /// Library name
309 #[allow(dead_code)]
310 name: String,
311 /// Entry point information
312 _entry_point: String,
313}
314
315impl LibraryHandle {
316 /// Loads a plugin library.
317 ///
318 /// # Arguments
319 ///
320 /// * `info` - Plugin information
321 ///
322 /// # Returns
323 ///
324 /// A library handle if loading succeeds.
325 fn load(info: &PluginInfo) -> Result<Self> {
326 // This is a simplified implementation
327 // In a real implementation, you would use libloading or similar
328 // to actually load the dynamic library
329
330 Ok(Self {
331 name: info.name().to_string(),
332 _entry_point: info.entry_point().to_string(),
333 })
334 }
335
336 /// Creates a plugin instance from this library.
337 ///
338 /// # Returns
339 ///
340 /// A boxed plugin instance.
341 fn create_plugin(&self) -> Result<Box<dyn Plugin>> {
342 // This is a simplified implementation
343 // In a real implementation, you would resolve the plugin factory symbol
344 // and call it to create the plugin instance
345
346 Err(TrustformersError::plugin_error(
347 "Dynamic plugin loading not implemented in this example".to_string(),
348 ))
349 }
350}
351
352/// Plugin loader configuration.
353#[derive(Debug, Clone)]
354pub struct LoaderConfig {
355 /// Enable library caching
356 pub cache_enabled: bool,
357 /// Maximum number of cached libraries
358 pub max_cached_libraries: usize,
359 /// Plugin load timeout in seconds
360 pub load_timeout_secs: u64,
361 /// Enable lazy loading
362 pub lazy_loading: bool,
363 /// Symbol name prefix for plugin factories
364 pub symbol_prefix: String,
365}
366
367impl Default for LoaderConfig {
368 fn default() -> Self {
369 Self {
370 cache_enabled: true,
371 max_cached_libraries: 50,
372 load_timeout_secs: 30,
373 lazy_loading: true,
374 symbol_prefix: "create_plugin".to_string(),
375 }
376 }
377}
378
379/// Plugin loader statistics.
380#[derive(Debug, Clone)]
381pub struct LoaderStats {
382 /// Number of cached libraries
383 pub cached_libraries: usize,
384 /// Number of registered static plugins
385 pub static_plugins: usize,
386 /// Cache hit count
387 pub cache_hits: u64,
388 /// Cache miss count
389 pub cache_misses: u64,
390}
391
392/// Plugin loading error types.
393#[derive(Debug, Clone)]
394pub enum LoadError {
395 /// Library file not found
396 LibraryNotFound(String),
397 /// Symbol not found in library
398 SymbolNotFound(String),
399 /// Plugin initialization failed
400 InitializationFailed(String),
401 /// Invalid plugin format
402 InvalidFormat(String),
403 /// Version incompatibility
404 VersionMismatch(String),
405 /// Dependency not satisfied
406 DependencyNotSatisfied(String),
407}
408
409impl std::fmt::Display for LoadError {
410 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
411 match self {
412 LoadError::LibraryNotFound(path) => write!(f, "Library not found: {}", path),
413 LoadError::SymbolNotFound(symbol) => write!(f, "Symbol not found: {}", symbol),
414 LoadError::InitializationFailed(msg) => write!(f, "Initialization failed: {}", msg),
415 LoadError::InvalidFormat(msg) => write!(f, "Invalid format: {}", msg),
416 LoadError::VersionMismatch(msg) => write!(f, "Version mismatch: {}", msg),
417 LoadError::DependencyNotSatisfied(dep) => {
418 write!(f, "Dependency not satisfied: {}", dep)
419 },
420 }
421 }
422}
423
424impl std::error::Error for LoadError {}
425
426/// Macro for registering static plugins.
427///
428/// This macro generates the boilerplate code needed to register
429/// a static plugin with the loader.
430///
431/// # Example
432///
433/// ```no_run
434/// use trustformers_core::register_static_plugin;
435/// use trustformers_core::plugins::Plugin;
436/// use trustformers_core::tensor::Tensor;
437/// use trustformers_core::errors::Result;
438/// use std::collections::HashMap;
439///
440/// #[derive(Debug, Clone, Default)]
441/// struct MyPlugin {
442/// config: HashMap<String, serde_json::Value>,
443/// }
444/// impl Plugin for MyPlugin {
445/// fn name(&self) -> &str { "my_plugin" }
446/// fn version(&self) -> &str { "1.0.0" }
447/// fn description(&self) -> &str { "My custom plugin" }
448/// fn configure(&mut self, config: HashMap<String, serde_json::Value>) -> Result<()> {
449/// self.config = config; Ok(())
450/// }
451/// fn get_config(&self) -> &HashMap<String, serde_json::Value> { &self.config }
452/// fn as_any(&self) -> &dyn std::any::Any { self }
453/// fn forward(&self, input: Tensor) -> Result<Tensor> { Ok(input) }
454/// }
455///
456/// register_static_plugin!(MyPlugin, "my_plugin");
457/// ```
458#[macro_export]
459macro_rules! register_static_plugin {
460 ($plugin_type:ty, $name:expr) => {
461 pub fn register_plugin() -> $crate::errors::Result<Box<dyn $crate::plugins::Plugin>> {
462 Ok(Box::new(<$plugin_type>::default()))
463 }
464
465 #[cfg(feature = "static-plugins")]
466 #[ctor::ctor]
467 fn register() {
468 use $crate::plugins::PluginLoader;
469
470 let loader = PluginLoader::new();
471 let _ = loader.register_static_plugin($name, register_plugin);
472 }
473 };
474}