Skip to main content

sklears_compose/plugin_architecture/
types.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
6use serde::{Deserialize, Serialize};
7use sklears_core::{
8    error::{Result as SklResult, SklearsError},
9    traits::Estimator,
10    types::Float,
11};
12use std::collections::{HashMap, HashSet};
13use std::fmt::Debug;
14use std::path::PathBuf;
15use std::sync::{Arc, RwLock};
16
17use super::functions::{ComponentFactory, Plugin, PluginComponent};
18
19/// Component configuration
20#[derive(Debug, Clone)]
21pub struct ComponentConfig {
22    /// Component type identifier
23    pub component_type: String,
24    /// Component parameters
25    pub parameters: HashMap<String, ConfigValue>,
26    /// Component metadata
27    pub metadata: HashMap<String, String>,
28}
29/// Parameter schema definition
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct ParameterSchema {
32    /// Parameter name
33    pub name: String,
34    /// Parameter type
35    pub parameter_type: ParameterType,
36    /// Parameter description
37    pub description: String,
38    /// Default value
39    pub default_value: Option<ConfigValue>,
40}
41/// Example transformer plugin implementation
42#[derive(Debug)]
43pub struct ExampleTransformerPlugin {
44    pub metadata: PluginMetadata,
45}
46impl ExampleTransformerPlugin {
47    pub fn new() -> Self {
48        Self {
49            metadata: PluginMetadata {
50                name: "Example Transformer Plugin".to_string(),
51                version: "1.0.0".to_string(),
52                description: "Example transformer plugin for demonstration".to_string(),
53                author: "Sklears Team".to_string(),
54                license: "MIT".to_string(),
55                min_api_version: "1.0.0".to_string(),
56                dependencies: vec![],
57                capabilities: vec!["transformer".to_string()],
58                tags: vec!["example".to_string(), "transformer".to_string()],
59                documentation_url: None,
60                source_url: None,
61            },
62        }
63    }
64    pub fn with_metadata(metadata: PluginMetadata) -> Self {
65        Self { metadata }
66    }
67}
68/// Plugin loading and management system
69pub struct PluginLoader {
70    /// Plugin configuration
71    config: PluginConfig,
72    /// Loaded plugin libraries
73    loaded_libraries: HashMap<String, PluginLibrary>,
74}
75impl PluginLoader {
76    /// Create a new plugin loader
77    #[must_use]
78    pub fn new(config: PluginConfig) -> Self {
79        Self {
80            config,
81            loaded_libraries: HashMap::new(),
82        }
83    }
84    /// Load plugins from configured directories
85    pub fn load_plugins(&mut self, registry: &PluginRegistry) -> SklResult<()> {
86        let plugin_dirs = self.config.plugin_dirs.clone();
87        for plugin_dir in &plugin_dirs {
88            self.load_plugins_from_dir(plugin_dir, registry)?;
89        }
90        Ok(())
91    }
92    /// Load plugins from a specific directory
93    fn load_plugins_from_dir(&mut self, dir: &PathBuf, registry: &PluginRegistry) -> SklResult<()> {
94        println!("Loading plugins from directory: {dir:?}");
95        self.load_example_plugins(registry)?;
96        Ok(())
97    }
98    /// Load example plugins for demonstration
99    pub fn load_example_plugins(&mut self, registry: &PluginRegistry) -> SklResult<()> {
100        let transformer_plugin = Box::new(ExampleTransformerPlugin::new());
101        let transformer_factory = Box::new(ExampleTransformerFactory::new());
102        registry.register_plugin(
103            "example_transformer",
104            transformer_plugin,
105            transformer_factory,
106        )?;
107        let estimator_plugin = Box::new(ExampleEstimatorPlugin::new());
108        let estimator_factory = Box::new(ExampleEstimatorFactory::new());
109        registry.register_plugin("example_estimator", estimator_plugin, estimator_factory)?;
110        Ok(())
111    }
112}
113/// Configuration value types
114#[derive(Debug, Clone, Serialize, Deserialize)]
115#[serde(untagged)]
116pub enum ConfigValue {
117    /// String value
118    String(String),
119    /// Integer value
120    Integer(i64),
121    /// Float value
122    Float(f64),
123    /// Boolean value
124    Boolean(bool),
125    /// Array of values
126    Array(Vec<ConfigValue>),
127    /// Object (nested configuration)
128    Object(HashMap<String, ConfigValue>),
129}
130/// Example estimator factory
131#[derive(Debug)]
132pub struct ExampleEstimatorFactory;
133impl ExampleEstimatorFactory {
134    pub fn new() -> Self {
135        Self
136    }
137}
138/// Component execution context
139#[derive(Debug, Clone)]
140pub struct ComponentContext {
141    /// Component ID
142    pub component_id: String,
143    /// Pipeline context
144    pub pipeline_id: Option<String>,
145    /// Execution parameters
146    pub execution_params: HashMap<String, String>,
147    /// Logger handle
148    pub logger: Option<String>,
149}
150/// Plugin capabilities
151#[derive(Debug, Clone, PartialEq, Eq, Hash)]
152pub enum PluginCapability {
153    /// Can create transformers
154    Transformer,
155    /// Can create estimators
156    Estimator,
157    /// Can create preprocessors
158    Preprocessor,
159    /// Can create feature selectors
160    FeatureSelector,
161    /// Can create ensemble methods
162    Ensemble,
163    /// Can create custom metrics
164    Metric,
165    /// Can create data loaders
166    DataLoader,
167    /// Can create visualizers
168    Visualizer,
169    /// Custom capability
170    Custom(String),
171}
172/// Plugin registry for managing custom components
173pub struct PluginRegistry {
174    /// Registered plugins
175    plugins: RwLock<HashMap<String, Box<dyn Plugin>>>,
176    /// Plugin metadata
177    metadata: RwLock<HashMap<String, PluginMetadata>>,
178    /// Component factories
179    factories: RwLock<HashMap<String, Box<dyn ComponentFactory>>>,
180    /// Dependency graph
181    dependencies: RwLock<HashMap<String, Vec<String>>>,
182    /// Plugin loading configuration
183    config: PluginConfig,
184}
185impl PluginRegistry {
186    #[must_use]
187    pub fn new(config: PluginConfig) -> Self {
188        Self {
189            plugins: RwLock::new(HashMap::new()),
190            metadata: RwLock::new(HashMap::new()),
191            factories: RwLock::new(HashMap::new()),
192            dependencies: RwLock::new(HashMap::new()),
193            config,
194        }
195    }
196    /// Register a plugin
197    pub fn register_plugin(
198        &self,
199        name: &str,
200        plugin: Box<dyn Plugin>,
201        factory: Box<dyn ComponentFactory>,
202    ) -> SklResult<()> {
203        let metadata = plugin.metadata().clone();
204        self.validate_plugin(&metadata)?;
205        self.check_dependencies(&metadata)?;
206        {
207            let mut plugins = self.plugins.write().map_err(|_| {
208                SklearsError::InvalidOperation(
209                    "Failed to acquire write lock for plugins".to_string(),
210                )
211            })?;
212            plugins.insert(name.to_string(), plugin);
213        }
214        {
215            let mut meta = self.metadata.write().map_err(|_| {
216                SklearsError::InvalidOperation(
217                    "Failed to acquire write lock for metadata".to_string(),
218                )
219            })?;
220            meta.insert(name.to_string(), metadata.clone());
221        }
222        {
223            let mut factories = self.factories.write().map_err(|_| {
224                SklearsError::InvalidOperation(
225                    "Failed to acquire write lock for factories".to_string(),
226                )
227            })?;
228            factories.insert(name.to_string(), factory);
229        }
230        {
231            let mut deps = self.dependencies.write().map_err(|_| {
232                SklearsError::InvalidOperation(
233                    "Failed to acquire write lock for dependencies".to_string(),
234                )
235            })?;
236            deps.insert(name.to_string(), metadata.dependencies);
237        }
238        Ok(())
239    }
240    /// Unregister a plugin
241    pub fn unregister_plugin(&self, name: &str) -> SklResult<()> {
242        let dependents = self.get_dependents(name)?;
243        if !dependents.is_empty() {
244            return Err(SklearsError::InvalidOperation(format!(
245                "Cannot unregister plugin '{name}' - it has dependents: {dependents:?}"
246            )));
247        }
248        if let Ok(mut plugins) = self.plugins.write() {
249            if let Some(mut plugin) = plugins.remove(name) {
250                plugin.shutdown()?;
251            }
252        }
253        if let Ok(mut metadata) = self.metadata.write() {
254            metadata.remove(name);
255        }
256        if let Ok(mut factories) = self.factories.write() {
257            factories.remove(name);
258        }
259        if let Ok(mut dependencies) = self.dependencies.write() {
260            dependencies.remove(name);
261        }
262        Ok(())
263    }
264    /// Create a component from a plugin
265    pub fn create_component(
266        &self,
267        plugin_name: &str,
268        component_type: &str,
269        config: &ComponentConfig,
270    ) -> SklResult<Box<dyn PluginComponent>> {
271        let factory = {
272            let factories = self.factories.read().map_err(|_| {
273                SklearsError::InvalidOperation(
274                    "Failed to acquire read lock for factories".to_string(),
275                )
276            })?;
277            factories
278                .get(plugin_name)
279                .ok_or_else(|| {
280                    SklearsError::InvalidInput(format!("Plugin '{plugin_name}' not found"))
281                })?
282                .create(component_type, config)?
283        };
284        Ok(factory)
285    }
286    /// List all registered plugins
287    pub fn list_plugins(&self) -> SklResult<Vec<String>> {
288        let plugins = self.plugins.read().map_err(|_| {
289            SklearsError::InvalidOperation("Failed to acquire read lock for plugins".to_string())
290        })?;
291        Ok(plugins.keys().cloned().collect())
292    }
293    /// Get plugin metadata
294    pub fn get_plugin_metadata(&self, name: &str) -> SklResult<PluginMetadata> {
295        let metadata = self.metadata.read().map_err(|_| {
296            SklearsError::InvalidOperation("Failed to acquire read lock for metadata".to_string())
297        })?;
298        metadata
299            .get(name)
300            .cloned()
301            .ok_or_else(|| SklearsError::InvalidInput(format!("Plugin '{name}' not found")))
302    }
303    /// List available component types for a plugin
304    pub fn list_component_types(&self, plugin_name: &str) -> SklResult<Vec<String>> {
305        let factories = self.factories.read().map_err(|_| {
306            SklearsError::InvalidOperation("Failed to acquire read lock for factories".to_string())
307        })?;
308        let factory = factories.get(plugin_name).ok_or_else(|| {
309            SklearsError::InvalidInput(format!("Plugin '{plugin_name}' not found"))
310        })?;
311        Ok(factory.available_types())
312    }
313    /// Get component schema
314    pub fn get_component_schema(
315        &self,
316        plugin_name: &str,
317        component_type: &str,
318    ) -> SklResult<Option<ComponentSchema>> {
319        let factories = self.factories.read().map_err(|_| {
320            SklearsError::InvalidOperation("Failed to acquire read lock for factories".to_string())
321        })?;
322        let factory = factories.get(plugin_name).ok_or_else(|| {
323            SklearsError::InvalidInput(format!("Plugin '{plugin_name}' not found"))
324        })?;
325        Ok(factory.get_schema(component_type))
326    }
327    /// Validate plugin compatibility
328    fn validate_plugin(&self, metadata: &PluginMetadata) -> SklResult<()> {
329        if !self.is_api_version_compatible(&metadata.min_api_version) {
330            return Err(SklearsError::InvalidInput(format!(
331                "Plugin requires API version {} but current version is incompatible",
332                metadata.min_api_version
333            )));
334        }
335        Ok(())
336    }
337    /// Check if API version is compatible
338    fn is_api_version_compatible(&self, required_version: &str) -> bool {
339        const CURRENT_API_VERSION: &str = "1.0.0";
340        required_version <= CURRENT_API_VERSION
341    }
342    /// Check plugin dependencies
343    fn check_dependencies(&self, metadata: &PluginMetadata) -> SklResult<()> {
344        let plugins = self.plugins.read().map_err(|_| {
345            SklearsError::InvalidOperation("Failed to acquire read lock for plugins".to_string())
346        })?;
347        for dependency in &metadata.dependencies {
348            if !plugins.contains_key(dependency) {
349                return Err(SklearsError::InvalidInput(format!(
350                    "Missing dependency: {dependency}"
351                )));
352            }
353        }
354        Ok(())
355    }
356    /// Get plugins that depend on the given plugin
357    fn get_dependents(&self, plugin_name: &str) -> SklResult<Vec<String>> {
358        let dependencies = self.dependencies.read().map_err(|_| {
359            SklearsError::InvalidOperation(
360                "Failed to acquire read lock for dependencies".to_string(),
361            )
362        })?;
363        let dependents: Vec<String> = dependencies
364            .iter()
365            .filter(|(_, deps)| deps.contains(&plugin_name.to_string()))
366            .map(|(name, _)| name.clone())
367            .collect();
368        Ok(dependents)
369    }
370    /// Initialize all plugins
371    pub fn initialize_all(&self) -> SklResult<()> {
372        let plugin_names = self.list_plugins()?;
373        for name in plugin_names {
374            self.initialize_plugin(&name)?;
375        }
376        Ok(())
377    }
378    /// Initialize a specific plugin
379    fn initialize_plugin(&self, name: &str) -> SklResult<()> {
380        let context = PluginContext {
381            registry_id: "main".to_string(),
382            working_dir: std::env::current_dir().unwrap_or_default(),
383            config: HashMap::new(),
384            available_apis: HashSet::new(),
385        };
386        let mut plugins = self.plugins.write().map_err(|_| {
387            SklearsError::InvalidOperation("Failed to acquire write lock for plugins".to_string())
388        })?;
389        if let Some(plugin) = plugins.get_mut(name) {
390            plugin.initialize(&context)?;
391        }
392        Ok(())
393    }
394    /// Shutdown all plugins
395    pub fn shutdown_all(&self) -> SklResult<()> {
396        let mut plugins = self.plugins.write().map_err(|_| {
397            SklearsError::InvalidOperation("Failed to acquire write lock for plugins".to_string())
398        })?;
399        for (_, plugin) in plugins.iter_mut() {
400            let _ = plugin.shutdown();
401        }
402        Ok(())
403    }
404}
405/// Component schema for validation
406#[derive(Debug, Clone, Serialize, Deserialize)]
407pub struct ComponentSchema {
408    /// Schema name
409    pub name: String,
410    /// Required parameters
411    pub required_parameters: Vec<ParameterSchema>,
412    /// Optional parameters
413    pub optional_parameters: Vec<ParameterSchema>,
414    /// Parameter constraints
415    pub constraints: Vec<ParameterConstraint>,
416}
417/// Loaded plugin library information
418#[derive(Debug)]
419struct PluginLibrary {
420    /// Library path
421    path: PathBuf,
422    /// Library handle (placeholder - in real implementation would use libloading)
423    handle: String,
424    /// Exported plugins
425    plugins: Vec<String>,
426}
427/// Plugin configuration
428#[derive(Debug, Clone)]
429pub struct PluginConfig {
430    /// Directories to search for plugins
431    pub plugin_dirs: Vec<PathBuf>,
432    /// Auto-load plugins on startup
433    pub auto_load: bool,
434    /// Enable plugin sandboxing
435    pub sandbox: bool,
436    /// Maximum plugin execution time
437    pub max_execution_time: std::time::Duration,
438    /// Enable plugin validation
439    pub validate_plugins: bool,
440}
441/// Example regressor component
442#[derive(Debug, Clone)]
443pub struct ExampleRegressor {
444    pub config: ComponentConfig,
445    pub learning_rate: f64,
446    pub fitted: bool,
447    pub coefficients: Option<Array1<f64>>,
448}
449impl ExampleRegressor {
450    pub fn new(config: ComponentConfig) -> Self {
451        let learning_rate = config
452            .parameters
453            .get("learning_rate")
454            .and_then(|v| match v {
455                ConfigValue::Float(f) => Some(*f),
456                _ => None,
457            })
458            .unwrap_or(0.01);
459        Self {
460            config,
461            learning_rate,
462            fitted: false,
463            coefficients: None,
464        }
465    }
466}
467/// Example transformer factory
468#[derive(Debug)]
469pub struct ExampleTransformerFactory;
470impl ExampleTransformerFactory {
471    pub fn new() -> Self {
472        Self
473    }
474}
475/// Parameter types
476#[derive(Debug, Clone, Serialize, Deserialize)]
477pub enum ParameterType {
478    /// String parameter
479    String {
480        min_length: Option<usize>,
481        max_length: Option<usize>,
482    },
483    /// Integer parameter
484    Integer {
485        min_value: Option<i64>,
486        max_value: Option<i64>,
487    },
488    /// Float parameter
489    Float {
490        min_value: Option<f64>,
491        max_value: Option<f64>,
492    },
493    /// Boolean parameter
494    Boolean,
495    /// Enum parameter
496    Enum { values: Vec<String> },
497    /// Array parameter
498    Array {
499        item_type: Box<ParameterType>,
500        min_items: Option<usize>,
501        max_items: Option<usize>,
502    },
503    /// Object parameter
504    Object { schema: ComponentSchema },
505}
506/// Example scaler component
507#[derive(Debug, Clone)]
508pub struct ExampleScaler {
509    pub config: ComponentConfig,
510    pub scale_factor: f64,
511    pub fitted: bool,
512}
513impl ExampleScaler {
514    pub fn new(config: ComponentConfig) -> Self {
515        let scale_factor = config
516            .parameters
517            .get("scale_factor")
518            .and_then(|v| match v {
519                ConfigValue::Float(f) => Some(*f),
520                _ => None,
521            })
522            .unwrap_or(1.0);
523        Self {
524            config,
525            scale_factor,
526            fitted: false,
527        }
528    }
529}
530/// Plugin context provided during initialization
531#[derive(Debug, Clone)]
532pub struct PluginContext {
533    /// Registry reference
534    pub registry_id: String,
535    /// Plugin working directory
536    pub working_dir: PathBuf,
537    /// Configuration parameters
538    pub config: HashMap<String, String>,
539    /// Available APIs
540    pub available_apis: HashSet<String>,
541}
542/// Example estimator plugin
543#[derive(Debug)]
544pub struct ExampleEstimatorPlugin {
545    pub metadata: PluginMetadata,
546}
547impl ExampleEstimatorPlugin {
548    pub fn new() -> Self {
549        Self {
550            metadata: PluginMetadata {
551                name: "Example Estimator Plugin".to_string(),
552                version: "1.0.0".to_string(),
553                description: "Example estimator plugin for demonstration".to_string(),
554                author: "Sklears Team".to_string(),
555                license: "MIT".to_string(),
556                min_api_version: "1.0.0".to_string(),
557                dependencies: vec![],
558                capabilities: vec!["estimator".to_string()],
559                tags: vec!["example".to_string(), "estimator".to_string()],
560                documentation_url: None,
561                source_url: None,
562            },
563        }
564    }
565}
566/// Parameter constraints
567#[derive(Debug, Clone, Serialize, Deserialize)]
568pub struct ParameterConstraint {
569    /// Constraint name
570    pub name: String,
571    /// Constraint expression
572    pub expression: String,
573    /// Constraint description
574    pub description: String,
575}
576/// Plugin metadata information
577#[derive(Debug, Clone, Serialize, Deserialize)]
578pub struct PluginMetadata {
579    /// Plugin name
580    pub name: String,
581    /// Plugin version
582    pub version: String,
583    /// Plugin description
584    pub description: String,
585    /// Plugin author
586    pub author: String,
587    /// Plugin license
588    pub license: String,
589    /// Minimum API version required
590    pub min_api_version: String,
591    /// Plugin dependencies
592    pub dependencies: Vec<String>,
593    /// Plugin capabilities
594    pub capabilities: Vec<String>,
595    /// Plugin tags
596    pub tags: Vec<String>,
597    /// Plugin documentation URL
598    pub documentation_url: Option<String>,
599    /// Plugin source code URL
600    pub source_url: Option<String>,
601}