sklears_feature_selection/
plugin.rs

1//! Modular Plugin Architecture for Feature Selection
2//!
3//! This module provides a flexible plugin system that allows users to register
4//! custom feature selection methods, metrics, and transformations. The architecture
5//! uses trait objects, dynamic dispatch, and reflection for maximum extensibility.
6
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
8use sklears_core::error::{Result as SklResult, SklearsError};
9use std::any::Any;
10use std::collections::HashMap;
11use std::sync::{Arc, RwLock};
12
13type Result<T> = SklResult<T>;
14
15/// Core trait for feature selection plugins
16pub trait FeatureSelectionPlugin: Send + Sync {
17    /// Get the plugin name
18    fn name(&self) -> &str;
19
20    /// Get the plugin version
21    fn version(&self) -> &str;
22
23    /// Get the plugin description
24    fn description(&self) -> &str;
25
26    /// Get plugin metadata
27    fn metadata(&self) -> PluginMetadata;
28
29    /// Fit the selector on training data
30    fn fit(&mut self, X: ArrayView2<f64>, y: ArrayView1<f64>) -> Result<()>;
31
32    /// Transform data by selecting features
33    fn transform(&self, X: ArrayView2<f64>) -> Result<Array2<f64>>;
34
35    /// Get selected feature indices
36    fn selected_features(&self) -> Result<Vec<usize>>;
37
38    /// Get feature scores/importances
39    fn feature_scores(&self) -> Result<Array1<f64>>;
40
41    /// Check if the plugin is fitted
42    fn is_fitted(&self) -> bool;
43
44    /// Get plugin configuration as Any trait object
45    fn as_any(&self) -> &dyn Any;
46
47    /// Clone the plugin
48    fn clone_plugin(&self) -> Box<dyn FeatureSelectionPlugin>;
49}
50
51/// Trait for custom scoring functions
52pub trait ScoringFunction: Send + Sync {
53    /// Get the scoring function name
54    fn name(&self) -> &str;
55
56    /// Compute score for a single feature
57    fn score(&self, feature: ArrayView1<f64>, target: ArrayView1<f64>) -> Result<f64>;
58
59    /// Compute scores for all features in parallel
60    fn score_features(&self, X: ArrayView2<f64>, y: ArrayView1<f64>) -> Result<Array1<f64>> {
61        let mut scores = Array1::zeros(X.ncols());
62        for (i, score) in scores.iter_mut().enumerate() {
63            *score = self.score(X.column(i), y)?;
64        }
65        Ok(scores)
66    }
67
68    /// Clone the scoring function
69    fn clone_scoring(&self) -> Box<dyn ScoringFunction>;
70}
71
72/// Trait for custom transformation functions
73pub trait TransformationFunction: Send + Sync {
74    /// Get the transformation function name
75    fn name(&self) -> &str;
76
77    /// Apply transformation to features
78    fn transform(&self, X: ArrayView2<f64>) -> Result<Array2<f64>>;
79
80    /// Get output feature count (if deterministic)
81    fn output_features(&self, input_features: usize) -> Option<usize>;
82
83    /// Clone the transformation function
84    fn clone_transform(&self) -> Box<dyn TransformationFunction>;
85}
86
87/// Plugin metadata
88#[derive(Debug, Clone)]
89pub struct PluginMetadata {
90    pub author: String,
91    pub license: String,
92    pub categories: Vec<String>,
93    pub tags: Vec<String>,
94    pub min_samples: Option<usize>,
95    pub max_features: Option<usize>,
96    pub supports_sparse: bool,
97    pub supports_multiclass: bool,
98    pub supports_regression: bool,
99    pub computational_complexity: ComputationalComplexity,
100    pub memory_complexity: MemoryComplexity,
101}
102
103impl Default for PluginMetadata {
104    fn default() -> Self {
105        Self {
106            author: String::new(),
107            license: "MIT".to_string(),
108            categories: Vec::new(),
109            tags: Vec::new(),
110            min_samples: None,
111            max_features: None,
112            supports_sparse: false,
113            supports_multiclass: true,
114            supports_regression: true,
115            computational_complexity: ComputationalComplexity::default(),
116            memory_complexity: MemoryComplexity::default(),
117        }
118    }
119}
120
121#[derive(Debug, Clone)]
122pub enum ComputationalComplexity {
123    /// Constant
124    Constant,
125    /// Linear
126    Linear,
127    /// Quadratic
128    Quadratic,
129    /// Cubic
130    Cubic,
131    /// Exponential
132    Exponential,
133    /// Custom
134    Custom(String),
135}
136
137impl Default for ComputationalComplexity {
138    fn default() -> Self {
139        Self::Linear
140    }
141}
142
143#[derive(Debug, Clone)]
144pub enum MemoryComplexity {
145    /// Constant
146    Constant,
147    /// Linear
148    Linear,
149    /// Quadratic
150    Quadratic,
151    /// Custom
152    Custom(String),
153}
154
155impl Default for MemoryComplexity {
156    fn default() -> Self {
157        Self::Linear
158    }
159}
160
161/// Plugin registry for managing feature selection plugins
162pub struct PluginRegistry {
163    plugins: RwLock<HashMap<String, Box<dyn FeatureSelectionPlugin>>>,
164    scoring_functions: RwLock<HashMap<String, Box<dyn ScoringFunction>>>,
165    transformations: RwLock<HashMap<String, Box<dyn TransformationFunction>>>,
166    middleware: RwLock<Vec<Box<dyn PluginMiddleware>>>,
167}
168
169impl Default for PluginRegistry {
170    fn default() -> Self {
171        Self::new()
172    }
173}
174
175impl PluginRegistry {
176    /// Create a new plugin registry
177    pub fn new() -> Self {
178        Self {
179            plugins: RwLock::new(HashMap::new()),
180            scoring_functions: RwLock::new(HashMap::new()),
181            transformations: RwLock::new(HashMap::new()),
182            middleware: RwLock::new(Vec::new()),
183        }
184    }
185
186    /// Register a feature selection plugin
187    pub fn register_plugin(&self, plugin: Box<dyn FeatureSelectionPlugin>) -> Result<()> {
188        let name = plugin.name().to_string();
189        let mut plugins = self
190            .plugins
191            .write()
192            .map_err(|_| SklearsError::FitError("Failed to acquire write lock".to_string()))?;
193
194        if plugins.contains_key(&name) {
195            return Err(SklearsError::InvalidInput(format!(
196                "Plugin '{}' is already registered",
197                name
198            )));
199        }
200
201        plugins.insert(name, plugin);
202        Ok(())
203    }
204
205    /// Register a custom scoring function
206    pub fn register_scoring_function(&self, function: Box<dyn ScoringFunction>) -> Result<()> {
207        let name = function.name().to_string();
208        let mut functions = self
209            .scoring_functions
210            .write()
211            .map_err(|_| SklearsError::FitError("Failed to acquire write lock".to_string()))?;
212
213        functions.insert(name, function);
214        Ok(())
215    }
216
217    /// Register a custom transformation function
218    pub fn register_transformation(
219        &self,
220        transformation: Box<dyn TransformationFunction>,
221    ) -> Result<()> {
222        let name = transformation.name().to_string();
223        let mut transformations = self
224            .transformations
225            .write()
226            .map_err(|_| SklearsError::FitError("Failed to acquire write lock".to_string()))?;
227
228        transformations.insert(name, transformation);
229        Ok(())
230    }
231
232    /// Register middleware
233    pub fn register_middleware(&self, middleware: Box<dyn PluginMiddleware>) -> Result<()> {
234        let mut middleware_vec = self
235            .middleware
236            .write()
237            .map_err(|_| SklearsError::FitError("Failed to acquire write lock".to_string()))?;
238
239        middleware_vec.push(middleware);
240        Ok(())
241    }
242
243    /// Get a plugin by name
244    pub fn get_plugin(&self, name: &str) -> Result<Box<dyn FeatureSelectionPlugin>> {
245        let plugins = self
246            .plugins
247            .read()
248            .map_err(|_| SklearsError::FitError("Failed to acquire read lock".to_string()))?;
249
250        plugins
251            .get(name)
252            .map(|plugin| plugin.clone_plugin())
253            .ok_or_else(|| SklearsError::InvalidInput(format!("Plugin '{}' not found", name)))
254    }
255
256    /// Get a scoring function by name
257    pub fn get_scoring_function(&self, name: &str) -> Result<Box<dyn ScoringFunction>> {
258        let functions = self
259            .scoring_functions
260            .read()
261            .map_err(|_| SklearsError::FitError("Failed to acquire read lock".to_string()))?;
262
263        functions
264            .get(name)
265            .map(|func| func.clone_scoring())
266            .ok_or_else(|| {
267                SklearsError::InvalidInput(format!("Scoring function '{}' not found", name))
268            })
269    }
270
271    /// Get a transformation by name
272    pub fn get_transformation(&self, name: &str) -> Result<Box<dyn TransformationFunction>> {
273        let transformations = self
274            .transformations
275            .read()
276            .map_err(|_| SklearsError::FitError("Failed to acquire read lock".to_string()))?;
277
278        transformations
279            .get(name)
280            .map(|transform| transform.clone_transform())
281            .ok_or_else(|| {
282                SklearsError::InvalidInput(format!("Transformation '{}' not found", name))
283            })
284    }
285
286    /// List all registered plugins
287    pub fn list_plugins(&self) -> Result<Vec<String>> {
288        let plugins = self
289            .plugins
290            .read()
291            .map_err(|_| SklearsError::FitError("Failed to acquire read lock".to_string()))?;
292
293        Ok(plugins.keys().cloned().collect())
294    }
295
296    /// Get plugin metadata
297    pub fn get_plugin_metadata(&self, name: &str) -> Result<PluginMetadata> {
298        let plugins = self
299            .plugins
300            .read()
301            .map_err(|_| SklearsError::FitError("Failed to acquire read lock".to_string()))?;
302
303        plugins
304            .get(name)
305            .map(|plugin| plugin.metadata())
306            .ok_or_else(|| SklearsError::InvalidInput(format!("Plugin '{}' not found", name)))
307    }
308
309    /// Execute middleware before plugin operations
310    pub fn execute_before_middleware(
311        &self,
312        plugin_name: &str,
313        context: &PluginContext,
314    ) -> Result<()> {
315        let middleware = self
316            .middleware
317            .read()
318            .map_err(|_| SklearsError::FitError("Failed to acquire read lock".to_string()))?;
319
320        for mw in middleware.iter() {
321            mw.before_execution(plugin_name, context)?;
322        }
323
324        Ok(())
325    }
326
327    /// Execute middleware after plugin operations
328    pub fn execute_after_middleware(
329        &self,
330        plugin_name: &str,
331        context: &PluginContext,
332        result: &PluginResult,
333    ) -> Result<()> {
334        let middleware = self
335            .middleware
336            .read()
337            .map_err(|_| SklearsError::FitError("Failed to acquire read lock".to_string()))?;
338
339        for mw in middleware.iter() {
340            mw.after_execution(plugin_name, context, result)?;
341        }
342
343        Ok(())
344    }
345}
346
347// Global plugin registry instance (commented out due to missing lazy_static dependency)
348// lazy_static::lazy_static! {
349//     pub static ref GLOBAL_REGISTRY: PluginRegistry = PluginRegistry::new();
350// }
351
352/// Plugin middleware trait for cross-cutting concerns
353pub trait PluginMiddleware: Send + Sync {
354    /// Execute before plugin operation
355    fn before_execution(&self, plugin_name: &str, context: &PluginContext) -> Result<()>;
356
357    /// Execute after plugin operation
358    fn after_execution(
359        &self,
360        plugin_name: &str,
361        context: &PluginContext,
362        result: &PluginResult,
363    ) -> Result<()>;
364}
365
366/// Context passed to middleware
367#[derive(Debug, Clone)]
368pub struct PluginContext {
369    pub operation: String,
370    pub data_shape: (usize, usize),
371    pub parameters: HashMap<String, String>,
372    pub start_time: std::time::Instant,
373}
374
375/// Result passed to middleware
376#[derive(Debug, Clone)]
377pub struct PluginResult {
378    pub success: bool,
379    pub execution_time: std::time::Duration,
380    pub selected_features: Vec<usize>,
381    pub error_message: Option<String>,
382}
383
384/// Composable plugin pipeline
385pub struct PluginPipeline {
386    steps: Vec<PipelineStep>,
387    registry: Arc<PluginRegistry>,
388}
389
390#[derive(Clone)]
391pub enum PipelineStep {
392    /// Plugin
393    Plugin {
394        name: String,
395
396        config: HashMap<String, String>,
397    },
398    /// Transformation
399    Transformation {
400        name: String,
401
402        config: HashMap<String, String>,
403    },
404    /// Scoring
405    Scoring {
406        name: String,
407        config: HashMap<String, String>,
408    },
409}
410
411impl Default for PluginPipeline {
412    fn default() -> Self {
413        Self::new()
414    }
415}
416
417impl PluginPipeline {
418    /// Create a new plugin pipeline
419    pub fn new() -> Self {
420        Self {
421            steps: Vec::new(),
422            registry: Arc::new(PluginRegistry::new()),
423        }
424    }
425
426    /// Create a pipeline with a custom registry
427    pub fn with_registry(registry: Arc<PluginRegistry>) -> Self {
428        Self {
429            steps: Vec::new(),
430            registry,
431        }
432    }
433
434    /// Add a plugin step
435    pub fn add_plugin(mut self, name: String, config: HashMap<String, String>) -> Self {
436        self.steps.push(PipelineStep::Plugin { name, config });
437        self
438    }
439
440    /// Add a transformation step
441    pub fn add_transformation(mut self, name: String, config: HashMap<String, String>) -> Self {
442        self.steps
443            .push(PipelineStep::Transformation { name, config });
444        self
445    }
446
447    /// Add a scoring step
448    pub fn add_scoring(mut self, name: String, config: HashMap<String, String>) -> Self {
449        self.steps.push(PipelineStep::Scoring { name, config });
450        self
451    }
452
453    /// Execute the pipeline
454    pub fn execute(&self, X: ArrayView2<f64>, y: ArrayView1<f64>) -> Result<PipelineResult> {
455        let start_time = std::time::Instant::now();
456        let mut current_X = X.to_owned();
457        let mut step_results = Vec::new();
458
459        for (step_index, step) in self.steps.iter().enumerate() {
460            let step_start = std::time::Instant::now();
461
462            match step {
463                PipelineStep::Plugin { name, config } => {
464                    let context = PluginContext {
465                        operation: "plugin_execution".to_string(),
466                        data_shape: (current_X.nrows(), current_X.ncols()),
467                        parameters: config.clone(),
468                        start_time: step_start,
469                    };
470
471                    self.registry.execute_before_middleware(name, &context)?;
472
473                    let mut plugin = self.registry.get_plugin(name)?;
474                    plugin.fit(current_X.view(), y.view())?;
475                    current_X = plugin.transform(current_X.view())?;
476                    let selected_features = plugin.selected_features()?;
477
478                    let result = PluginResult {
479                        success: true,
480                        execution_time: step_start.elapsed(),
481                        selected_features: selected_features.clone(),
482                        error_message: None,
483                    };
484
485                    self.registry
486                        .execute_after_middleware(name, &context, &result)?;
487
488                    step_results.push(StepResult {
489                        step_index,
490                        step_type: "Plugin".to_string(),
491                        step_name: name.clone(),
492                        execution_time: step_start.elapsed(),
493                        input_features: context.data_shape.1,
494                        output_features: current_X.ncols(),
495                        selected_features,
496                    });
497                }
498                PipelineStep::Transformation { name, config: _ } => {
499                    let transformation = self.registry.get_transformation(name)?;
500                    let input_features = current_X.ncols();
501                    current_X = transformation.transform(current_X.view())?;
502
503                    step_results.push(StepResult {
504                        step_index,
505                        step_type: "Transformation".to_string(),
506                        step_name: name.clone(),
507                        execution_time: step_start.elapsed(),
508                        input_features,
509                        output_features: current_X.ncols(),
510                        selected_features: (0..current_X.ncols()).collect(),
511                    });
512                }
513                PipelineStep::Scoring { name, config: _ } => {
514                    let scoring_function = self.registry.get_scoring_function(name)?;
515                    let _scores = scoring_function.score_features(current_X.view(), y.view())?;
516
517                    step_results.push(StepResult {
518                        step_index,
519                        step_type: "Scoring".to_string(),
520                        step_name: name.clone(),
521                        execution_time: step_start.elapsed(),
522                        input_features: current_X.ncols(),
523                        output_features: current_X.ncols(),
524                        selected_features: (0..current_X.ncols()).collect(),
525                    });
526                }
527            }
528        }
529
530        Ok(PipelineResult {
531            final_data: current_X.clone(),
532            step_results,
533            total_execution_time: start_time.elapsed(),
534            original_features: X.ncols(),
535            final_features: current_X.ncols(),
536        })
537    }
538}
539
540/// Result of pipeline execution
541#[derive(Debug, Clone)]
542pub struct PipelineResult {
543    pub final_data: Array2<f64>,
544    pub step_results: Vec<StepResult>,
545    pub total_execution_time: std::time::Duration,
546    pub original_features: usize,
547    pub final_features: usize,
548}
549
550/// Result of individual pipeline step
551#[derive(Debug, Clone)]
552pub struct StepResult {
553    pub step_index: usize,
554    pub step_type: String,
555    pub step_name: String,
556    pub execution_time: std::time::Duration,
557    pub input_features: usize,
558    pub output_features: usize,
559    pub selected_features: Vec<usize>,
560}
561
562/// Built-in plugins
563pub mod builtin {
564    use super::*;
565
566    /// Variance threshold plugin
567    #[derive(Debug, Clone)]
568    pub struct VarianceThresholdPlugin {
569        threshold: f64,
570        feature_variances: Option<Array1<f64>>,
571        selected_indices: Option<Vec<usize>>,
572        fitted: bool,
573    }
574
575    impl VarianceThresholdPlugin {
576        pub fn new(threshold: f64) -> Self {
577            Self {
578                threshold,
579                feature_variances: None,
580                selected_indices: None,
581                fitted: false,
582            }
583        }
584    }
585
586    impl FeatureSelectionPlugin for VarianceThresholdPlugin {
587        fn name(&self) -> &str {
588            "variance_threshold"
589        }
590
591        fn version(&self) -> &str {
592            "1.0.0"
593        }
594
595        fn description(&self) -> &str {
596            "Removes features with variance below threshold"
597        }
598
599        fn metadata(&self) -> PluginMetadata {
600            PluginMetadata {
601                author: "Sklears Team".to_string(),
602                license: "MIT".to_string(),
603                categories: vec!["filter".to_string(), "univariate".to_string()],
604                tags: vec!["variance".to_string(), "threshold".to_string()],
605                min_samples: None,
606                max_features: None,
607                supports_sparse: true,
608                supports_multiclass: true,
609                supports_regression: true,
610                computational_complexity: ComputationalComplexity::Linear,
611                memory_complexity: MemoryComplexity::Linear,
612            }
613        }
614
615        fn fit(&mut self, X: ArrayView2<f64>, _y: ArrayView1<f64>) -> Result<()> {
616            let mut variances = Array1::zeros(X.ncols());
617            for (i, var) in variances.iter_mut().enumerate() {
618                *var = X.column(i).var(1.0);
619            }
620
621            let selected_indices: Vec<usize> = variances
622                .iter()
623                .enumerate()
624                .filter_map(|(i, &var)| if var > self.threshold { Some(i) } else { None })
625                .collect();
626
627            self.feature_variances = Some(variances);
628            self.selected_indices = Some(selected_indices);
629            self.fitted = true;
630
631            Ok(())
632        }
633
634        fn transform(&self, X: ArrayView2<f64>) -> Result<Array2<f64>> {
635            if !self.fitted {
636                return Err(SklearsError::FitError("Plugin not fitted".to_string()));
637            }
638
639            let selected_indices = self.selected_indices.as_ref().unwrap();
640            if selected_indices.is_empty() {
641                return Err(SklearsError::InvalidInput(
642                    "No features selected".to_string(),
643                ));
644            }
645
646            let mut result = Array2::zeros((X.nrows(), selected_indices.len()));
647            for (new_col, &old_col) in selected_indices.iter().enumerate() {
648                for row in 0..X.nrows() {
649                    result[[row, new_col]] = X[[row, old_col]];
650                }
651            }
652
653            Ok(result)
654        }
655
656        fn selected_features(&self) -> Result<Vec<usize>> {
657            self.selected_indices
658                .clone()
659                .ok_or_else(|| SklearsError::FitError("Plugin not fitted".to_string()))
660        }
661
662        fn feature_scores(&self) -> Result<Array1<f64>> {
663            self.feature_variances
664                .clone()
665                .ok_or_else(|| SklearsError::FitError("Plugin not fitted".to_string()))
666        }
667
668        fn is_fitted(&self) -> bool {
669            self.fitted
670        }
671
672        fn as_any(&self) -> &dyn Any {
673            self
674        }
675
676        fn clone_plugin(&self) -> Box<dyn FeatureSelectionPlugin> {
677            Box::new(self.clone())
678        }
679    }
680
681    /// Correlation-based scoring function
682    #[derive(Debug, Clone)]
683    pub struct CorrelationScoring;
684
685    impl ScoringFunction for CorrelationScoring {
686        fn name(&self) -> &str {
687            "correlation"
688        }
689
690        fn score(&self, feature: ArrayView1<f64>, target: ArrayView1<f64>) -> Result<f64> {
691            if feature.len() != target.len() {
692                return Err(SklearsError::InvalidInput(
693                    "Feature and target length mismatch".to_string(),
694                ));
695            }
696
697            let correlation = crate::performance::SIMDStats::correlation_auto(feature, target);
698            Ok(correlation.abs())
699        }
700
701        fn clone_scoring(&self) -> Box<dyn ScoringFunction> {
702            Box::new(self.clone())
703        }
704    }
705
706    /// Normalization transformation
707    #[derive(Debug, Clone)]
708    pub struct NormalizationTransform;
709
710    impl TransformationFunction for NormalizationTransform {
711        fn name(&self) -> &str {
712            "normalization"
713        }
714
715        fn transform(&self, X: ArrayView2<f64>) -> Result<Array2<f64>> {
716            let mut result = X.to_owned();
717
718            for col in 0..result.ncols() {
719                let column = result.column(col);
720                let mean = column.mean().unwrap_or(0.0);
721                let std = column.var(1.0).sqrt();
722
723                if std > 1e-10 {
724                    for row in 0..result.nrows() {
725                        result[[row, col]] = (result[[row, col]] - mean) / std;
726                    }
727                }
728            }
729
730            Ok(result)
731        }
732
733        fn output_features(&self, input_features: usize) -> Option<usize> {
734            Some(input_features)
735        }
736
737        fn clone_transform(&self) -> Box<dyn TransformationFunction> {
738            Box::new(self.clone())
739        }
740    }
741}
742
743/// Logging middleware
744#[derive(Debug, Clone)]
745pub struct LoggingMiddleware {
746    log_level: LogLevel,
747}
748
749#[derive(Debug, Clone)]
750pub enum LogLevel {
751    /// Debug
752    Debug,
753    /// Info
754    Info,
755    /// Warning
756    Warning,
757    /// Error
758    Error,
759}
760
761impl LoggingMiddleware {
762    pub fn new(log_level: LogLevel) -> Self {
763        Self { log_level }
764    }
765}
766
767impl PluginMiddleware for LoggingMiddleware {
768    fn before_execution(&self, plugin_name: &str, context: &PluginContext) -> Result<()> {
769        match self.log_level {
770            LogLevel::Debug | LogLevel::Info => {
771                println!(
772                    "Executing plugin '{}' with operation '{}'",
773                    plugin_name, context.operation
774                );
775                println!("  Data shape: {:?}", context.data_shape);
776            }
777            _ => {}
778        }
779        Ok(())
780    }
781
782    fn after_execution(
783        &self,
784        plugin_name: &str,
785        _context: &PluginContext,
786        result: &PluginResult,
787    ) -> Result<()> {
788        match self.log_level {
789            LogLevel::Debug | LogLevel::Info => {
790                println!(
791                    "Plugin '{}' completed in {:?}",
792                    plugin_name, result.execution_time
793                );
794                println!("  Selected {} features", result.selected_features.len());
795                if let Some(ref error) = result.error_message {
796                    println!("  Error: {}", error);
797                }
798            }
799            _ => {}
800        }
801        Ok(())
802    }
803}
804
805/// Performance monitoring middleware
806#[derive(Debug)]
807pub struct PerformanceMiddleware {
808    metrics: Arc<RwLock<HashMap<String, PerformanceMetrics>>>,
809}
810
811#[derive(Debug, Clone)]
812pub struct PerformanceMetrics {
813    pub total_executions: usize,
814    pub total_time: std::time::Duration,
815    pub average_time: std::time::Duration,
816    pub min_time: std::time::Duration,
817    pub max_time: std::time::Duration,
818}
819
820impl Default for PerformanceMiddleware {
821    fn default() -> Self {
822        Self::new()
823    }
824}
825
826impl PerformanceMiddleware {
827    pub fn new() -> Self {
828        Self {
829            metrics: Arc::new(RwLock::new(HashMap::new())),
830        }
831    }
832
833    pub fn get_metrics(&self) -> Result<HashMap<String, PerformanceMetrics>> {
834        let metrics = self
835            .metrics
836            .read()
837            .map_err(|_| SklearsError::FitError("Failed to acquire read lock".to_string()))?;
838        Ok(metrics.clone())
839    }
840}
841
842impl PluginMiddleware for PerformanceMiddleware {
843    fn before_execution(&self, _plugin_name: &str, _context: &PluginContext) -> Result<()> {
844        // Nothing to do before execution
845        Ok(())
846    }
847
848    fn after_execution(
849        &self,
850        plugin_name: &str,
851        _context: &PluginContext,
852        result: &PluginResult,
853    ) -> Result<()> {
854        let mut metrics = self
855            .metrics
856            .write()
857            .map_err(|_| SklearsError::FitError("Failed to acquire write lock".to_string()))?;
858
859        let entry = metrics
860            .entry(plugin_name.to_string())
861            .or_insert_with(|| PerformanceMetrics {
862                total_executions: 0,
863                total_time: std::time::Duration::from_secs(0),
864                average_time: std::time::Duration::from_secs(0),
865                min_time: std::time::Duration::from_secs(u64::MAX),
866                max_time: std::time::Duration::from_secs(0),
867            });
868
869        entry.total_executions += 1;
870        entry.total_time += result.execution_time;
871        entry.average_time = entry.total_time / entry.total_executions as u32;
872        entry.min_time = entry.min_time.min(result.execution_time);
873        entry.max_time = entry.max_time.max(result.execution_time);
874
875        Ok(())
876    }
877}
878
879/// Helper macro for easy plugin registration
880#[macro_export]
881macro_rules! register_plugin {
882    ($registry:expr, $plugin:expr) => {
883        $registry.register_plugin(Box::new($plugin))?;
884    };
885}
886
887/// Helper macro for creating plugin pipelines
888#[macro_export]
889macro_rules! plugin_pipeline {
890    ($($step_type:ident($name:expr, $config:expr)),+ $(,)?) => {
891        {
892            let mut pipeline = PluginPipeline::new();
893            $(
894                pipeline = match stringify!($step_type) {
895                    "plugin" => pipeline.add_plugin($name.to_string(), $config),
896                    "transform" => pipeline.add_transformation($name.to_string(), $config),
897                    "scoring" => pipeline.add_scoring($name.to_string(), $config),
898                    _ => panic!("Unknown step type: {}", stringify!($step_type)),
899                };
900            )+
901            pipeline
902        }
903    };
904}
905
906#[allow(non_snake_case)]
907#[cfg(test)]
908mod tests {
909    use super::builtin::*;
910    use super::*;
911    use scirs2_core::ndarray::array;
912
913    #[test]
914    fn test_plugin_registry() -> Result<()> {
915        let registry = PluginRegistry::new();
916
917        // Register a plugin
918        let plugin = VarianceThresholdPlugin::new(0.1);
919        registry.register_plugin(Box::new(plugin))?;
920
921        // Register a scoring function
922        let scoring = CorrelationScoring;
923        registry.register_scoring_function(Box::new(scoring))?;
924
925        // Register a transformation
926        let transform = NormalizationTransform;
927        registry.register_transformation(Box::new(transform))?;
928
929        // Test retrieval
930        let retrieved_plugin = registry.get_plugin("variance_threshold")?;
931        assert_eq!(retrieved_plugin.name(), "variance_threshold");
932
933        let retrieved_scoring = registry.get_scoring_function("correlation")?;
934        assert_eq!(retrieved_scoring.name(), "correlation");
935
936        let retrieved_transform = registry.get_transformation("normalization")?;
937        assert_eq!(retrieved_transform.name(), "normalization");
938
939        Ok(())
940    }
941
942    #[test]
943    #[allow(non_snake_case)]
944    fn test_plugin_execution() -> Result<()> {
945        let X = array![
946            [1.0, 2.0, 0.0],
947            [2.0, 4.0, 0.0],
948            [3.0, 6.0, 0.0],
949            [4.0, 8.0, 0.0],
950        ];
951        let y = array![1.0, 2.0, 3.0, 4.0];
952
953        let mut plugin = VarianceThresholdPlugin::new(0.1);
954        plugin.fit(X.view(), y.view())?;
955
956        let selected_features = plugin.selected_features()?;
957        assert!(selected_features.len() <= 3);
958
959        let transformed = plugin.transform(X.view())?;
960        assert_eq!(transformed.ncols(), selected_features.len());
961
962        Ok(())
963    }
964
965    #[test]
966    #[allow(non_snake_case)]
967    fn test_plugin_pipeline() -> Result<()> {
968        let registry = Arc::new(PluginRegistry::new());
969
970        // Register plugins
971        registry.register_plugin(Box::new(VarianceThresholdPlugin::new(0.1)))?;
972        registry.register_transformation(Box::new(NormalizationTransform))?;
973        registry.register_scoring_function(Box::new(CorrelationScoring))?;
974
975        let pipeline = PluginPipeline::with_registry(registry)
976            .add_transformation("normalization".to_string(), HashMap::new())
977            .add_plugin("variance_threshold".to_string(), HashMap::new());
978
979        let X = array![
980            [1.0, 2.0, 3.0],
981            [2.0, 4.0, 6.0],
982            [3.0, 6.0, 9.0],
983            [4.0, 8.0, 12.0],
984        ];
985        let y = array![1.0, 2.0, 3.0, 4.0];
986
987        let result = pipeline.execute(X.view(), y.view())?;
988        assert!(result.final_features <= 3);
989        assert_eq!(result.step_results.len(), 2);
990
991        Ok(())
992    }
993
994    #[test]
995    fn test_middleware() -> Result<()> {
996        let registry = PluginRegistry::new();
997
998        // Register middleware
999        let logging_middleware = LoggingMiddleware::new(LogLevel::Info);
1000        registry.register_middleware(Box::new(logging_middleware))?;
1001
1002        let performance_middleware = PerformanceMiddleware::new();
1003        registry.register_middleware(Box::new(performance_middleware))?;
1004
1005        // Register and execute plugin
1006        registry.register_plugin(Box::new(VarianceThresholdPlugin::new(0.1)))?;
1007
1008        let context = PluginContext {
1009            operation: "test".to_string(),
1010            data_shape: (100, 10),
1011            parameters: HashMap::new(),
1012            start_time: std::time::Instant::now(),
1013        };
1014
1015        registry.execute_before_middleware("variance_threshold", &context)?;
1016
1017        let result = PluginResult {
1018            success: true,
1019            execution_time: std::time::Duration::from_millis(10),
1020            selected_features: vec![0, 1, 2],
1021            error_message: None,
1022        };
1023
1024        registry.execute_after_middleware("variance_threshold", &context, &result)?;
1025
1026        Ok(())
1027    }
1028
1029    #[test]
1030    fn test_macro_pipeline() -> Result<()> {
1031        let pipeline = plugin_pipeline! {
1032            transform("normalization", HashMap::new()),
1033            plugin("variance_threshold", HashMap::new()),
1034            scoring("correlation", HashMap::new()),
1035        };
1036
1037        // Pipeline should have 3 steps
1038        assert_eq!(pipeline.steps.len(), 3);
1039
1040        Ok(())
1041    }
1042}