trustformers_core/plugins/traits.rs
1//! Core plugin traits and interfaces.
2
3use crate::errors::Result;
4use crate::tensor::Tensor;
5use serde::{Deserialize, Serialize};
6use std::any::Any;
7use std::collections::HashMap;
8
9/// Core trait for plugin implementations.
10///
11/// The `Plugin` trait provides metadata, configuration, and lifecycle management
12/// capabilities for dynamic plugin systems. Plugins can provide Layer functionality
13/// through the `forward` method.
14///
15/// # Requirements
16///
17/// Plugins must be:
18/// - Thread-safe (`Send + Sync`)
19/// - Cloneable for multi-instance usage
20///
21/// # Example
22///
23/// ```no_run
24/// use trustformers_core::plugins::Plugin;
25/// use trustformers_core::tensor::Tensor;
26/// use trustformers_core::errors::Result;
27/// use std::collections::HashMap;
28///
29/// #[derive(Debug, Clone)]
30/// struct CustomAttentionPlugin {
31/// hidden_size: usize,
32/// num_heads: usize,
33/// config: HashMap<String, serde_json::Value>,
34/// }
35///
36/// impl Plugin for CustomAttentionPlugin {
37/// fn name(&self) -> &str {
38/// "custom_attention"
39/// }
40///
41/// fn version(&self) -> &str {
42/// "1.0.0"
43/// }
44///
45/// fn description(&self) -> &str {
46/// "A custom attention mechanism with optimized kernels"
47/// }
48///
49/// fn configure(&mut self, config: HashMap<String, serde_json::Value>) -> Result<()> {
50/// self.config = config;
51/// Ok(())
52/// }
53///
54/// fn get_config(&self) -> &HashMap<String, serde_json::Value> {
55/// &self.config
56/// }
57///
58/// fn forward(&self, input: Tensor) -> Result<Tensor> {
59/// // Custom attention implementation
60/// Ok(input) // Simplified
61/// }
62///
63/// fn as_any(&self) -> &dyn std::any::Any {
64/// self
65/// }
66/// }
67/// ```
68pub trait Plugin: Send + Sync + ClonePlugin + std::fmt::Debug {
69 /// Returns the plugin's unique name.
70 ///
71 /// The name should be unique across all plugins and follow a consistent
72 /// naming convention (e.g., "vendor.component_name" or "custom_attention").
73 ///
74 /// # Returns
75 ///
76 /// A string slice containing the plugin name.
77 fn name(&self) -> &str;
78
79 /// Returns the plugin's version string.
80 ///
81 /// Should follow semantic versioning (e.g., "1.2.3").
82 ///
83 /// # Returns
84 ///
85 /// A string slice containing the version.
86 fn version(&self) -> &str;
87
88 /// Returns a human-readable description of the plugin.
89 ///
90 /// # Returns
91 ///
92 /// A string slice describing the plugin's functionality.
93 fn description(&self) -> &str;
94
95 /// Configures the plugin with the provided parameters.
96 ///
97 /// # Arguments
98 ///
99 /// * `config` - A map of configuration parameters
100 ///
101 /// # Returns
102 ///
103 /// Returns `Ok(())` on successful configuration, or an error if
104 /// the configuration is invalid.
105 ///
106 /// # Errors
107 ///
108 /// May return errors for:
109 /// - Invalid configuration parameters
110 /// - Missing required parameters
111 /// - Parameter validation failures
112 fn configure(&mut self, config: HashMap<String, serde_json::Value>) -> Result<()>;
113
114 /// Returns the current plugin configuration.
115 ///
116 /// # Returns
117 ///
118 /// A reference to the plugin's configuration map.
119 fn get_config(&self) -> &HashMap<String, serde_json::Value>;
120
121 /// Validates the current configuration.
122 ///
123 /// This method should check that all configuration parameters are valid
124 /// and compatible with each other.
125 ///
126 /// # Returns
127 ///
128 /// Returns `Ok(())` if the configuration is valid, or an error describing
129 /// the validation failure.
130 fn validate_config(&self) -> Result<()> {
131 Ok(())
132 }
133
134 /// Initializes the plugin for use.
135 ///
136 /// This method is called after configuration and before the plugin
137 /// is used for computation. It should set up any internal state,
138 /// allocate resources, or perform other initialization tasks.
139 ///
140 /// # Returns
141 ///
142 /// Returns `Ok(())` on successful initialization, or an error if
143 /// initialization fails.
144 ///
145 /// # Errors
146 ///
147 /// May return errors for:
148 /// - Resource allocation failures
149 /// - Invalid configuration
150 /// - Hardware compatibility issues
151 fn initialize(&mut self) -> Result<()> {
152 self.validate_config()
153 }
154
155 /// Cleans up plugin resources.
156 ///
157 /// This method is called when the plugin is no longer needed.
158 /// It should release any allocated resources, close connections,
159 /// or perform other cleanup tasks.
160 ///
161 /// # Returns
162 ///
163 /// Returns `Ok(())` on successful cleanup, or an error if
164 /// cleanup fails.
165 fn cleanup(&mut self) -> Result<()> {
166 Ok(())
167 }
168
169 /// Returns the plugin as an `Any` trait object for downcasting.
170 ///
171 /// This enables type-safe downcasting to concrete plugin types
172 /// when needed for advanced functionality.
173 ///
174 /// # Returns
175 ///
176 /// A reference to self as an `Any` trait object.
177 fn as_any(&self) -> &dyn Any;
178
179 /// Returns plugin dependencies.
180 ///
181 /// Lists other plugins or system components that this plugin
182 /// requires to function correctly.
183 ///
184 /// # Returns
185 ///
186 /// A vector of dependency specifications (e.g., "plugin_name >= 1.0.0").
187 fn dependencies(&self) -> Vec<String> {
188 Vec::new()
189 }
190
191 /// Returns plugin capabilities.
192 ///
193 /// Describes what features or operations this plugin supports.
194 /// This can be used for plugin discovery and compatibility checking.
195 ///
196 /// # Returns
197 ///
198 /// A vector of capability strings.
199 fn capabilities(&self) -> Vec<String> {
200 Vec::new()
201 }
202
203 /// Returns plugin tags for categorization.
204 ///
205 /// Tags help organize and discover plugins by functionality
206 /// (e.g., "attention", "optimization", "quantization").
207 ///
208 /// # Returns
209 ///
210 /// A vector of tag strings.
211 fn tags(&self) -> Vec<String> {
212 Vec::new()
213 }
214
215 /// Performs the forward computation of this plugin.
216 ///
217 /// This method provides the core computational functionality of the plugin,
218 /// accepting and returning tensors.
219 ///
220 /// # Arguments
221 ///
222 /// * `input` - The input tensor to process
223 ///
224 /// # Returns
225 ///
226 /// Returns `Ok(output)` containing the plugin's output tensor, or an error
227 /// if computation fails.
228 ///
229 /// # Errors
230 ///
231 /// May return errors for:
232 /// - Invalid input dimensions
233 /// - Numerical errors during computation
234 /// - Resource allocation failures
235 fn forward(&self, input: Tensor) -> Result<Tensor>;
236}
237
238/// Helper trait for cloning plugin trait objects.
239///
240/// This trait provides a way to clone boxed plugin instances,
241/// which is needed for plugin registry management.
242pub trait ClonePlugin {
243 /// Creates a clone of the plugin.
244 ///
245 /// # Returns
246 ///
247 /// A boxed clone of the plugin.
248 fn clone_plugin(&self) -> Box<dyn Plugin>;
249}
250
251impl<T> ClonePlugin for T
252where
253 T: Plugin + Clone + 'static,
254{
255 fn clone_plugin(&self) -> Box<dyn Plugin> {
256 Box::new(self.clone())
257 }
258}
259
260impl Clone for Box<dyn Plugin> {
261 fn clone(&self) -> Self {
262 self.clone_plugin()
263 }
264}
265
266/// Plugin lifecycle events.
267///
268/// These events are fired during different stages of a plugin's lifecycle,
269/// allowing for monitoring, logging, and custom handling.
270#[derive(Debug, Clone, Serialize, Deserialize)]
271pub enum PluginEvent {
272 /// Plugin is being loaded.
273 Loading { name: String, version: String },
274 /// Plugin has been successfully loaded.
275 Loaded { name: String, version: String },
276 /// Plugin configuration is being updated.
277 Configuring {
278 name: String,
279 config: HashMap<String, serde_json::Value>,
280 },
281 /// Plugin is being initialized.
282 Initializing { name: String },
283 /// Plugin has been successfully initialized.
284 Initialized { name: String },
285 /// Plugin is being unloaded.
286 Unloading { name: String },
287 /// Plugin has been unloaded.
288 Unloaded { name: String },
289 /// Plugin encountered an error.
290 Error { name: String, error: String },
291}
292
293/// Trait for handling plugin lifecycle events.
294///
295/// Implement this trait to receive notifications about plugin
296/// lifecycle events for monitoring, logging, or custom handling.
297pub trait PluginEventHandler: Send + Sync {
298 /// Handles a plugin lifecycle event.
299 ///
300 /// # Arguments
301 ///
302 /// * `event` - The lifecycle event that occurred
303 ///
304 /// # Returns
305 ///
306 /// Returns `Ok(())` on successful handling, or an error if
307 /// handling fails.
308 fn handle_event(&self, event: &PluginEvent) -> Result<()>;
309}
310
311/// Plugin execution context.
312///
313/// Provides runtime information and utilities to plugins during execution.
314/// This includes access to shared resources, configuration, and monitoring.
315#[derive(Debug)]
316pub struct PluginContext {
317 /// Plugin name
318 pub name: String,
319 /// Runtime configuration
320 pub config: HashMap<String, serde_json::Value>,
321 /// Shared resources
322 pub resources: HashMap<String, Box<dyn Any + Send + Sync>>,
323 /// Performance metrics
324 pub metrics: HashMap<String, f64>,
325}
326
327impl PluginContext {
328 /// Creates a new plugin context.
329 ///
330 /// # Arguments
331 ///
332 /// * `name` - The plugin name
333 /// * `config` - Initial configuration
334 ///
335 /// # Returns
336 ///
337 /// A new plugin context instance.
338 pub fn new(name: String, config: HashMap<String, serde_json::Value>) -> Self {
339 Self {
340 name,
341 config,
342 resources: HashMap::new(),
343 metrics: HashMap::new(),
344 }
345 }
346
347 /// Adds a shared resource to the context.
348 ///
349 /// # Arguments
350 ///
351 /// * `key` - The resource key
352 /// * `resource` - The resource to add
353 pub fn add_resource<T: Any + Send + Sync>(&mut self, key: String, resource: T) {
354 self.resources.insert(key, Box::new(resource));
355 }
356
357 /// Gets a shared resource from the context.
358 ///
359 /// # Arguments
360 ///
361 /// * `key` - The resource key
362 ///
363 /// # Returns
364 ///
365 /// An optional reference to the resource.
366 pub fn get_resource<T: Any + Send + Sync>(&self, key: &str) -> Option<&T> {
367 self.resources.get(key).and_then(|r| r.downcast_ref::<T>())
368 }
369
370 /// Updates a performance metric.
371 ///
372 /// # Arguments
373 ///
374 /// * `key` - The metric name
375 /// * `value` - The metric value
376 pub fn update_metric(&mut self, key: String, value: f64) {
377 self.metrics.insert(key, value);
378 }
379
380 /// Gets a performance metric.
381 ///
382 /// # Arguments
383 ///
384 /// * `key` - The metric name
385 ///
386 /// # Returns
387 ///
388 /// The metric value if it exists.
389 pub fn get_metric(&self, key: &str) -> Option<f64> {
390 self.metrics.get(key).copied()
391 }
392}