1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
8use sklears_core::error::{Result as SklResult, SklearsError};
9use std::any::Any;
10use std::collections::HashMap;
11use std::sync::{Arc, RwLock};
12
13type Result<T> = SklResult<T>;
14
15pub trait FeatureSelectionPlugin: Send + Sync {
17 fn name(&self) -> &str;
19
20 fn version(&self) -> &str;
22
23 fn description(&self) -> &str;
25
26 fn metadata(&self) -> PluginMetadata;
28
29 fn fit(&mut self, X: ArrayView2<f64>, y: ArrayView1<f64>) -> Result<()>;
31
32 fn transform(&self, X: ArrayView2<f64>) -> Result<Array2<f64>>;
34
35 fn selected_features(&self) -> Result<Vec<usize>>;
37
38 fn feature_scores(&self) -> Result<Array1<f64>>;
40
41 fn is_fitted(&self) -> bool;
43
44 fn as_any(&self) -> &dyn Any;
46
47 fn clone_plugin(&self) -> Box<dyn FeatureSelectionPlugin>;
49}
50
51pub trait ScoringFunction: Send + Sync {
53 fn name(&self) -> &str;
55
56 fn score(&self, feature: ArrayView1<f64>, target: ArrayView1<f64>) -> Result<f64>;
58
59 fn score_features(&self, X: ArrayView2<f64>, y: ArrayView1<f64>) -> Result<Array1<f64>> {
61 let mut scores = Array1::zeros(X.ncols());
62 for (i, score) in scores.iter_mut().enumerate() {
63 *score = self.score(X.column(i), y)?;
64 }
65 Ok(scores)
66 }
67
68 fn clone_scoring(&self) -> Box<dyn ScoringFunction>;
70}
71
72pub trait TransformationFunction: Send + Sync {
74 fn name(&self) -> &str;
76
77 fn transform(&self, X: ArrayView2<f64>) -> Result<Array2<f64>>;
79
80 fn output_features(&self, input_features: usize) -> Option<usize>;
82
83 fn clone_transform(&self) -> Box<dyn TransformationFunction>;
85}
86
87#[derive(Debug, Clone)]
89pub struct PluginMetadata {
90 pub author: String,
91 pub license: String,
92 pub categories: Vec<String>,
93 pub tags: Vec<String>,
94 pub min_samples: Option<usize>,
95 pub max_features: Option<usize>,
96 pub supports_sparse: bool,
97 pub supports_multiclass: bool,
98 pub supports_regression: bool,
99 pub computational_complexity: ComputationalComplexity,
100 pub memory_complexity: MemoryComplexity,
101}
102
103impl Default for PluginMetadata {
104 fn default() -> Self {
105 Self {
106 author: String::new(),
107 license: "MIT".to_string(),
108 categories: Vec::new(),
109 tags: Vec::new(),
110 min_samples: None,
111 max_features: None,
112 supports_sparse: false,
113 supports_multiclass: true,
114 supports_regression: true,
115 computational_complexity: ComputationalComplexity::default(),
116 memory_complexity: MemoryComplexity::default(),
117 }
118 }
119}
120
121#[derive(Debug, Clone)]
122pub enum ComputationalComplexity {
123 Constant,
125 Linear,
127 Quadratic,
129 Cubic,
131 Exponential,
133 Custom(String),
135}
136
137impl Default for ComputationalComplexity {
138 fn default() -> Self {
139 Self::Linear
140 }
141}
142
143#[derive(Debug, Clone)]
144pub enum MemoryComplexity {
145 Constant,
147 Linear,
149 Quadratic,
151 Custom(String),
153}
154
155impl Default for MemoryComplexity {
156 fn default() -> Self {
157 Self::Linear
158 }
159}
160
161pub struct PluginRegistry {
163 plugins: RwLock<HashMap<String, Box<dyn FeatureSelectionPlugin>>>,
164 scoring_functions: RwLock<HashMap<String, Box<dyn ScoringFunction>>>,
165 transformations: RwLock<HashMap<String, Box<dyn TransformationFunction>>>,
166 middleware: RwLock<Vec<Box<dyn PluginMiddleware>>>,
167}
168
169impl Default for PluginRegistry {
170 fn default() -> Self {
171 Self::new()
172 }
173}
174
175impl PluginRegistry {
176 pub fn new() -> Self {
178 Self {
179 plugins: RwLock::new(HashMap::new()),
180 scoring_functions: RwLock::new(HashMap::new()),
181 transformations: RwLock::new(HashMap::new()),
182 middleware: RwLock::new(Vec::new()),
183 }
184 }
185
186 pub fn register_plugin(&self, plugin: Box<dyn FeatureSelectionPlugin>) -> Result<()> {
188 let name = plugin.name().to_string();
189 let mut plugins = self
190 .plugins
191 .write()
192 .map_err(|_| SklearsError::FitError("Failed to acquire write lock".to_string()))?;
193
194 if plugins.contains_key(&name) {
195 return Err(SklearsError::InvalidInput(format!(
196 "Plugin '{}' is already registered",
197 name
198 )));
199 }
200
201 plugins.insert(name, plugin);
202 Ok(())
203 }
204
205 pub fn register_scoring_function(&self, function: Box<dyn ScoringFunction>) -> Result<()> {
207 let name = function.name().to_string();
208 let mut functions = self
209 .scoring_functions
210 .write()
211 .map_err(|_| SklearsError::FitError("Failed to acquire write lock".to_string()))?;
212
213 functions.insert(name, function);
214 Ok(())
215 }
216
217 pub fn register_transformation(
219 &self,
220 transformation: Box<dyn TransformationFunction>,
221 ) -> Result<()> {
222 let name = transformation.name().to_string();
223 let mut transformations = self
224 .transformations
225 .write()
226 .map_err(|_| SklearsError::FitError("Failed to acquire write lock".to_string()))?;
227
228 transformations.insert(name, transformation);
229 Ok(())
230 }
231
232 pub fn register_middleware(&self, middleware: Box<dyn PluginMiddleware>) -> Result<()> {
234 let mut middleware_vec = self
235 .middleware
236 .write()
237 .map_err(|_| SklearsError::FitError("Failed to acquire write lock".to_string()))?;
238
239 middleware_vec.push(middleware);
240 Ok(())
241 }
242
243 pub fn get_plugin(&self, name: &str) -> Result<Box<dyn FeatureSelectionPlugin>> {
245 let plugins = self
246 .plugins
247 .read()
248 .map_err(|_| SklearsError::FitError("Failed to acquire read lock".to_string()))?;
249
250 plugins
251 .get(name)
252 .map(|plugin| plugin.clone_plugin())
253 .ok_or_else(|| SklearsError::InvalidInput(format!("Plugin '{}' not found", name)))
254 }
255
256 pub fn get_scoring_function(&self, name: &str) -> Result<Box<dyn ScoringFunction>> {
258 let functions = self
259 .scoring_functions
260 .read()
261 .map_err(|_| SklearsError::FitError("Failed to acquire read lock".to_string()))?;
262
263 functions
264 .get(name)
265 .map(|func| func.clone_scoring())
266 .ok_or_else(|| {
267 SklearsError::InvalidInput(format!("Scoring function '{}' not found", name))
268 })
269 }
270
271 pub fn get_transformation(&self, name: &str) -> Result<Box<dyn TransformationFunction>> {
273 let transformations = self
274 .transformations
275 .read()
276 .map_err(|_| SklearsError::FitError("Failed to acquire read lock".to_string()))?;
277
278 transformations
279 .get(name)
280 .map(|transform| transform.clone_transform())
281 .ok_or_else(|| {
282 SklearsError::InvalidInput(format!("Transformation '{}' not found", name))
283 })
284 }
285
286 pub fn list_plugins(&self) -> Result<Vec<String>> {
288 let plugins = self
289 .plugins
290 .read()
291 .map_err(|_| SklearsError::FitError("Failed to acquire read lock".to_string()))?;
292
293 Ok(plugins.keys().cloned().collect())
294 }
295
296 pub fn get_plugin_metadata(&self, name: &str) -> Result<PluginMetadata> {
298 let plugins = self
299 .plugins
300 .read()
301 .map_err(|_| SklearsError::FitError("Failed to acquire read lock".to_string()))?;
302
303 plugins
304 .get(name)
305 .map(|plugin| plugin.metadata())
306 .ok_or_else(|| SklearsError::InvalidInput(format!("Plugin '{}' not found", name)))
307 }
308
309 pub fn execute_before_middleware(
311 &self,
312 plugin_name: &str,
313 context: &PluginContext,
314 ) -> Result<()> {
315 let middleware = self
316 .middleware
317 .read()
318 .map_err(|_| SklearsError::FitError("Failed to acquire read lock".to_string()))?;
319
320 for mw in middleware.iter() {
321 mw.before_execution(plugin_name, context)?;
322 }
323
324 Ok(())
325 }
326
327 pub fn execute_after_middleware(
329 &self,
330 plugin_name: &str,
331 context: &PluginContext,
332 result: &PluginResult,
333 ) -> Result<()> {
334 let middleware = self
335 .middleware
336 .read()
337 .map_err(|_| SklearsError::FitError("Failed to acquire read lock".to_string()))?;
338
339 for mw in middleware.iter() {
340 mw.after_execution(plugin_name, context, result)?;
341 }
342
343 Ok(())
344 }
345}
346
347pub trait PluginMiddleware: Send + Sync {
354 fn before_execution(&self, plugin_name: &str, context: &PluginContext) -> Result<()>;
356
357 fn after_execution(
359 &self,
360 plugin_name: &str,
361 context: &PluginContext,
362 result: &PluginResult,
363 ) -> Result<()>;
364}
365
366#[derive(Debug, Clone)]
368pub struct PluginContext {
369 pub operation: String,
370 pub data_shape: (usize, usize),
371 pub parameters: HashMap<String, String>,
372 pub start_time: std::time::Instant,
373}
374
375#[derive(Debug, Clone)]
377pub struct PluginResult {
378 pub success: bool,
379 pub execution_time: std::time::Duration,
380 pub selected_features: Vec<usize>,
381 pub error_message: Option<String>,
382}
383
384pub struct PluginPipeline {
386 steps: Vec<PipelineStep>,
387 registry: Arc<PluginRegistry>,
388}
389
390#[derive(Clone)]
391pub enum PipelineStep {
392 Plugin {
394 name: String,
395
396 config: HashMap<String, String>,
397 },
398 Transformation {
400 name: String,
401
402 config: HashMap<String, String>,
403 },
404 Scoring {
406 name: String,
407 config: HashMap<String, String>,
408 },
409}
410
411impl Default for PluginPipeline {
412 fn default() -> Self {
413 Self::new()
414 }
415}
416
417impl PluginPipeline {
418 pub fn new() -> Self {
420 Self {
421 steps: Vec::new(),
422 registry: Arc::new(PluginRegistry::new()),
423 }
424 }
425
426 pub fn with_registry(registry: Arc<PluginRegistry>) -> Self {
428 Self {
429 steps: Vec::new(),
430 registry,
431 }
432 }
433
434 pub fn add_plugin(mut self, name: String, config: HashMap<String, String>) -> Self {
436 self.steps.push(PipelineStep::Plugin { name, config });
437 self
438 }
439
440 pub fn add_transformation(mut self, name: String, config: HashMap<String, String>) -> Self {
442 self.steps
443 .push(PipelineStep::Transformation { name, config });
444 self
445 }
446
447 pub fn add_scoring(mut self, name: String, config: HashMap<String, String>) -> Self {
449 self.steps.push(PipelineStep::Scoring { name, config });
450 self
451 }
452
453 pub fn execute(&self, X: ArrayView2<f64>, y: ArrayView1<f64>) -> Result<PipelineResult> {
455 let start_time = std::time::Instant::now();
456 let mut current_X = X.to_owned();
457 let mut step_results = Vec::new();
458
459 for (step_index, step) in self.steps.iter().enumerate() {
460 let step_start = std::time::Instant::now();
461
462 match step {
463 PipelineStep::Plugin { name, config } => {
464 let context = PluginContext {
465 operation: "plugin_execution".to_string(),
466 data_shape: (current_X.nrows(), current_X.ncols()),
467 parameters: config.clone(),
468 start_time: step_start,
469 };
470
471 self.registry.execute_before_middleware(name, &context)?;
472
473 let mut plugin = self.registry.get_plugin(name)?;
474 plugin.fit(current_X.view(), y.view())?;
475 current_X = plugin.transform(current_X.view())?;
476 let selected_features = plugin.selected_features()?;
477
478 let result = PluginResult {
479 success: true,
480 execution_time: step_start.elapsed(),
481 selected_features: selected_features.clone(),
482 error_message: None,
483 };
484
485 self.registry
486 .execute_after_middleware(name, &context, &result)?;
487
488 step_results.push(StepResult {
489 step_index,
490 step_type: "Plugin".to_string(),
491 step_name: name.clone(),
492 execution_time: step_start.elapsed(),
493 input_features: context.data_shape.1,
494 output_features: current_X.ncols(),
495 selected_features,
496 });
497 }
498 PipelineStep::Transformation { name, config: _ } => {
499 let transformation = self.registry.get_transformation(name)?;
500 let input_features = current_X.ncols();
501 current_X = transformation.transform(current_X.view())?;
502
503 step_results.push(StepResult {
504 step_index,
505 step_type: "Transformation".to_string(),
506 step_name: name.clone(),
507 execution_time: step_start.elapsed(),
508 input_features,
509 output_features: current_X.ncols(),
510 selected_features: (0..current_X.ncols()).collect(),
511 });
512 }
513 PipelineStep::Scoring { name, config: _ } => {
514 let scoring_function = self.registry.get_scoring_function(name)?;
515 let _scores = scoring_function.score_features(current_X.view(), y.view())?;
516
517 step_results.push(StepResult {
518 step_index,
519 step_type: "Scoring".to_string(),
520 step_name: name.clone(),
521 execution_time: step_start.elapsed(),
522 input_features: current_X.ncols(),
523 output_features: current_X.ncols(),
524 selected_features: (0..current_X.ncols()).collect(),
525 });
526 }
527 }
528 }
529
530 Ok(PipelineResult {
531 final_data: current_X.clone(),
532 step_results,
533 total_execution_time: start_time.elapsed(),
534 original_features: X.ncols(),
535 final_features: current_X.ncols(),
536 })
537 }
538}
539
540#[derive(Debug, Clone)]
542pub struct PipelineResult {
543 pub final_data: Array2<f64>,
544 pub step_results: Vec<StepResult>,
545 pub total_execution_time: std::time::Duration,
546 pub original_features: usize,
547 pub final_features: usize,
548}
549
550#[derive(Debug, Clone)]
552pub struct StepResult {
553 pub step_index: usize,
554 pub step_type: String,
555 pub step_name: String,
556 pub execution_time: std::time::Duration,
557 pub input_features: usize,
558 pub output_features: usize,
559 pub selected_features: Vec<usize>,
560}
561
562pub mod builtin {
564 use super::*;
565
566 #[derive(Debug, Clone)]
568 pub struct VarianceThresholdPlugin {
569 threshold: f64,
570 feature_variances: Option<Array1<f64>>,
571 selected_indices: Option<Vec<usize>>,
572 fitted: bool,
573 }
574
575 impl VarianceThresholdPlugin {
576 pub fn new(threshold: f64) -> Self {
577 Self {
578 threshold,
579 feature_variances: None,
580 selected_indices: None,
581 fitted: false,
582 }
583 }
584 }
585
586 impl FeatureSelectionPlugin for VarianceThresholdPlugin {
587 fn name(&self) -> &str {
588 "variance_threshold"
589 }
590
591 fn version(&self) -> &str {
592 "1.0.0"
593 }
594
595 fn description(&self) -> &str {
596 "Removes features with variance below threshold"
597 }
598
599 fn metadata(&self) -> PluginMetadata {
600 PluginMetadata {
601 author: "Sklears Team".to_string(),
602 license: "MIT".to_string(),
603 categories: vec!["filter".to_string(), "univariate".to_string()],
604 tags: vec!["variance".to_string(), "threshold".to_string()],
605 min_samples: None,
606 max_features: None,
607 supports_sparse: true,
608 supports_multiclass: true,
609 supports_regression: true,
610 computational_complexity: ComputationalComplexity::Linear,
611 memory_complexity: MemoryComplexity::Linear,
612 }
613 }
614
615 fn fit(&mut self, X: ArrayView2<f64>, _y: ArrayView1<f64>) -> Result<()> {
616 let mut variances = Array1::zeros(X.ncols());
617 for (i, var) in variances.iter_mut().enumerate() {
618 *var = X.column(i).var(1.0);
619 }
620
621 let selected_indices: Vec<usize> = variances
622 .iter()
623 .enumerate()
624 .filter_map(|(i, &var)| if var > self.threshold { Some(i) } else { None })
625 .collect();
626
627 self.feature_variances = Some(variances);
628 self.selected_indices = Some(selected_indices);
629 self.fitted = true;
630
631 Ok(())
632 }
633
634 fn transform(&self, X: ArrayView2<f64>) -> Result<Array2<f64>> {
635 if !self.fitted {
636 return Err(SklearsError::FitError("Plugin not fitted".to_string()));
637 }
638
639 let selected_indices = self.selected_indices.as_ref().unwrap();
640 if selected_indices.is_empty() {
641 return Err(SklearsError::InvalidInput(
642 "No features selected".to_string(),
643 ));
644 }
645
646 let mut result = Array2::zeros((X.nrows(), selected_indices.len()));
647 for (new_col, &old_col) in selected_indices.iter().enumerate() {
648 for row in 0..X.nrows() {
649 result[[row, new_col]] = X[[row, old_col]];
650 }
651 }
652
653 Ok(result)
654 }
655
656 fn selected_features(&self) -> Result<Vec<usize>> {
657 self.selected_indices
658 .clone()
659 .ok_or_else(|| SklearsError::FitError("Plugin not fitted".to_string()))
660 }
661
662 fn feature_scores(&self) -> Result<Array1<f64>> {
663 self.feature_variances
664 .clone()
665 .ok_or_else(|| SklearsError::FitError("Plugin not fitted".to_string()))
666 }
667
668 fn is_fitted(&self) -> bool {
669 self.fitted
670 }
671
672 fn as_any(&self) -> &dyn Any {
673 self
674 }
675
676 fn clone_plugin(&self) -> Box<dyn FeatureSelectionPlugin> {
677 Box::new(self.clone())
678 }
679 }
680
681 #[derive(Debug, Clone)]
683 pub struct CorrelationScoring;
684
685 impl ScoringFunction for CorrelationScoring {
686 fn name(&self) -> &str {
687 "correlation"
688 }
689
690 fn score(&self, feature: ArrayView1<f64>, target: ArrayView1<f64>) -> Result<f64> {
691 if feature.len() != target.len() {
692 return Err(SklearsError::InvalidInput(
693 "Feature and target length mismatch".to_string(),
694 ));
695 }
696
697 let correlation = crate::performance::SIMDStats::correlation_auto(feature, target);
698 Ok(correlation.abs())
699 }
700
701 fn clone_scoring(&self) -> Box<dyn ScoringFunction> {
702 Box::new(self.clone())
703 }
704 }
705
706 #[derive(Debug, Clone)]
708 pub struct NormalizationTransform;
709
710 impl TransformationFunction for NormalizationTransform {
711 fn name(&self) -> &str {
712 "normalization"
713 }
714
715 fn transform(&self, X: ArrayView2<f64>) -> Result<Array2<f64>> {
716 let mut result = X.to_owned();
717
718 for col in 0..result.ncols() {
719 let column = result.column(col);
720 let mean = column.mean().unwrap_or(0.0);
721 let std = column.var(1.0).sqrt();
722
723 if std > 1e-10 {
724 for row in 0..result.nrows() {
725 result[[row, col]] = (result[[row, col]] - mean) / std;
726 }
727 }
728 }
729
730 Ok(result)
731 }
732
733 fn output_features(&self, input_features: usize) -> Option<usize> {
734 Some(input_features)
735 }
736
737 fn clone_transform(&self) -> Box<dyn TransformationFunction> {
738 Box::new(self.clone())
739 }
740 }
741}
742
743#[derive(Debug, Clone)]
745pub struct LoggingMiddleware {
746 log_level: LogLevel,
747}
748
749#[derive(Debug, Clone)]
750pub enum LogLevel {
751 Debug,
753 Info,
755 Warning,
757 Error,
759}
760
761impl LoggingMiddleware {
762 pub fn new(log_level: LogLevel) -> Self {
763 Self { log_level }
764 }
765}
766
767impl PluginMiddleware for LoggingMiddleware {
768 fn before_execution(&self, plugin_name: &str, context: &PluginContext) -> Result<()> {
769 match self.log_level {
770 LogLevel::Debug | LogLevel::Info => {
771 println!(
772 "Executing plugin '{}' with operation '{}'",
773 plugin_name, context.operation
774 );
775 println!(" Data shape: {:?}", context.data_shape);
776 }
777 _ => {}
778 }
779 Ok(())
780 }
781
782 fn after_execution(
783 &self,
784 plugin_name: &str,
785 _context: &PluginContext,
786 result: &PluginResult,
787 ) -> Result<()> {
788 match self.log_level {
789 LogLevel::Debug | LogLevel::Info => {
790 println!(
791 "Plugin '{}' completed in {:?}",
792 plugin_name, result.execution_time
793 );
794 println!(" Selected {} features", result.selected_features.len());
795 if let Some(ref error) = result.error_message {
796 println!(" Error: {}", error);
797 }
798 }
799 _ => {}
800 }
801 Ok(())
802 }
803}
804
805#[derive(Debug)]
807pub struct PerformanceMiddleware {
808 metrics: Arc<RwLock<HashMap<String, PerformanceMetrics>>>,
809}
810
811#[derive(Debug, Clone)]
812pub struct PerformanceMetrics {
813 pub total_executions: usize,
814 pub total_time: std::time::Duration,
815 pub average_time: std::time::Duration,
816 pub min_time: std::time::Duration,
817 pub max_time: std::time::Duration,
818}
819
820impl Default for PerformanceMiddleware {
821 fn default() -> Self {
822 Self::new()
823 }
824}
825
826impl PerformanceMiddleware {
827 pub fn new() -> Self {
828 Self {
829 metrics: Arc::new(RwLock::new(HashMap::new())),
830 }
831 }
832
833 pub fn get_metrics(&self) -> Result<HashMap<String, PerformanceMetrics>> {
834 let metrics = self
835 .metrics
836 .read()
837 .map_err(|_| SklearsError::FitError("Failed to acquire read lock".to_string()))?;
838 Ok(metrics.clone())
839 }
840}
841
842impl PluginMiddleware for PerformanceMiddleware {
843 fn before_execution(&self, _plugin_name: &str, _context: &PluginContext) -> Result<()> {
844 Ok(())
846 }
847
848 fn after_execution(
849 &self,
850 plugin_name: &str,
851 _context: &PluginContext,
852 result: &PluginResult,
853 ) -> Result<()> {
854 let mut metrics = self
855 .metrics
856 .write()
857 .map_err(|_| SklearsError::FitError("Failed to acquire write lock".to_string()))?;
858
859 let entry = metrics
860 .entry(plugin_name.to_string())
861 .or_insert_with(|| PerformanceMetrics {
862 total_executions: 0,
863 total_time: std::time::Duration::from_secs(0),
864 average_time: std::time::Duration::from_secs(0),
865 min_time: std::time::Duration::from_secs(u64::MAX),
866 max_time: std::time::Duration::from_secs(0),
867 });
868
869 entry.total_executions += 1;
870 entry.total_time += result.execution_time;
871 entry.average_time = entry.total_time / entry.total_executions as u32;
872 entry.min_time = entry.min_time.min(result.execution_time);
873 entry.max_time = entry.max_time.max(result.execution_time);
874
875 Ok(())
876 }
877}
878
879#[macro_export]
881macro_rules! register_plugin {
882 ($registry:expr, $plugin:expr) => {
883 $registry.register_plugin(Box::new($plugin))?;
884 };
885}
886
887#[macro_export]
889macro_rules! plugin_pipeline {
890 ($($step_type:ident($name:expr, $config:expr)),+ $(,)?) => {
891 {
892 let mut pipeline = PluginPipeline::new();
893 $(
894 pipeline = match stringify!($step_type) {
895 "plugin" => pipeline.add_plugin($name.to_string(), $config),
896 "transform" => pipeline.add_transformation($name.to_string(), $config),
897 "scoring" => pipeline.add_scoring($name.to_string(), $config),
898 _ => panic!("Unknown step type: {}", stringify!($step_type)),
899 };
900 )+
901 pipeline
902 }
903 };
904}
905
906#[allow(non_snake_case)]
907#[cfg(test)]
908mod tests {
909 use super::builtin::*;
910 use super::*;
911 use scirs2_core::ndarray::array;
912
913 #[test]
914 fn test_plugin_registry() -> Result<()> {
915 let registry = PluginRegistry::new();
916
917 let plugin = VarianceThresholdPlugin::new(0.1);
919 registry.register_plugin(Box::new(plugin))?;
920
921 let scoring = CorrelationScoring;
923 registry.register_scoring_function(Box::new(scoring))?;
924
925 let transform = NormalizationTransform;
927 registry.register_transformation(Box::new(transform))?;
928
929 let retrieved_plugin = registry.get_plugin("variance_threshold")?;
931 assert_eq!(retrieved_plugin.name(), "variance_threshold");
932
933 let retrieved_scoring = registry.get_scoring_function("correlation")?;
934 assert_eq!(retrieved_scoring.name(), "correlation");
935
936 let retrieved_transform = registry.get_transformation("normalization")?;
937 assert_eq!(retrieved_transform.name(), "normalization");
938
939 Ok(())
940 }
941
942 #[test]
943 #[allow(non_snake_case)]
944 fn test_plugin_execution() -> Result<()> {
945 let X = array![
946 [1.0, 2.0, 0.0],
947 [2.0, 4.0, 0.0],
948 [3.0, 6.0, 0.0],
949 [4.0, 8.0, 0.0],
950 ];
951 let y = array![1.0, 2.0, 3.0, 4.0];
952
953 let mut plugin = VarianceThresholdPlugin::new(0.1);
954 plugin.fit(X.view(), y.view())?;
955
956 let selected_features = plugin.selected_features()?;
957 assert!(selected_features.len() <= 3);
958
959 let transformed = plugin.transform(X.view())?;
960 assert_eq!(transformed.ncols(), selected_features.len());
961
962 Ok(())
963 }
964
965 #[test]
966 #[allow(non_snake_case)]
967 fn test_plugin_pipeline() -> Result<()> {
968 let registry = Arc::new(PluginRegistry::new());
969
970 registry.register_plugin(Box::new(VarianceThresholdPlugin::new(0.1)))?;
972 registry.register_transformation(Box::new(NormalizationTransform))?;
973 registry.register_scoring_function(Box::new(CorrelationScoring))?;
974
975 let pipeline = PluginPipeline::with_registry(registry)
976 .add_transformation("normalization".to_string(), HashMap::new())
977 .add_plugin("variance_threshold".to_string(), HashMap::new());
978
979 let X = array![
980 [1.0, 2.0, 3.0],
981 [2.0, 4.0, 6.0],
982 [3.0, 6.0, 9.0],
983 [4.0, 8.0, 12.0],
984 ];
985 let y = array![1.0, 2.0, 3.0, 4.0];
986
987 let result = pipeline.execute(X.view(), y.view())?;
988 assert!(result.final_features <= 3);
989 assert_eq!(result.step_results.len(), 2);
990
991 Ok(())
992 }
993
994 #[test]
995 fn test_middleware() -> Result<()> {
996 let registry = PluginRegistry::new();
997
998 let logging_middleware = LoggingMiddleware::new(LogLevel::Info);
1000 registry.register_middleware(Box::new(logging_middleware))?;
1001
1002 let performance_middleware = PerformanceMiddleware::new();
1003 registry.register_middleware(Box::new(performance_middleware))?;
1004
1005 registry.register_plugin(Box::new(VarianceThresholdPlugin::new(0.1)))?;
1007
1008 let context = PluginContext {
1009 operation: "test".to_string(),
1010 data_shape: (100, 10),
1011 parameters: HashMap::new(),
1012 start_time: std::time::Instant::now(),
1013 };
1014
1015 registry.execute_before_middleware("variance_threshold", &context)?;
1016
1017 let result = PluginResult {
1018 success: true,
1019 execution_time: std::time::Duration::from_millis(10),
1020 selected_features: vec![0, 1, 2],
1021 error_message: None,
1022 };
1023
1024 registry.execute_after_middleware("variance_threshold", &context, &result)?;
1025
1026 Ok(())
1027 }
1028
1029 #[test]
1030 fn test_macro_pipeline() -> Result<()> {
1031 let pipeline = plugin_pipeline! {
1032 transform("normalization", HashMap::new()),
1033 plugin("variance_threshold", HashMap::new()),
1034 scoring("correlation", HashMap::new()),
1035 };
1036
1037 assert_eq!(pipeline.steps.len(), 3);
1039
1040 Ok(())
1041 }
1042}