sklears_compose/workflow_language/
component_registry.rs

1//! Component Registry for Pipeline Components
2//!
3//! This module provides component registration and discovery capabilities for the
4//! workflow system, including component metadata management, parameter schemas,
5//! validation rules, and component lifecycle management.
6
7use serde::{Deserialize, Serialize};
8use sklears_core::error::{Result as SklResult, SklearsError};
9use std::collections::{BTreeMap, HashMap};
10
11use super::workflow_definitions::{DataType, ParameterValue, StepType};
12
13/// Component registry for available pipeline components
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ComponentRegistry {
16    /// Registered components by name
17    pub components: HashMap<String, ComponentDefinition>,
18}
19
20/// Component definition
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct ComponentDefinition {
23    /// Component name
24    pub name: String,
25    /// Component type
26    pub component_type: StepType,
27    /// Component description
28    pub description: String,
29    /// Component category
30    pub category: ComponentCategory,
31    /// Input parameters schema
32    pub parameters: BTreeMap<String, ParameterSchema>,
33    /// Input ports
34    pub inputs: Vec<PortDefinition>,
35    /// Output ports
36    pub outputs: Vec<PortDefinition>,
37    /// Component version
38    pub version: String,
39    /// Whether component is deprecated
40    pub deprecated: bool,
41    /// Performance characteristics
42    pub performance: PerformanceCharacteristics,
43    /// Implementation details
44    pub implementation: ImplementationDetails,
45}
46
47/// Parameter schema
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct ParameterSchema {
50    /// Parameter data type
51    pub param_type: DataType,
52    /// Default value
53    pub default: Option<ParameterValue>,
54    /// Parameter description
55    pub description: String,
56    /// Validation rules
57    pub validation: Option<ValidationRule>,
58    /// Whether parameter is required
59    pub required: bool,
60    /// Parameter hints for UI
61    pub ui_hints: Option<UIHints>,
62}
63
64/// Validation rule
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct ValidationRule {
67    /// Rule type
68    pub rule_type: ValidationRuleType,
69    /// Rule parameters
70    pub parameters: BTreeMap<String, String>,
71    /// Custom validation function
72    pub custom_validator: Option<String>,
73}
74
75/// Types of validation rules
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub enum ValidationRuleType {
78    /// Range validation for numeric values
79    Range { min: Option<f64>, max: Option<f64> },
80    /// Length validation for strings/arrays
81    Length {
82        min: Option<usize>,
83        max: Option<usize>,
84    },
85    /// Pattern validation for strings
86    Pattern(String),
87    /// Enum validation (allowed values)
88    Enum(Vec<String>),
89    /// Cross-parameter validation
90    CrossParameter(String),
91    /// Custom validation function
92    Custom(String),
93}
94
95/// Port definition for inputs/outputs
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct PortDefinition {
98    /// Port name
99    pub name: String,
100    /// Data type
101    pub data_type: DataType,
102    /// Whether port is optional
103    pub optional: bool,
104    /// Port description
105    pub description: String,
106    /// Shape constraints
107    pub shape_constraints: Option<String>,
108}
109
110/// Component categories
111#[derive(Debug, Clone, Serialize, Deserialize)]
112pub enum ComponentCategory {
113    /// Data input/output
114    DataIO,
115    /// Preprocessing
116    Preprocessing,
117    /// Feature engineering
118    FeatureEngineering,
119    /// Model training
120    ModelTraining,
121    /// Model evaluation
122    ModelEvaluation,
123    /// Visualization
124    Visualization,
125    /// Utilities
126    Utilities,
127    /// Custom category
128    Custom(String),
129}
130
131/// Performance characteristics
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct PerformanceCharacteristics {
134    /// Time complexity (Big O notation)
135    pub time_complexity: String,
136    /// Space complexity (Big O notation)
137    pub space_complexity: String,
138    /// Whether component supports parallel execution
139    pub parallel_capable: bool,
140    /// Whether component supports GPU acceleration
141    pub gpu_accelerated: bool,
142    /// Typical memory usage
143    pub memory_usage: MemoryUsage,
144    /// Scalability characteristics
145    pub scalability: ScalabilityInfo,
146}
147
148/// Memory usage information
149#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct MemoryUsage {
151    /// Base memory overhead
152    pub base_overhead_mb: f64,
153    /// Memory scaling factor with data size
154    pub scaling_factor: f64,
155    /// Peak memory multiplier
156    pub peak_multiplier: f64,
157}
158
159/// Scalability information
160#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct ScalabilityInfo {
162    /// Maximum recommended data size
163    pub max_data_size: Option<usize>,
164    /// Scaling behavior
165    pub scaling_behavior: ScalingBehavior,
166    /// Bottleneck description
167    pub bottlenecks: Vec<String>,
168}
169
170/// Scaling behavior
171#[derive(Debug, Clone, Serialize, Deserialize)]
172pub enum ScalingBehavior {
173    /// Linear scaling with data size
174    Linear,
175    /// Logarithmic scaling
176    Logarithmic,
177    /// Polynomial scaling
178    Polynomial(f64),
179    /// Exponential scaling
180    Exponential,
181    /// Constant (doesn't scale with data)
182    Constant,
183}
184
185/// Implementation details
186#[derive(Debug, Clone, Serialize, Deserialize)]
187pub struct ImplementationDetails {
188    /// Implementation language
189    pub language: String,
190    /// Required dependencies
191    pub dependencies: Vec<String>,
192    /// Supported platforms
193    pub platforms: Vec<String>,
194    /// License information
195    pub license: String,
196    /// Source location
197    pub source: Option<String>,
198}
199
200/// UI hints for parameter display
201#[derive(Debug, Clone, Serialize, Deserialize)]
202pub struct UIHints {
203    /// Widget type for parameter input
204    pub widget_type: WidgetType,
205    /// Display order
206    pub display_order: Option<i32>,
207    /// Grouping information
208    pub group: Option<String>,
209    /// Help text
210    pub help_text: Option<String>,
211    /// Placeholder text
212    pub placeholder: Option<String>,
213}
214
215/// Widget types for UI
216#[derive(Debug, Clone, Serialize, Deserialize)]
217pub enum WidgetType {
218    /// Text input
219    TextInput,
220    /// Number input
221    NumberInput,
222    /// Checkbox
223    Checkbox,
224    /// Dropdown/select
225    Dropdown(Vec<String>),
226    /// Slider
227    Slider { min: f64, max: f64, step: f64 },
228    /// File picker
229    FilePicker,
230    /// Color picker
231    ColorPicker,
232    /// Custom widget
233    Custom(String),
234}
235
236impl ComponentRegistry {
237    /// Create a new component registry with default components
238    #[must_use]
239    pub fn new() -> Self {
240        let mut registry = Self {
241            components: HashMap::new(),
242        };
243
244        // Register default components
245        registry.register_default_components();
246        registry
247    }
248
249    /// Register a component
250    pub fn register_component(&mut self, component: ComponentDefinition) -> SklResult<()> {
251        if self.components.contains_key(&component.name) {
252            return Err(SklearsError::InvalidInput(format!(
253                "Component '{}' already registered",
254                component.name
255            )));
256        }
257
258        self.components.insert(component.name.clone(), component);
259        Ok(())
260    }
261
262    /// Get a component definition
263    #[must_use]
264    pub fn get_component(&self, name: &str) -> Option<&ComponentDefinition> {
265        self.components.get(name)
266    }
267
268    /// Check if a component exists
269    #[must_use]
270    pub fn has_component(&self, name: &str) -> bool {
271        self.components.contains_key(name)
272    }
273
274    /// List all available components
275    #[must_use]
276    pub fn list_components(&self) -> Vec<&str> {
277        self.components
278            .keys()
279            .map(std::string::String::as_str)
280            .collect()
281    }
282
283    /// Get components by category
284    #[must_use]
285    pub fn get_components_by_category(
286        &self,
287        category: &ComponentCategory,
288    ) -> Vec<&ComponentDefinition> {
289        self.components
290            .values()
291            .filter(|comp| {
292                std::mem::discriminant(&comp.category) == std::mem::discriminant(category)
293            })
294            .collect()
295    }
296
297    /// Search components by name or description
298    #[must_use]
299    pub fn search_components(&self, query: &str) -> Vec<&ComponentDefinition> {
300        let query_lower = query.to_lowercase();
301        self.components
302            .values()
303            .filter(|comp| {
304                comp.name.to_lowercase().contains(&query_lower)
305                    || comp.description.to_lowercase().contains(&query_lower)
306            })
307            .collect()
308    }
309
310    /// Validate component parameters
311    pub fn validate_parameters(
312        &self,
313        component_name: &str,
314        parameters: &BTreeMap<String, ParameterValue>,
315    ) -> SklResult<()> {
316        let component = self.get_component(component_name).ok_or_else(|| {
317            SklearsError::InvalidInput(format!("Component '{component_name}' not found"))
318        })?;
319
320        // Check required parameters
321        for (param_name, param_schema) in &component.parameters {
322            if param_schema.required && !parameters.contains_key(param_name) {
323                return Err(SklearsError::InvalidInput(format!(
324                    "Required parameter '{param_name}' missing for component '{component_name}'"
325                )));
326            }
327        }
328
329        // Validate parameter values
330        for (param_name, param_value) in parameters {
331            if let Some(param_schema) = component.parameters.get(param_name) {
332                self.validate_parameter_value(param_schema, param_value)?;
333            } else {
334                return Err(SklearsError::InvalidInput(format!(
335                    "Unknown parameter '{param_name}' for component '{component_name}'"
336                )));
337            }
338        }
339
340        Ok(())
341    }
342
343    /// Validate a single parameter value
344    fn validate_parameter_value(
345        &self,
346        schema: &ParameterSchema,
347        value: &ParameterValue,
348    ) -> SklResult<()> {
349        // Type validation
350        let types_match = match (&schema.param_type, value) {
351            (DataType::Float32 | DataType::Float64, ParameterValue::Float(_)) => true,
352            (DataType::Int32 | DataType::Int64, ParameterValue::Int(_)) => true,
353            (DataType::Boolean, ParameterValue::Bool(_)) => true,
354            (DataType::String, ParameterValue::String(_)) => true,
355            (DataType::Array(_), ParameterValue::Array(_)) => true,
356            _ => false,
357        };
358
359        if !types_match {
360            return Err(SklearsError::InvalidInput(format!(
361                "Parameter type mismatch: expected {:?}, got {:?}",
362                schema.param_type, value
363            )));
364        }
365
366        // Validation rules
367        if let Some(validation) = &schema.validation {
368            self.apply_validation_rule(validation, value)?;
369        }
370
371        Ok(())
372    }
373
374    /// Apply validation rule to a parameter value
375    fn apply_validation_rule(
376        &self,
377        rule: &ValidationRule,
378        value: &ParameterValue,
379    ) -> SklResult<()> {
380        match &rule.rule_type {
381            ValidationRuleType::Range { min, max } => {
382                if let ParameterValue::Float(val) = value {
383                    if let Some(min_val) = min {
384                        if *val < *min_val {
385                            return Err(SklearsError::InvalidInput(format!(
386                                "Value {val} is below minimum {min_val}"
387                            )));
388                        }
389                    }
390                    if let Some(max_val) = max {
391                        if *val > *max_val {
392                            return Err(SklearsError::InvalidInput(format!(
393                                "Value {val} is above maximum {max_val}"
394                            )));
395                        }
396                    }
397                } else if let ParameterValue::Int(val) = value {
398                    if let Some(min_val) = min {
399                        if (*val as f64) < *min_val {
400                            return Err(SklearsError::InvalidInput(format!(
401                                "Value {val} is below minimum {min_val}"
402                            )));
403                        }
404                    }
405                    if let Some(max_val) = max {
406                        if (*val as f64) > *max_val {
407                            return Err(SklearsError::InvalidInput(format!(
408                                "Value {val} is above maximum {max_val}"
409                            )));
410                        }
411                    }
412                }
413            }
414            ValidationRuleType::Length { min, max } => {
415                let length = match value {
416                    ParameterValue::String(s) => s.len(),
417                    ParameterValue::Array(arr) => arr.len(),
418                    _ => return Ok(()), // Skip length validation for non-string/array types
419                };
420
421                if let Some(min_len) = min {
422                    if length < *min_len {
423                        return Err(SklearsError::InvalidInput(format!(
424                            "Length {length} is below minimum {min_len}"
425                        )));
426                    }
427                }
428                if let Some(max_len) = max {
429                    if length > *max_len {
430                        return Err(SklearsError::InvalidInput(format!(
431                            "Length {length} is above maximum {max_len}"
432                        )));
433                    }
434                }
435            }
436            ValidationRuleType::Enum(allowed_values) => {
437                if let ParameterValue::String(val) = value {
438                    if !allowed_values.contains(val) {
439                        return Err(SklearsError::InvalidInput(format!(
440                            "Value '{val}' is not in allowed values: {allowed_values:?}"
441                        )));
442                    }
443                }
444            }
445            _ => {
446                // Other validation rules not implemented in this example
447            }
448        }
449
450        Ok(())
451    }
452
453    /// Register default components
454    fn register_default_components(&mut self) {
455        // StandardScaler
456        let standard_scaler = ComponentDefinition {
457            name: "StandardScaler".to_string(),
458            component_type: StepType::Transformer,
459            description: "Standardize features by removing the mean and scaling to unit variance"
460                .to_string(),
461            category: ComponentCategory::Preprocessing,
462            parameters: {
463                let mut params = BTreeMap::new();
464                params.insert(
465                    "with_mean".to_string(),
466                    ParameterSchema {
467                        param_type: DataType::Boolean,
468                        default: Some(ParameterValue::Bool(true)),
469                        description: "Center the data before scaling".to_string(),
470                        validation: None,
471                        required: false,
472                        ui_hints: Some(UIHints {
473                            widget_type: WidgetType::Checkbox,
474                            display_order: Some(1),
475                            group: Some("Scaling Options".to_string()),
476                            help_text: Some(
477                                "Whether to center the data before scaling".to_string(),
478                            ),
479                            placeholder: None,
480                        }),
481                    },
482                );
483                params.insert(
484                    "with_std".to_string(),
485                    ParameterSchema {
486                        param_type: DataType::Boolean,
487                        default: Some(ParameterValue::Bool(true)),
488                        description: "Scale the data to unit variance".to_string(),
489                        validation: None,
490                        required: false,
491                        ui_hints: Some(UIHints {
492                            widget_type: WidgetType::Checkbox,
493                            display_order: Some(2),
494                            group: Some("Scaling Options".to_string()),
495                            help_text: Some(
496                                "Whether to scale the data to unit variance".to_string(),
497                            ),
498                            placeholder: None,
499                        }),
500                    },
501                );
502                params
503            },
504            inputs: vec![PortDefinition {
505                name: "X".to_string(),
506                data_type: DataType::Matrix(Box::new(DataType::Float64)),
507                optional: false,
508                description: "Input feature matrix".to_string(),
509                shape_constraints: Some("[n_samples, n_features]".to_string()),
510            }],
511            outputs: vec![PortDefinition {
512                name: "X_scaled".to_string(),
513                data_type: DataType::Matrix(Box::new(DataType::Float64)),
514                optional: false,
515                description: "Scaled feature matrix".to_string(),
516                shape_constraints: Some("[n_samples, n_features]".to_string()),
517            }],
518            version: "1.0.0".to_string(),
519            deprecated: false,
520            performance: PerformanceCharacteristics {
521                time_complexity: "O(n*m)".to_string(),
522                space_complexity: "O(m)".to_string(),
523                parallel_capable: true,
524                gpu_accelerated: false,
525                memory_usage: MemoryUsage {
526                    base_overhead_mb: 1.0,
527                    scaling_factor: 0.1,
528                    peak_multiplier: 1.2,
529                },
530                scalability: ScalabilityInfo {
531                    max_data_size: None,
532                    scaling_behavior: ScalingBehavior::Linear,
533                    bottlenecks: vec!["Memory bandwidth".to_string()],
534                },
535            },
536            implementation: ImplementationDetails {
537                language: "Rust".to_string(),
538                dependencies: vec!["ndarray".to_string(), "sklears-core".to_string()],
539                platforms: vec![
540                    "Linux".to_string(),
541                    "macOS".to_string(),
542                    "Windows".to_string(),
543                ],
544                license: "MIT".to_string(),
545                source: None,
546            },
547        };
548
549        // LinearRegression
550        let linear_regression = ComponentDefinition {
551            name: "LinearRegression".to_string(),
552            component_type: StepType::Trainer,
553            description: "Ordinary least squares Linear Regression".to_string(),
554            category: ComponentCategory::ModelTraining,
555            parameters: {
556                let mut params = BTreeMap::new();
557                params.insert(
558                    "fit_intercept".to_string(),
559                    ParameterSchema {
560                        param_type: DataType::Boolean,
561                        default: Some(ParameterValue::Bool(true)),
562                        description: "Whether to fit an intercept term".to_string(),
563                        validation: None,
564                        required: false,
565                        ui_hints: Some(UIHints {
566                            widget_type: WidgetType::Checkbox,
567                            display_order: Some(1),
568                            group: None,
569                            help_text: Some(
570                                "Whether to calculate the intercept for this model".to_string(),
571                            ),
572                            placeholder: None,
573                        }),
574                    },
575                );
576                params
577            },
578            inputs: vec![
579                PortDefinition {
580                    name: "X".to_string(),
581                    data_type: DataType::Matrix(Box::new(DataType::Float64)),
582                    optional: false,
583                    description: "Training data".to_string(),
584                    shape_constraints: Some("[n_samples, n_features]".to_string()),
585                },
586                PortDefinition {
587                    name: "y".to_string(),
588                    data_type: DataType::Array(Box::new(DataType::Float64)),
589                    optional: false,
590                    description: "Target values".to_string(),
591                    shape_constraints: Some("[n_samples]".to_string()),
592                },
593            ],
594            outputs: vec![PortDefinition {
595                name: "model".to_string(),
596                data_type: DataType::Custom("LinearRegressionModel".to_string()),
597                optional: false,
598                description: "Trained linear regression model".to_string(),
599                shape_constraints: None,
600            }],
601            version: "1.0.0".to_string(),
602            deprecated: false,
603            performance: PerformanceCharacteristics {
604                time_complexity: "O(n*m^2)".to_string(),
605                space_complexity: "O(m^2)".to_string(),
606                parallel_capable: true,
607                gpu_accelerated: true,
608                memory_usage: MemoryUsage {
609                    base_overhead_mb: 2.0,
610                    scaling_factor: 0.2,
611                    peak_multiplier: 1.5,
612                },
613                scalability: ScalabilityInfo {
614                    max_data_size: Some(1_000_000),
615                    scaling_behavior: ScalingBehavior::Polynomial(2.0),
616                    bottlenecks: vec!["Matrix inversion".to_string()],
617                },
618            },
619            implementation: ImplementationDetails {
620                language: "Rust".to_string(),
621                dependencies: vec!["ndarray".to_string(), "ndarray-linalg".to_string()],
622                platforms: vec![
623                    "Linux".to_string(),
624                    "macOS".to_string(),
625                    "Windows".to_string(),
626                ],
627                license: "MIT".to_string(),
628                source: None,
629            },
630        };
631
632        // Register components
633        let _ = self.register_component(standard_scaler);
634        let _ = self.register_component(linear_regression);
635    }
636
637    /// Get component metadata summary
638    #[must_use]
639    pub fn get_component_summary(&self, name: &str) -> Option<ComponentSummary> {
640        self.get_component(name).map(|comp| ComponentSummary {
641            name: comp.name.clone(),
642            component_type: comp.component_type.clone(),
643            description: comp.description.clone(),
644            category: comp.category.clone(),
645            version: comp.version.clone(),
646            deprecated: comp.deprecated,
647            parameter_count: comp.parameters.len(),
648            input_count: comp.inputs.len(),
649            output_count: comp.outputs.len(),
650        })
651    }
652
653    /// Get all component summaries
654    #[must_use]
655    pub fn get_all_summaries(&self) -> Vec<ComponentSummary> {
656        self.components
657            .keys()
658            .filter_map(|name| self.get_component_summary(name))
659            .collect()
660    }
661}
662
663/// Component summary for quick overview
664#[derive(Debug, Clone, Serialize, Deserialize)]
665pub struct ComponentSummary {
666    /// Component name
667    pub name: String,
668    /// Component type
669    pub component_type: StepType,
670    /// Description
671    pub description: String,
672    /// Category
673    pub category: ComponentCategory,
674    /// Version
675    pub version: String,
676    /// Whether deprecated
677    pub deprecated: bool,
678    /// Number of parameters
679    pub parameter_count: usize,
680    /// Number of inputs
681    pub input_count: usize,
682    /// Number of outputs
683    pub output_count: usize,
684}
685
686impl Default for ComponentRegistry {
687    fn default() -> Self {
688        Self::new()
689    }
690}
691
692/// Component discovery service for finding available components
693#[derive(Debug, Clone, Serialize, Deserialize)]
694pub struct ComponentDiscovery {
695    /// Available component registries
696    pub registries: Vec<String>,
697    /// Search paths for components
698    pub search_paths: Vec<String>,
699    /// Discovery configuration
700    pub config: DiscoveryConfig,
701}
702
703/// Discovery configuration
704#[derive(Debug, Clone, Serialize, Deserialize)]
705pub struct DiscoveryConfig {
706    /// Enable automatic discovery
707    pub auto_discovery: bool,
708    /// Discovery timeout in seconds
709    pub timeout_sec: u64,
710    /// Cache discovery results
711    pub cache_results: bool,
712}
713
714/// Component metadata for extended information
715#[derive(Debug, Clone, Serialize, Deserialize)]
716pub struct ComponentMetadata {
717    /// Component identifier
718    pub id: String,
719    /// Display name
720    pub display_name: String,
721    /// Icon or image reference
722    pub icon: Option<String>,
723    /// Documentation URL
724    pub documentation_url: Option<String>,
725    /// Example usage
726    pub examples: Vec<String>,
727    /// Keywords for search
728    pub keywords: Vec<String>,
729    /// Maintainer information
730    pub maintainer: Option<String>,
731    /// Creation timestamp
732    pub created_at: String,
733    /// Last updated timestamp
734    pub updated_at: String,
735}
736
737/// Component signature for type checking
738#[derive(Debug, Clone, Serialize, Deserialize)]
739pub struct ComponentSignature {
740    /// Input signature
741    pub inputs: Vec<TypeSignature>,
742    /// Output signature
743    pub outputs: Vec<TypeSignature>,
744    /// Parameter signature
745    pub parameters: Vec<ParameterSignature>,
746}
747
748/// Type signature for inputs/outputs
749#[derive(Debug, Clone, Serialize, Deserialize)]
750pub struct TypeSignature {
751    /// Type name
752    pub name: String,
753    /// Data type
754    pub data_type: DataType,
755    /// Shape information
756    pub shape: Option<String>,
757    /// Type constraints
758    pub constraints: Vec<String>,
759}
760
761/// Parameter signature for type checking
762#[derive(Debug, Clone, Serialize, Deserialize)]
763pub struct ParameterSignature {
764    /// Parameter name
765    pub name: String,
766    /// Parameter type
767    pub param_type: DataType,
768    /// Whether required
769    pub required: bool,
770    /// Type constraints
771    pub constraints: Vec<String>,
772}
773
774/// Component type classification
775#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
776pub enum ComponentType {
777    /// Data loading component
778    DataLoader,
779    /// Data transformation component
780    Transformer,
781    /// Model training component
782    Trainer,
783    /// Model inference component
784    Predictor,
785    /// Evaluation component
786    Evaluator,
787    /// Visualization component
788    Visualizer,
789    /// Utility component
790    Utility,
791    /// Custom component type
792    Custom(String),
793}
794
795/// Component validator for validation logic
796#[derive(Debug, Clone, Serialize, Deserialize)]
797pub struct ComponentValidator {
798    /// Validation rules
799    pub rules: Vec<ValidationRule>,
800    /// Custom validation function
801    pub custom_validator: Option<String>,
802    /// Validation context
803    pub context: ValidationContext,
804}
805
806/// Validation context
807#[derive(Debug, Clone, Serialize, Deserialize)]
808pub struct ValidationContext {
809    /// Current workflow context
810    pub workflow_id: Option<String>,
811    /// Available component instances
812    pub available_components: Vec<String>,
813    /// Global parameters
814    pub global_params: BTreeMap<String, ParameterValue>,
815}
816
817/// Component version information
818#[derive(Debug, Clone, Serialize, Deserialize)]
819pub struct ComponentVersion {
820    /// Version string
821    pub version: String,
822    /// Major version number
823    pub major: u32,
824    /// Minor version number
825    pub minor: u32,
826    /// Patch version number
827    pub patch: u32,
828    /// Pre-release identifier
829    pub pre_release: Option<String>,
830    /// Build metadata
831    pub build: Option<String>,
832}
833
834/// Registry error types
835#[derive(Debug, Clone, Serialize, Deserialize, thiserror::Error)]
836pub enum RegistryError {
837    /// Component not found
838    #[error("Component not found: {0}")]
839    ComponentNotFound(String),
840    /// Component already exists
841    #[error("Component already exists: {0}")]
842    ComponentExists(String),
843    /// Invalid component definition
844    #[error("Invalid component definition: {0}")]
845    InvalidComponent(String),
846    /// Version conflict
847    #[error("Version conflict: {0}")]
848    VersionConflict(String),
849    /// Dependency error
850    #[error("Dependency error: {0}")]
851    DependencyError(String),
852    /// Validation error
853    #[error("Validation error: {0}")]
854    ValidationError(String),
855    /// IO error
856    #[error("IO error: {0}")]
857    IoError(String),
858    /// Network error
859    #[error("Network error: {0}")]
860    NetworkError(String),
861}
862
863#[allow(non_snake_case)]
864#[cfg(test)]
865mod tests {
866    use super::*;
867
868    #[test]
869    fn test_component_registry_creation() {
870        let registry = ComponentRegistry::new();
871        assert!(registry.has_component("StandardScaler"));
872        assert!(registry.has_component("LinearRegression"));
873        assert!(!registry.has_component("NonExistentComponent"));
874    }
875
876    #[test]
877    fn test_get_component() {
878        let registry = ComponentRegistry::new();
879        let component = registry.get_component("StandardScaler");
880        assert!(component.is_some());
881
882        let comp = component.unwrap();
883        assert_eq!(comp.name, "StandardScaler");
884        assert_eq!(comp.component_type, StepType::Transformer);
885    }
886
887    #[test]
888    fn test_validate_parameters() {
889        let registry = ComponentRegistry::new();
890
891        let mut params = BTreeMap::new();
892        params.insert("with_mean".to_string(), ParameterValue::Bool(true));
893        params.insert("with_std".to_string(), ParameterValue::Bool(false));
894
895        let result = registry.validate_parameters("StandardScaler", &params);
896        assert!(result.is_ok());
897
898        // Test invalid parameter
899        params.insert("invalid_param".to_string(), ParameterValue::Bool(true));
900        let result = registry.validate_parameters("StandardScaler", &params);
901        assert!(result.is_err());
902    }
903
904    #[test]
905    fn test_search_components() {
906        let registry = ComponentRegistry::new();
907        let results = registry.search_components("scale");
908        assert!(!results.is_empty());
909        assert!(results.iter().any(|comp| comp.name == "StandardScaler"));
910    }
911
912    #[test]
913    fn test_get_components_by_category() {
914        let registry = ComponentRegistry::new();
915        let preprocessing_components =
916            registry.get_components_by_category(&ComponentCategory::Preprocessing);
917        assert!(!preprocessing_components.is_empty());
918        assert!(preprocessing_components
919            .iter()
920            .any(|comp| comp.name == "StandardScaler"));
921    }
922
923    #[test]
924    fn test_component_summary() {
925        let registry = ComponentRegistry::new();
926        let summary = registry.get_component_summary("LinearRegression");
927        assert!(summary.is_some());
928
929        let sum = summary.unwrap();
930        assert_eq!(sum.name, "LinearRegression");
931        assert_eq!(sum.component_type, StepType::Trainer);
932        assert!(!sum.deprecated);
933    }
934}