1use serde::{Deserialize, Serialize};
8use sklears_core::error::{Result as SklResult, SklearsError};
9use std::collections::{BTreeMap, HashMap};
10
11use super::workflow_definitions::{DataType, ParameterValue, StepType};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ComponentRegistry {
16 pub components: HashMap<String, ComponentDefinition>,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct ComponentDefinition {
23 pub name: String,
25 pub component_type: StepType,
27 pub description: String,
29 pub category: ComponentCategory,
31 pub parameters: BTreeMap<String, ParameterSchema>,
33 pub inputs: Vec<PortDefinition>,
35 pub outputs: Vec<PortDefinition>,
37 pub version: String,
39 pub deprecated: bool,
41 pub performance: PerformanceCharacteristics,
43 pub implementation: ImplementationDetails,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct ParameterSchema {
50 pub param_type: DataType,
52 pub default: Option<ParameterValue>,
54 pub description: String,
56 pub validation: Option<ValidationRule>,
58 pub required: bool,
60 pub ui_hints: Option<UIHints>,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct ValidationRule {
67 pub rule_type: ValidationRuleType,
69 pub parameters: BTreeMap<String, String>,
71 pub custom_validator: Option<String>,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub enum ValidationRuleType {
78 Range { min: Option<f64>, max: Option<f64> },
80 Length {
82 min: Option<usize>,
83 max: Option<usize>,
84 },
85 Pattern(String),
87 Enum(Vec<String>),
89 CrossParameter(String),
91 Custom(String),
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct PortDefinition {
98 pub name: String,
100 pub data_type: DataType,
102 pub optional: bool,
104 pub description: String,
106 pub shape_constraints: Option<String>,
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
112pub enum ComponentCategory {
113 DataIO,
115 Preprocessing,
117 FeatureEngineering,
119 ModelTraining,
121 ModelEvaluation,
123 Visualization,
125 Utilities,
127 Custom(String),
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct PerformanceCharacteristics {
134 pub time_complexity: String,
136 pub space_complexity: String,
138 pub parallel_capable: bool,
140 pub gpu_accelerated: bool,
142 pub memory_usage: MemoryUsage,
144 pub scalability: ScalabilityInfo,
146}
147
148#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct MemoryUsage {
151 pub base_overhead_mb: f64,
153 pub scaling_factor: f64,
155 pub peak_multiplier: f64,
157}
158
159#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct ScalabilityInfo {
162 pub max_data_size: Option<usize>,
164 pub scaling_behavior: ScalingBehavior,
166 pub bottlenecks: Vec<String>,
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize)]
172pub enum ScalingBehavior {
173 Linear,
175 Logarithmic,
177 Polynomial(f64),
179 Exponential,
181 Constant,
183}
184
185#[derive(Debug, Clone, Serialize, Deserialize)]
187pub struct ImplementationDetails {
188 pub language: String,
190 pub dependencies: Vec<String>,
192 pub platforms: Vec<String>,
194 pub license: String,
196 pub source: Option<String>,
198}
199
200#[derive(Debug, Clone, Serialize, Deserialize)]
202pub struct UIHints {
203 pub widget_type: WidgetType,
205 pub display_order: Option<i32>,
207 pub group: Option<String>,
209 pub help_text: Option<String>,
211 pub placeholder: Option<String>,
213}
214
215#[derive(Debug, Clone, Serialize, Deserialize)]
217pub enum WidgetType {
218 TextInput,
220 NumberInput,
222 Checkbox,
224 Dropdown(Vec<String>),
226 Slider { min: f64, max: f64, step: f64 },
228 FilePicker,
230 ColorPicker,
232 Custom(String),
234}
235
236impl ComponentRegistry {
237 #[must_use]
239 pub fn new() -> Self {
240 let mut registry = Self {
241 components: HashMap::new(),
242 };
243
244 registry.register_default_components();
246 registry
247 }
248
249 pub fn register_component(&mut self, component: ComponentDefinition) -> SklResult<()> {
251 if self.components.contains_key(&component.name) {
252 return Err(SklearsError::InvalidInput(format!(
253 "Component '{}' already registered",
254 component.name
255 )));
256 }
257
258 self.components.insert(component.name.clone(), component);
259 Ok(())
260 }
261
262 #[must_use]
264 pub fn get_component(&self, name: &str) -> Option<&ComponentDefinition> {
265 self.components.get(name)
266 }
267
268 #[must_use]
270 pub fn has_component(&self, name: &str) -> bool {
271 self.components.contains_key(name)
272 }
273
274 #[must_use]
276 pub fn list_components(&self) -> Vec<&str> {
277 self.components
278 .keys()
279 .map(std::string::String::as_str)
280 .collect()
281 }
282
283 #[must_use]
285 pub fn get_components_by_category(
286 &self,
287 category: &ComponentCategory,
288 ) -> Vec<&ComponentDefinition> {
289 self.components
290 .values()
291 .filter(|comp| {
292 std::mem::discriminant(&comp.category) == std::mem::discriminant(category)
293 })
294 .collect()
295 }
296
297 #[must_use]
299 pub fn search_components(&self, query: &str) -> Vec<&ComponentDefinition> {
300 let query_lower = query.to_lowercase();
301 self.components
302 .values()
303 .filter(|comp| {
304 comp.name.to_lowercase().contains(&query_lower)
305 || comp.description.to_lowercase().contains(&query_lower)
306 })
307 .collect()
308 }
309
310 pub fn validate_parameters(
312 &self,
313 component_name: &str,
314 parameters: &BTreeMap<String, ParameterValue>,
315 ) -> SklResult<()> {
316 let component = self.get_component(component_name).ok_or_else(|| {
317 SklearsError::InvalidInput(format!("Component '{component_name}' not found"))
318 })?;
319
320 for (param_name, param_schema) in &component.parameters {
322 if param_schema.required && !parameters.contains_key(param_name) {
323 return Err(SklearsError::InvalidInput(format!(
324 "Required parameter '{param_name}' missing for component '{component_name}'"
325 )));
326 }
327 }
328
329 for (param_name, param_value) in parameters {
331 if let Some(param_schema) = component.parameters.get(param_name) {
332 self.validate_parameter_value(param_schema, param_value)?;
333 } else {
334 return Err(SklearsError::InvalidInput(format!(
335 "Unknown parameter '{param_name}' for component '{component_name}'"
336 )));
337 }
338 }
339
340 Ok(())
341 }
342
343 fn validate_parameter_value(
345 &self,
346 schema: &ParameterSchema,
347 value: &ParameterValue,
348 ) -> SklResult<()> {
349 let types_match = match (&schema.param_type, value) {
351 (DataType::Float32 | DataType::Float64, ParameterValue::Float(_)) => true,
352 (DataType::Int32 | DataType::Int64, ParameterValue::Int(_)) => true,
353 (DataType::Boolean, ParameterValue::Bool(_)) => true,
354 (DataType::String, ParameterValue::String(_)) => true,
355 (DataType::Array(_), ParameterValue::Array(_)) => true,
356 _ => false,
357 };
358
359 if !types_match {
360 return Err(SklearsError::InvalidInput(format!(
361 "Parameter type mismatch: expected {:?}, got {:?}",
362 schema.param_type, value
363 )));
364 }
365
366 if let Some(validation) = &schema.validation {
368 self.apply_validation_rule(validation, value)?;
369 }
370
371 Ok(())
372 }
373
374 fn apply_validation_rule(
376 &self,
377 rule: &ValidationRule,
378 value: &ParameterValue,
379 ) -> SklResult<()> {
380 match &rule.rule_type {
381 ValidationRuleType::Range { min, max } => {
382 if let ParameterValue::Float(val) = value {
383 if let Some(min_val) = min {
384 if *val < *min_val {
385 return Err(SklearsError::InvalidInput(format!(
386 "Value {val} is below minimum {min_val}"
387 )));
388 }
389 }
390 if let Some(max_val) = max {
391 if *val > *max_val {
392 return Err(SklearsError::InvalidInput(format!(
393 "Value {val} is above maximum {max_val}"
394 )));
395 }
396 }
397 } else if let ParameterValue::Int(val) = value {
398 if let Some(min_val) = min {
399 if (*val as f64) < *min_val {
400 return Err(SklearsError::InvalidInput(format!(
401 "Value {val} is below minimum {min_val}"
402 )));
403 }
404 }
405 if let Some(max_val) = max {
406 if (*val as f64) > *max_val {
407 return Err(SklearsError::InvalidInput(format!(
408 "Value {val} is above maximum {max_val}"
409 )));
410 }
411 }
412 }
413 }
414 ValidationRuleType::Length { min, max } => {
415 let length = match value {
416 ParameterValue::String(s) => s.len(),
417 ParameterValue::Array(arr) => arr.len(),
418 _ => return Ok(()), };
420
421 if let Some(min_len) = min {
422 if length < *min_len {
423 return Err(SklearsError::InvalidInput(format!(
424 "Length {length} is below minimum {min_len}"
425 )));
426 }
427 }
428 if let Some(max_len) = max {
429 if length > *max_len {
430 return Err(SklearsError::InvalidInput(format!(
431 "Length {length} is above maximum {max_len}"
432 )));
433 }
434 }
435 }
436 ValidationRuleType::Enum(allowed_values) => {
437 if let ParameterValue::String(val) = value {
438 if !allowed_values.contains(val) {
439 return Err(SklearsError::InvalidInput(format!(
440 "Value '{val}' is not in allowed values: {allowed_values:?}"
441 )));
442 }
443 }
444 }
445 _ => {
446 }
448 }
449
450 Ok(())
451 }
452
453 fn register_default_components(&mut self) {
455 let standard_scaler = ComponentDefinition {
457 name: "StandardScaler".to_string(),
458 component_type: StepType::Transformer,
459 description: "Standardize features by removing the mean and scaling to unit variance"
460 .to_string(),
461 category: ComponentCategory::Preprocessing,
462 parameters: {
463 let mut params = BTreeMap::new();
464 params.insert(
465 "with_mean".to_string(),
466 ParameterSchema {
467 param_type: DataType::Boolean,
468 default: Some(ParameterValue::Bool(true)),
469 description: "Center the data before scaling".to_string(),
470 validation: None,
471 required: false,
472 ui_hints: Some(UIHints {
473 widget_type: WidgetType::Checkbox,
474 display_order: Some(1),
475 group: Some("Scaling Options".to_string()),
476 help_text: Some(
477 "Whether to center the data before scaling".to_string(),
478 ),
479 placeholder: None,
480 }),
481 },
482 );
483 params.insert(
484 "with_std".to_string(),
485 ParameterSchema {
486 param_type: DataType::Boolean,
487 default: Some(ParameterValue::Bool(true)),
488 description: "Scale the data to unit variance".to_string(),
489 validation: None,
490 required: false,
491 ui_hints: Some(UIHints {
492 widget_type: WidgetType::Checkbox,
493 display_order: Some(2),
494 group: Some("Scaling Options".to_string()),
495 help_text: Some(
496 "Whether to scale the data to unit variance".to_string(),
497 ),
498 placeholder: None,
499 }),
500 },
501 );
502 params
503 },
504 inputs: vec![PortDefinition {
505 name: "X".to_string(),
506 data_type: DataType::Matrix(Box::new(DataType::Float64)),
507 optional: false,
508 description: "Input feature matrix".to_string(),
509 shape_constraints: Some("[n_samples, n_features]".to_string()),
510 }],
511 outputs: vec![PortDefinition {
512 name: "X_scaled".to_string(),
513 data_type: DataType::Matrix(Box::new(DataType::Float64)),
514 optional: false,
515 description: "Scaled feature matrix".to_string(),
516 shape_constraints: Some("[n_samples, n_features]".to_string()),
517 }],
518 version: "1.0.0".to_string(),
519 deprecated: false,
520 performance: PerformanceCharacteristics {
521 time_complexity: "O(n*m)".to_string(),
522 space_complexity: "O(m)".to_string(),
523 parallel_capable: true,
524 gpu_accelerated: false,
525 memory_usage: MemoryUsage {
526 base_overhead_mb: 1.0,
527 scaling_factor: 0.1,
528 peak_multiplier: 1.2,
529 },
530 scalability: ScalabilityInfo {
531 max_data_size: None,
532 scaling_behavior: ScalingBehavior::Linear,
533 bottlenecks: vec!["Memory bandwidth".to_string()],
534 },
535 },
536 implementation: ImplementationDetails {
537 language: "Rust".to_string(),
538 dependencies: vec!["ndarray".to_string(), "sklears-core".to_string()],
539 platforms: vec![
540 "Linux".to_string(),
541 "macOS".to_string(),
542 "Windows".to_string(),
543 ],
544 license: "MIT".to_string(),
545 source: None,
546 },
547 };
548
549 let linear_regression = ComponentDefinition {
551 name: "LinearRegression".to_string(),
552 component_type: StepType::Trainer,
553 description: "Ordinary least squares Linear Regression".to_string(),
554 category: ComponentCategory::ModelTraining,
555 parameters: {
556 let mut params = BTreeMap::new();
557 params.insert(
558 "fit_intercept".to_string(),
559 ParameterSchema {
560 param_type: DataType::Boolean,
561 default: Some(ParameterValue::Bool(true)),
562 description: "Whether to fit an intercept term".to_string(),
563 validation: None,
564 required: false,
565 ui_hints: Some(UIHints {
566 widget_type: WidgetType::Checkbox,
567 display_order: Some(1),
568 group: None,
569 help_text: Some(
570 "Whether to calculate the intercept for this model".to_string(),
571 ),
572 placeholder: None,
573 }),
574 },
575 );
576 params
577 },
578 inputs: vec![
579 PortDefinition {
580 name: "X".to_string(),
581 data_type: DataType::Matrix(Box::new(DataType::Float64)),
582 optional: false,
583 description: "Training data".to_string(),
584 shape_constraints: Some("[n_samples, n_features]".to_string()),
585 },
586 PortDefinition {
587 name: "y".to_string(),
588 data_type: DataType::Array(Box::new(DataType::Float64)),
589 optional: false,
590 description: "Target values".to_string(),
591 shape_constraints: Some("[n_samples]".to_string()),
592 },
593 ],
594 outputs: vec![PortDefinition {
595 name: "model".to_string(),
596 data_type: DataType::Custom("LinearRegressionModel".to_string()),
597 optional: false,
598 description: "Trained linear regression model".to_string(),
599 shape_constraints: None,
600 }],
601 version: "1.0.0".to_string(),
602 deprecated: false,
603 performance: PerformanceCharacteristics {
604 time_complexity: "O(n*m^2)".to_string(),
605 space_complexity: "O(m^2)".to_string(),
606 parallel_capable: true,
607 gpu_accelerated: true,
608 memory_usage: MemoryUsage {
609 base_overhead_mb: 2.0,
610 scaling_factor: 0.2,
611 peak_multiplier: 1.5,
612 },
613 scalability: ScalabilityInfo {
614 max_data_size: Some(1_000_000),
615 scaling_behavior: ScalingBehavior::Polynomial(2.0),
616 bottlenecks: vec!["Matrix inversion".to_string()],
617 },
618 },
619 implementation: ImplementationDetails {
620 language: "Rust".to_string(),
621 dependencies: vec!["ndarray".to_string(), "ndarray-linalg".to_string()],
622 platforms: vec![
623 "Linux".to_string(),
624 "macOS".to_string(),
625 "Windows".to_string(),
626 ],
627 license: "MIT".to_string(),
628 source: None,
629 },
630 };
631
632 let _ = self.register_component(standard_scaler);
634 let _ = self.register_component(linear_regression);
635 }
636
637 #[must_use]
639 pub fn get_component_summary(&self, name: &str) -> Option<ComponentSummary> {
640 self.get_component(name).map(|comp| ComponentSummary {
641 name: comp.name.clone(),
642 component_type: comp.component_type.clone(),
643 description: comp.description.clone(),
644 category: comp.category.clone(),
645 version: comp.version.clone(),
646 deprecated: comp.deprecated,
647 parameter_count: comp.parameters.len(),
648 input_count: comp.inputs.len(),
649 output_count: comp.outputs.len(),
650 })
651 }
652
653 #[must_use]
655 pub fn get_all_summaries(&self) -> Vec<ComponentSummary> {
656 self.components
657 .keys()
658 .filter_map(|name| self.get_component_summary(name))
659 .collect()
660 }
661}
662
663#[derive(Debug, Clone, Serialize, Deserialize)]
665pub struct ComponentSummary {
666 pub name: String,
668 pub component_type: StepType,
670 pub description: String,
672 pub category: ComponentCategory,
674 pub version: String,
676 pub deprecated: bool,
678 pub parameter_count: usize,
680 pub input_count: usize,
682 pub output_count: usize,
684}
685
686impl Default for ComponentRegistry {
687 fn default() -> Self {
688 Self::new()
689 }
690}
691
692#[derive(Debug, Clone, Serialize, Deserialize)]
694pub struct ComponentDiscovery {
695 pub registries: Vec<String>,
697 pub search_paths: Vec<String>,
699 pub config: DiscoveryConfig,
701}
702
703#[derive(Debug, Clone, Serialize, Deserialize)]
705pub struct DiscoveryConfig {
706 pub auto_discovery: bool,
708 pub timeout_sec: u64,
710 pub cache_results: bool,
712}
713
714#[derive(Debug, Clone, Serialize, Deserialize)]
716pub struct ComponentMetadata {
717 pub id: String,
719 pub display_name: String,
721 pub icon: Option<String>,
723 pub documentation_url: Option<String>,
725 pub examples: Vec<String>,
727 pub keywords: Vec<String>,
729 pub maintainer: Option<String>,
731 pub created_at: String,
733 pub updated_at: String,
735}
736
737#[derive(Debug, Clone, Serialize, Deserialize)]
739pub struct ComponentSignature {
740 pub inputs: Vec<TypeSignature>,
742 pub outputs: Vec<TypeSignature>,
744 pub parameters: Vec<ParameterSignature>,
746}
747
748#[derive(Debug, Clone, Serialize, Deserialize)]
750pub struct TypeSignature {
751 pub name: String,
753 pub data_type: DataType,
755 pub shape: Option<String>,
757 pub constraints: Vec<String>,
759}
760
761#[derive(Debug, Clone, Serialize, Deserialize)]
763pub struct ParameterSignature {
764 pub name: String,
766 pub param_type: DataType,
768 pub required: bool,
770 pub constraints: Vec<String>,
772}
773
774#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
776pub enum ComponentType {
777 DataLoader,
779 Transformer,
781 Trainer,
783 Predictor,
785 Evaluator,
787 Visualizer,
789 Utility,
791 Custom(String),
793}
794
795#[derive(Debug, Clone, Serialize, Deserialize)]
797pub struct ComponentValidator {
798 pub rules: Vec<ValidationRule>,
800 pub custom_validator: Option<String>,
802 pub context: ValidationContext,
804}
805
806#[derive(Debug, Clone, Serialize, Deserialize)]
808pub struct ValidationContext {
809 pub workflow_id: Option<String>,
811 pub available_components: Vec<String>,
813 pub global_params: BTreeMap<String, ParameterValue>,
815}
816
817#[derive(Debug, Clone, Serialize, Deserialize)]
819pub struct ComponentVersion {
820 pub version: String,
822 pub major: u32,
824 pub minor: u32,
826 pub patch: u32,
828 pub pre_release: Option<String>,
830 pub build: Option<String>,
832}
833
834#[derive(Debug, Clone, Serialize, Deserialize, thiserror::Error)]
836pub enum RegistryError {
837 #[error("Component not found: {0}")]
839 ComponentNotFound(String),
840 #[error("Component already exists: {0}")]
842 ComponentExists(String),
843 #[error("Invalid component definition: {0}")]
845 InvalidComponent(String),
846 #[error("Version conflict: {0}")]
848 VersionConflict(String),
849 #[error("Dependency error: {0}")]
851 DependencyError(String),
852 #[error("Validation error: {0}")]
854 ValidationError(String),
855 #[error("IO error: {0}")]
857 IoError(String),
858 #[error("Network error: {0}")]
860 NetworkError(String),
861}
862
863#[allow(non_snake_case)]
864#[cfg(test)]
865mod tests {
866 use super::*;
867
868 #[test]
869 fn test_component_registry_creation() {
870 let registry = ComponentRegistry::new();
871 assert!(registry.has_component("StandardScaler"));
872 assert!(registry.has_component("LinearRegression"));
873 assert!(!registry.has_component("NonExistentComponent"));
874 }
875
876 #[test]
877 fn test_get_component() {
878 let registry = ComponentRegistry::new();
879 let component = registry.get_component("StandardScaler");
880 assert!(component.is_some());
881
882 let comp = component.unwrap();
883 assert_eq!(comp.name, "StandardScaler");
884 assert_eq!(comp.component_type, StepType::Transformer);
885 }
886
887 #[test]
888 fn test_validate_parameters() {
889 let registry = ComponentRegistry::new();
890
891 let mut params = BTreeMap::new();
892 params.insert("with_mean".to_string(), ParameterValue::Bool(true));
893 params.insert("with_std".to_string(), ParameterValue::Bool(false));
894
895 let result = registry.validate_parameters("StandardScaler", ¶ms);
896 assert!(result.is_ok());
897
898 params.insert("invalid_param".to_string(), ParameterValue::Bool(true));
900 let result = registry.validate_parameters("StandardScaler", ¶ms);
901 assert!(result.is_err());
902 }
903
904 #[test]
905 fn test_search_components() {
906 let registry = ComponentRegistry::new();
907 let results = registry.search_components("scale");
908 assert!(!results.is_empty());
909 assert!(results.iter().any(|comp| comp.name == "StandardScaler"));
910 }
911
912 #[test]
913 fn test_get_components_by_category() {
914 let registry = ComponentRegistry::new();
915 let preprocessing_components =
916 registry.get_components_by_category(&ComponentCategory::Preprocessing);
917 assert!(!preprocessing_components.is_empty());
918 assert!(preprocessing_components
919 .iter()
920 .any(|comp| comp.name == "StandardScaler"));
921 }
922
923 #[test]
924 fn test_component_summary() {
925 let registry = ComponentRegistry::new();
926 let summary = registry.get_component_summary("LinearRegression");
927 assert!(summary.is_some());
928
929 let sum = summary.unwrap();
930 assert_eq!(sum.name, "LinearRegression");
931 assert_eq!(sum.component_type, StepType::Trainer);
932 assert!(!sum.deprecated);
933 }
934}