1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
6use serde::{Deserialize, Serialize};
7use sklears_core::{
8 error::{Result as SklResult, SklearsError},
9 traits::Estimator,
10 types::Float,
11};
12use std::collections::{HashMap, HashSet};
13use std::fmt::Debug;
14use std::path::PathBuf;
15use std::sync::{Arc, RwLock};
16
17use super::functions::{ComponentFactory, Plugin, PluginComponent};
18
19#[derive(Debug, Clone)]
21pub struct ComponentConfig {
22 pub component_type: String,
24 pub parameters: HashMap<String, ConfigValue>,
26 pub metadata: HashMap<String, String>,
28}
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct ParameterSchema {
32 pub name: String,
34 pub parameter_type: ParameterType,
36 pub description: String,
38 pub default_value: Option<ConfigValue>,
40}
41#[derive(Debug)]
43pub struct ExampleTransformerPlugin {
44 pub metadata: PluginMetadata,
45}
46impl ExampleTransformerPlugin {
47 pub fn new() -> Self {
48 Self {
49 metadata: PluginMetadata {
50 name: "Example Transformer Plugin".to_string(),
51 version: "1.0.0".to_string(),
52 description: "Example transformer plugin for demonstration".to_string(),
53 author: "Sklears Team".to_string(),
54 license: "MIT".to_string(),
55 min_api_version: "1.0.0".to_string(),
56 dependencies: vec![],
57 capabilities: vec!["transformer".to_string()],
58 tags: vec!["example".to_string(), "transformer".to_string()],
59 documentation_url: None,
60 source_url: None,
61 },
62 }
63 }
64 pub fn with_metadata(metadata: PluginMetadata) -> Self {
65 Self { metadata }
66 }
67}
68pub struct PluginLoader {
70 config: PluginConfig,
72 loaded_libraries: HashMap<String, PluginLibrary>,
74}
75impl PluginLoader {
76 #[must_use]
78 pub fn new(config: PluginConfig) -> Self {
79 Self {
80 config,
81 loaded_libraries: HashMap::new(),
82 }
83 }
84 pub fn load_plugins(&mut self, registry: &PluginRegistry) -> SklResult<()> {
86 let plugin_dirs = self.config.plugin_dirs.clone();
87 for plugin_dir in &plugin_dirs {
88 self.load_plugins_from_dir(plugin_dir, registry)?;
89 }
90 Ok(())
91 }
92 fn load_plugins_from_dir(&mut self, dir: &PathBuf, registry: &PluginRegistry) -> SklResult<()> {
94 println!("Loading plugins from directory: {dir:?}");
95 self.load_example_plugins(registry)?;
96 Ok(())
97 }
98 pub fn load_example_plugins(&mut self, registry: &PluginRegistry) -> SklResult<()> {
100 let transformer_plugin = Box::new(ExampleTransformerPlugin::new());
101 let transformer_factory = Box::new(ExampleTransformerFactory::new());
102 registry.register_plugin(
103 "example_transformer",
104 transformer_plugin,
105 transformer_factory,
106 )?;
107 let estimator_plugin = Box::new(ExampleEstimatorPlugin::new());
108 let estimator_factory = Box::new(ExampleEstimatorFactory::new());
109 registry.register_plugin("example_estimator", estimator_plugin, estimator_factory)?;
110 Ok(())
111 }
112}
113#[derive(Debug, Clone, Serialize, Deserialize)]
115#[serde(untagged)]
116pub enum ConfigValue {
117 String(String),
119 Integer(i64),
121 Float(f64),
123 Boolean(bool),
125 Array(Vec<ConfigValue>),
127 Object(HashMap<String, ConfigValue>),
129}
130#[derive(Debug)]
132pub struct ExampleEstimatorFactory;
133impl ExampleEstimatorFactory {
134 pub fn new() -> Self {
135 Self
136 }
137}
138#[derive(Debug, Clone)]
140pub struct ComponentContext {
141 pub component_id: String,
143 pub pipeline_id: Option<String>,
145 pub execution_params: HashMap<String, String>,
147 pub logger: Option<String>,
149}
150#[derive(Debug, Clone, PartialEq, Eq, Hash)]
152pub enum PluginCapability {
153 Transformer,
155 Estimator,
157 Preprocessor,
159 FeatureSelector,
161 Ensemble,
163 Metric,
165 DataLoader,
167 Visualizer,
169 Custom(String),
171}
172pub struct PluginRegistry {
174 plugins: RwLock<HashMap<String, Box<dyn Plugin>>>,
176 metadata: RwLock<HashMap<String, PluginMetadata>>,
178 factories: RwLock<HashMap<String, Box<dyn ComponentFactory>>>,
180 dependencies: RwLock<HashMap<String, Vec<String>>>,
182 config: PluginConfig,
184}
185impl PluginRegistry {
186 #[must_use]
187 pub fn new(config: PluginConfig) -> Self {
188 Self {
189 plugins: RwLock::new(HashMap::new()),
190 metadata: RwLock::new(HashMap::new()),
191 factories: RwLock::new(HashMap::new()),
192 dependencies: RwLock::new(HashMap::new()),
193 config,
194 }
195 }
196 pub fn register_plugin(
198 &self,
199 name: &str,
200 plugin: Box<dyn Plugin>,
201 factory: Box<dyn ComponentFactory>,
202 ) -> SklResult<()> {
203 let metadata = plugin.metadata().clone();
204 self.validate_plugin(&metadata)?;
205 self.check_dependencies(&metadata)?;
206 {
207 let mut plugins = self.plugins.write().map_err(|_| {
208 SklearsError::InvalidOperation(
209 "Failed to acquire write lock for plugins".to_string(),
210 )
211 })?;
212 plugins.insert(name.to_string(), plugin);
213 }
214 {
215 let mut meta = self.metadata.write().map_err(|_| {
216 SklearsError::InvalidOperation(
217 "Failed to acquire write lock for metadata".to_string(),
218 )
219 })?;
220 meta.insert(name.to_string(), metadata.clone());
221 }
222 {
223 let mut factories = self.factories.write().map_err(|_| {
224 SklearsError::InvalidOperation(
225 "Failed to acquire write lock for factories".to_string(),
226 )
227 })?;
228 factories.insert(name.to_string(), factory);
229 }
230 {
231 let mut deps = self.dependencies.write().map_err(|_| {
232 SklearsError::InvalidOperation(
233 "Failed to acquire write lock for dependencies".to_string(),
234 )
235 })?;
236 deps.insert(name.to_string(), metadata.dependencies);
237 }
238 Ok(())
239 }
240 pub fn unregister_plugin(&self, name: &str) -> SklResult<()> {
242 let dependents = self.get_dependents(name)?;
243 if !dependents.is_empty() {
244 return Err(SklearsError::InvalidOperation(format!(
245 "Cannot unregister plugin '{name}' - it has dependents: {dependents:?}"
246 )));
247 }
248 if let Ok(mut plugins) = self.plugins.write() {
249 if let Some(mut plugin) = plugins.remove(name) {
250 plugin.shutdown()?;
251 }
252 }
253 if let Ok(mut metadata) = self.metadata.write() {
254 metadata.remove(name);
255 }
256 if let Ok(mut factories) = self.factories.write() {
257 factories.remove(name);
258 }
259 if let Ok(mut dependencies) = self.dependencies.write() {
260 dependencies.remove(name);
261 }
262 Ok(())
263 }
264 pub fn create_component(
266 &self,
267 plugin_name: &str,
268 component_type: &str,
269 config: &ComponentConfig,
270 ) -> SklResult<Box<dyn PluginComponent>> {
271 let factory = {
272 let factories = self.factories.read().map_err(|_| {
273 SklearsError::InvalidOperation(
274 "Failed to acquire read lock for factories".to_string(),
275 )
276 })?;
277 factories
278 .get(plugin_name)
279 .ok_or_else(|| {
280 SklearsError::InvalidInput(format!("Plugin '{plugin_name}' not found"))
281 })?
282 .create(component_type, config)?
283 };
284 Ok(factory)
285 }
286 pub fn list_plugins(&self) -> SklResult<Vec<String>> {
288 let plugins = self.plugins.read().map_err(|_| {
289 SklearsError::InvalidOperation("Failed to acquire read lock for plugins".to_string())
290 })?;
291 Ok(plugins.keys().cloned().collect())
292 }
293 pub fn get_plugin_metadata(&self, name: &str) -> SklResult<PluginMetadata> {
295 let metadata = self.metadata.read().map_err(|_| {
296 SklearsError::InvalidOperation("Failed to acquire read lock for metadata".to_string())
297 })?;
298 metadata
299 .get(name)
300 .cloned()
301 .ok_or_else(|| SklearsError::InvalidInput(format!("Plugin '{name}' not found")))
302 }
303 pub fn list_component_types(&self, plugin_name: &str) -> SklResult<Vec<String>> {
305 let factories = self.factories.read().map_err(|_| {
306 SklearsError::InvalidOperation("Failed to acquire read lock for factories".to_string())
307 })?;
308 let factory = factories.get(plugin_name).ok_or_else(|| {
309 SklearsError::InvalidInput(format!("Plugin '{plugin_name}' not found"))
310 })?;
311 Ok(factory.available_types())
312 }
313 pub fn get_component_schema(
315 &self,
316 plugin_name: &str,
317 component_type: &str,
318 ) -> SklResult<Option<ComponentSchema>> {
319 let factories = self.factories.read().map_err(|_| {
320 SklearsError::InvalidOperation("Failed to acquire read lock for factories".to_string())
321 })?;
322 let factory = factories.get(plugin_name).ok_or_else(|| {
323 SklearsError::InvalidInput(format!("Plugin '{plugin_name}' not found"))
324 })?;
325 Ok(factory.get_schema(component_type))
326 }
327 fn validate_plugin(&self, metadata: &PluginMetadata) -> SklResult<()> {
329 if !self.is_api_version_compatible(&metadata.min_api_version) {
330 return Err(SklearsError::InvalidInput(format!(
331 "Plugin requires API version {} but current version is incompatible",
332 metadata.min_api_version
333 )));
334 }
335 Ok(())
336 }
337 fn is_api_version_compatible(&self, required_version: &str) -> bool {
339 const CURRENT_API_VERSION: &str = "1.0.0";
340 required_version <= CURRENT_API_VERSION
341 }
342 fn check_dependencies(&self, metadata: &PluginMetadata) -> SklResult<()> {
344 let plugins = self.plugins.read().map_err(|_| {
345 SklearsError::InvalidOperation("Failed to acquire read lock for plugins".to_string())
346 })?;
347 for dependency in &metadata.dependencies {
348 if !plugins.contains_key(dependency) {
349 return Err(SklearsError::InvalidInput(format!(
350 "Missing dependency: {dependency}"
351 )));
352 }
353 }
354 Ok(())
355 }
356 fn get_dependents(&self, plugin_name: &str) -> SklResult<Vec<String>> {
358 let dependencies = self.dependencies.read().map_err(|_| {
359 SklearsError::InvalidOperation(
360 "Failed to acquire read lock for dependencies".to_string(),
361 )
362 })?;
363 let dependents: Vec<String> = dependencies
364 .iter()
365 .filter(|(_, deps)| deps.contains(&plugin_name.to_string()))
366 .map(|(name, _)| name.clone())
367 .collect();
368 Ok(dependents)
369 }
370 pub fn initialize_all(&self) -> SklResult<()> {
372 let plugin_names = self.list_plugins()?;
373 for name in plugin_names {
374 self.initialize_plugin(&name)?;
375 }
376 Ok(())
377 }
378 fn initialize_plugin(&self, name: &str) -> SklResult<()> {
380 let context = PluginContext {
381 registry_id: "main".to_string(),
382 working_dir: std::env::current_dir().unwrap_or_default(),
383 config: HashMap::new(),
384 available_apis: HashSet::new(),
385 };
386 let mut plugins = self.plugins.write().map_err(|_| {
387 SklearsError::InvalidOperation("Failed to acquire write lock for plugins".to_string())
388 })?;
389 if let Some(plugin) = plugins.get_mut(name) {
390 plugin.initialize(&context)?;
391 }
392 Ok(())
393 }
394 pub fn shutdown_all(&self) -> SklResult<()> {
396 let mut plugins = self.plugins.write().map_err(|_| {
397 SklearsError::InvalidOperation("Failed to acquire write lock for plugins".to_string())
398 })?;
399 for (_, plugin) in plugins.iter_mut() {
400 let _ = plugin.shutdown();
401 }
402 Ok(())
403 }
404}
405#[derive(Debug, Clone, Serialize, Deserialize)]
407pub struct ComponentSchema {
408 pub name: String,
410 pub required_parameters: Vec<ParameterSchema>,
412 pub optional_parameters: Vec<ParameterSchema>,
414 pub constraints: Vec<ParameterConstraint>,
416}
417#[derive(Debug)]
419struct PluginLibrary {
420 path: PathBuf,
422 handle: String,
424 plugins: Vec<String>,
426}
427#[derive(Debug, Clone)]
429pub struct PluginConfig {
430 pub plugin_dirs: Vec<PathBuf>,
432 pub auto_load: bool,
434 pub sandbox: bool,
436 pub max_execution_time: std::time::Duration,
438 pub validate_plugins: bool,
440}
441#[derive(Debug, Clone)]
443pub struct ExampleRegressor {
444 pub config: ComponentConfig,
445 pub learning_rate: f64,
446 pub fitted: bool,
447 pub coefficients: Option<Array1<f64>>,
448}
449impl ExampleRegressor {
450 pub fn new(config: ComponentConfig) -> Self {
451 let learning_rate = config
452 .parameters
453 .get("learning_rate")
454 .and_then(|v| match v {
455 ConfigValue::Float(f) => Some(*f),
456 _ => None,
457 })
458 .unwrap_or(0.01);
459 Self {
460 config,
461 learning_rate,
462 fitted: false,
463 coefficients: None,
464 }
465 }
466}
467#[derive(Debug)]
469pub struct ExampleTransformerFactory;
470impl ExampleTransformerFactory {
471 pub fn new() -> Self {
472 Self
473 }
474}
475#[derive(Debug, Clone, Serialize, Deserialize)]
477pub enum ParameterType {
478 String {
480 min_length: Option<usize>,
481 max_length: Option<usize>,
482 },
483 Integer {
485 min_value: Option<i64>,
486 max_value: Option<i64>,
487 },
488 Float {
490 min_value: Option<f64>,
491 max_value: Option<f64>,
492 },
493 Boolean,
495 Enum { values: Vec<String> },
497 Array {
499 item_type: Box<ParameterType>,
500 min_items: Option<usize>,
501 max_items: Option<usize>,
502 },
503 Object { schema: ComponentSchema },
505}
506#[derive(Debug, Clone)]
508pub struct ExampleScaler {
509 pub config: ComponentConfig,
510 pub scale_factor: f64,
511 pub fitted: bool,
512}
513impl ExampleScaler {
514 pub fn new(config: ComponentConfig) -> Self {
515 let scale_factor = config
516 .parameters
517 .get("scale_factor")
518 .and_then(|v| match v {
519 ConfigValue::Float(f) => Some(*f),
520 _ => None,
521 })
522 .unwrap_or(1.0);
523 Self {
524 config,
525 scale_factor,
526 fitted: false,
527 }
528 }
529}
530#[derive(Debug, Clone)]
532pub struct PluginContext {
533 pub registry_id: String,
535 pub working_dir: PathBuf,
537 pub config: HashMap<String, String>,
539 pub available_apis: HashSet<String>,
541}
542#[derive(Debug)]
544pub struct ExampleEstimatorPlugin {
545 pub metadata: PluginMetadata,
546}
547impl ExampleEstimatorPlugin {
548 pub fn new() -> Self {
549 Self {
550 metadata: PluginMetadata {
551 name: "Example Estimator Plugin".to_string(),
552 version: "1.0.0".to_string(),
553 description: "Example estimator plugin for demonstration".to_string(),
554 author: "Sklears Team".to_string(),
555 license: "MIT".to_string(),
556 min_api_version: "1.0.0".to_string(),
557 dependencies: vec![],
558 capabilities: vec!["estimator".to_string()],
559 tags: vec!["example".to_string(), "estimator".to_string()],
560 documentation_url: None,
561 source_url: None,
562 },
563 }
564 }
565}
566#[derive(Debug, Clone, Serialize, Deserialize)]
568pub struct ParameterConstraint {
569 pub name: String,
571 pub expression: String,
573 pub description: String,
575}
576#[derive(Debug, Clone, Serialize, Deserialize)]
578pub struct PluginMetadata {
579 pub name: String,
581 pub version: String,
583 pub description: String,
585 pub author: String,
587 pub license: String,
589 pub min_api_version: String,
591 pub dependencies: Vec<String>,
593 pub capabilities: Vec<String>,
595 pub tags: Vec<String>,
597 pub documentation_url: Option<String>,
599 pub source_url: Option<String>,
601}