Skip to main content

trustformers_core/plugins/
traits.rs

1//! Core plugin traits and interfaces.
2
3use crate::errors::Result;
4use crate::tensor::Tensor;
5use serde::{Deserialize, Serialize};
6use std::any::Any;
7use std::collections::HashMap;
8
9/// Core trait for plugin implementations.
10///
11/// The `Plugin` trait provides metadata, configuration, and lifecycle management
12/// capabilities for dynamic plugin systems. Plugins can provide Layer functionality
13/// through the `forward` method.
14///
15/// # Requirements
16///
17/// Plugins must be:
18/// - Thread-safe (`Send + Sync`)
19/// - Cloneable for multi-instance usage
20///
21/// # Example
22///
23/// ```no_run
24/// use trustformers_core::plugins::Plugin;
25/// use trustformers_core::tensor::Tensor;
26/// use trustformers_core::errors::Result;
27/// use std::collections::HashMap;
28///
29/// #[derive(Debug, Clone)]
30/// struct CustomAttentionPlugin {
31///     hidden_size: usize,
32///     num_heads: usize,
33///     config: HashMap<String, serde_json::Value>,
34/// }
35///
36/// impl Plugin for CustomAttentionPlugin {
37///     fn name(&self) -> &str {
38///         "custom_attention"
39///     }
40///
41///     fn version(&self) -> &str {
42///         "1.0.0"
43///     }
44///
45///     fn description(&self) -> &str {
46///         "A custom attention mechanism with optimized kernels"
47///     }
48///
49///     fn configure(&mut self, config: HashMap<String, serde_json::Value>) -> Result<()> {
50///         self.config = config;
51///         Ok(())
52///     }
53///
54///     fn get_config(&self) -> &HashMap<String, serde_json::Value> {
55///         &self.config
56///     }
57///
58///     fn forward(&self, input: Tensor) -> Result<Tensor> {
59///         // Custom attention implementation
60///         Ok(input) // Simplified
61///     }
62///
63///     fn as_any(&self) -> &dyn std::any::Any {
64///         self
65///     }
66/// }
67/// ```
68pub trait Plugin: Send + Sync + ClonePlugin + std::fmt::Debug {
69    /// Returns the plugin's unique name.
70    ///
71    /// The name should be unique across all plugins and follow a consistent
72    /// naming convention (e.g., "vendor.component_name" or "custom_attention").
73    ///
74    /// # Returns
75    ///
76    /// A string slice containing the plugin name.
77    fn name(&self) -> &str;
78
79    /// Returns the plugin's version string.
80    ///
81    /// Should follow semantic versioning (e.g., "1.2.3").
82    ///
83    /// # Returns
84    ///
85    /// A string slice containing the version.
86    fn version(&self) -> &str;
87
88    /// Returns a human-readable description of the plugin.
89    ///
90    /// # Returns
91    ///
92    /// A string slice describing the plugin's functionality.
93    fn description(&self) -> &str;
94
95    /// Configures the plugin with the provided parameters.
96    ///
97    /// # Arguments
98    ///
99    /// * `config` - A map of configuration parameters
100    ///
101    /// # Returns
102    ///
103    /// Returns `Ok(())` on successful configuration, or an error if
104    /// the configuration is invalid.
105    ///
106    /// # Errors
107    ///
108    /// May return errors for:
109    /// - Invalid configuration parameters
110    /// - Missing required parameters
111    /// - Parameter validation failures
112    fn configure(&mut self, config: HashMap<String, serde_json::Value>) -> Result<()>;
113
114    /// Returns the current plugin configuration.
115    ///
116    /// # Returns
117    ///
118    /// A reference to the plugin's configuration map.
119    fn get_config(&self) -> &HashMap<String, serde_json::Value>;
120
121    /// Validates the current configuration.
122    ///
123    /// This method should check that all configuration parameters are valid
124    /// and compatible with each other.
125    ///
126    /// # Returns
127    ///
128    /// Returns `Ok(())` if the configuration is valid, or an error describing
129    /// the validation failure.
130    fn validate_config(&self) -> Result<()> {
131        Ok(())
132    }
133
134    /// Initializes the plugin for use.
135    ///
136    /// This method is called after configuration and before the plugin
137    /// is used for computation. It should set up any internal state,
138    /// allocate resources, or perform other initialization tasks.
139    ///
140    /// # Returns
141    ///
142    /// Returns `Ok(())` on successful initialization, or an error if
143    /// initialization fails.
144    ///
145    /// # Errors
146    ///
147    /// May return errors for:
148    /// - Resource allocation failures
149    /// - Invalid configuration
150    /// - Hardware compatibility issues
151    fn initialize(&mut self) -> Result<()> {
152        self.validate_config()
153    }
154
155    /// Cleans up plugin resources.
156    ///
157    /// This method is called when the plugin is no longer needed.
158    /// It should release any allocated resources, close connections,
159    /// or perform other cleanup tasks.
160    ///
161    /// # Returns
162    ///
163    /// Returns `Ok(())` on successful cleanup, or an error if
164    /// cleanup fails.
165    fn cleanup(&mut self) -> Result<()> {
166        Ok(())
167    }
168
169    /// Returns the plugin as an `Any` trait object for downcasting.
170    ///
171    /// This enables type-safe downcasting to concrete plugin types
172    /// when needed for advanced functionality.
173    ///
174    /// # Returns
175    ///
176    /// A reference to self as an `Any` trait object.
177    fn as_any(&self) -> &dyn Any;
178
179    /// Returns plugin dependencies.
180    ///
181    /// Lists other plugins or system components that this plugin
182    /// requires to function correctly.
183    ///
184    /// # Returns
185    ///
186    /// A vector of dependency specifications (e.g., "plugin_name >= 1.0.0").
187    fn dependencies(&self) -> Vec<String> {
188        Vec::new()
189    }
190
191    /// Returns plugin capabilities.
192    ///
193    /// Describes what features or operations this plugin supports.
194    /// This can be used for plugin discovery and compatibility checking.
195    ///
196    /// # Returns
197    ///
198    /// A vector of capability strings.
199    fn capabilities(&self) -> Vec<String> {
200        Vec::new()
201    }
202
203    /// Returns plugin tags for categorization.
204    ///
205    /// Tags help organize and discover plugins by functionality
206    /// (e.g., "attention", "optimization", "quantization").
207    ///
208    /// # Returns
209    ///
210    /// A vector of tag strings.
211    fn tags(&self) -> Vec<String> {
212        Vec::new()
213    }
214
215    /// Performs the forward computation of this plugin.
216    ///
217    /// This method provides the core computational functionality of the plugin,
218    /// accepting and returning tensors.
219    ///
220    /// # Arguments
221    ///
222    /// * `input` - The input tensor to process
223    ///
224    /// # Returns
225    ///
226    /// Returns `Ok(output)` containing the plugin's output tensor, or an error
227    /// if computation fails.
228    ///
229    /// # Errors
230    ///
231    /// May return errors for:
232    /// - Invalid input dimensions
233    /// - Numerical errors during computation
234    /// - Resource allocation failures
235    fn forward(&self, input: Tensor) -> Result<Tensor>;
236}
237
238/// Helper trait for cloning plugin trait objects.
239///
240/// This trait provides a way to clone boxed plugin instances,
241/// which is needed for plugin registry management.
242pub trait ClonePlugin {
243    /// Creates a clone of the plugin.
244    ///
245    /// # Returns
246    ///
247    /// A boxed clone of the plugin.
248    fn clone_plugin(&self) -> Box<dyn Plugin>;
249}
250
251impl<T> ClonePlugin for T
252where
253    T: Plugin + Clone + 'static,
254{
255    fn clone_plugin(&self) -> Box<dyn Plugin> {
256        Box::new(self.clone())
257    }
258}
259
260impl Clone for Box<dyn Plugin> {
261    fn clone(&self) -> Self {
262        self.clone_plugin()
263    }
264}
265
266/// Plugin lifecycle events.
267///
268/// These events are fired during different stages of a plugin's lifecycle,
269/// allowing for monitoring, logging, and custom handling.
270#[derive(Debug, Clone, Serialize, Deserialize)]
271pub enum PluginEvent {
272    /// Plugin is being loaded.
273    Loading { name: String, version: String },
274    /// Plugin has been successfully loaded.
275    Loaded { name: String, version: String },
276    /// Plugin configuration is being updated.
277    Configuring {
278        name: String,
279        config: HashMap<String, serde_json::Value>,
280    },
281    /// Plugin is being initialized.
282    Initializing { name: String },
283    /// Plugin has been successfully initialized.
284    Initialized { name: String },
285    /// Plugin is being unloaded.
286    Unloading { name: String },
287    /// Plugin has been unloaded.
288    Unloaded { name: String },
289    /// Plugin encountered an error.
290    Error { name: String, error: String },
291}
292
293/// Trait for handling plugin lifecycle events.
294///
295/// Implement this trait to receive notifications about plugin
296/// lifecycle events for monitoring, logging, or custom handling.
297pub trait PluginEventHandler: Send + Sync {
298    /// Handles a plugin lifecycle event.
299    ///
300    /// # Arguments
301    ///
302    /// * `event` - The lifecycle event that occurred
303    ///
304    /// # Returns
305    ///
306    /// Returns `Ok(())` on successful handling, or an error if
307    /// handling fails.
308    fn handle_event(&self, event: &PluginEvent) -> Result<()>;
309}
310
311/// Plugin execution context.
312///
313/// Provides runtime information and utilities to plugins during execution.
314/// This includes access to shared resources, configuration, and monitoring.
315#[derive(Debug)]
316pub struct PluginContext {
317    /// Plugin name
318    pub name: String,
319    /// Runtime configuration
320    pub config: HashMap<String, serde_json::Value>,
321    /// Shared resources
322    pub resources: HashMap<String, Box<dyn Any + Send + Sync>>,
323    /// Performance metrics
324    pub metrics: HashMap<String, f64>,
325}
326
327impl PluginContext {
328    /// Creates a new plugin context.
329    ///
330    /// # Arguments
331    ///
332    /// * `name` - The plugin name
333    /// * `config` - Initial configuration
334    ///
335    /// # Returns
336    ///
337    /// A new plugin context instance.
338    pub fn new(name: String, config: HashMap<String, serde_json::Value>) -> Self {
339        Self {
340            name,
341            config,
342            resources: HashMap::new(),
343            metrics: HashMap::new(),
344        }
345    }
346
347    /// Adds a shared resource to the context.
348    ///
349    /// # Arguments
350    ///
351    /// * `key` - The resource key
352    /// * `resource` - The resource to add
353    pub fn add_resource<T: Any + Send + Sync>(&mut self, key: String, resource: T) {
354        self.resources.insert(key, Box::new(resource));
355    }
356
357    /// Gets a shared resource from the context.
358    ///
359    /// # Arguments
360    ///
361    /// * `key` - The resource key
362    ///
363    /// # Returns
364    ///
365    /// An optional reference to the resource.
366    pub fn get_resource<T: Any + Send + Sync>(&self, key: &str) -> Option<&T> {
367        self.resources.get(key).and_then(|r| r.downcast_ref::<T>())
368    }
369
370    /// Updates a performance metric.
371    ///
372    /// # Arguments
373    ///
374    /// * `key` - The metric name
375    /// * `value` - The metric value
376    pub fn update_metric(&mut self, key: String, value: f64) {
377        self.metrics.insert(key, value);
378    }
379
380    /// Gets a performance metric.
381    ///
382    /// # Arguments
383    ///
384    /// * `key` - The metric name
385    ///
386    /// # Returns
387    ///
388    /// The metric value if it exists.
389    pub fn get_metric(&self, key: &str) -> Option<f64> {
390        self.metrics.get(key).copied()
391    }
392}