1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
11use sklears_core::error::{Result, SklearsError};
12use sklears_core::traits::Estimator;
13use std::any::Any;
14use std::collections::HashMap;
15use std::fmt::Debug;
16use std::sync::{Arc, Mutex, RwLock};
17
18pub trait BaselinePlugin: Send + Sync + Debug {
20 fn name(&self) -> &str;
22
23 fn version(&self) -> &str;
25
26 fn description(&self) -> &str;
28
29 fn initialize(&mut self, config: &PluginConfig) -> Result<()>;
31
32 fn shutdown(&mut self) -> Result<()>;
34
35 fn is_compatible(&self, data_info: &DataInfo) -> bool;
37
38 fn metadata(&self) -> PluginMetadata;
40}
41
42#[derive(Debug, Clone)]
44#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
45pub struct PluginConfig {
46 pub parameters: HashMap<String, PluginParameter>,
48 pub resources: ResourceConfig,
50 pub logging: LoggingConfig,
52}
53
54#[derive(Debug, Clone)]
55#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
56pub struct ResourceConfig {
57 pub max_memory_mb: usize,
59 pub max_cpu_cores: usize,
61 pub temp_directory: String,
63 pub cache_enabled: bool,
65}
66
67#[derive(Debug, Clone)]
68#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
69pub struct LoggingConfig {
70 pub level: LogLevel,
72 pub output_file: Option<String>,
74 pub include_timestamps: bool,
76}
77
78#[derive(Debug, Clone)]
79#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
80pub enum LogLevel {
81 Error,
83 Warn,
85 Info,
87 Debug,
89 Trace,
91}
92
93#[derive(Debug, Clone)]
94#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
95pub enum PluginParameter {
96 Integer(i64),
98 Float(f64),
100 String(String),
102 Boolean(bool),
104 Array(Vec<PluginParameter>),
106}
107
108#[derive(Debug, Clone)]
110pub struct PluginMetadata {
111 pub author: String,
113 pub license: String,
115 pub homepage: String,
117 pub supported_tasks: Vec<TaskType>,
119 pub requirements: Vec<String>,
121}
122
123#[derive(Debug, Clone, PartialEq)]
124pub enum TaskType {
125 Classification,
127 Regression,
129 Clustering,
131 DimensionalityReduction,
133}
134
135#[derive(Debug, Clone)]
137pub struct DataInfo {
138 pub n_samples: usize,
140 pub n_features: usize,
142 pub feature_types: Vec<FeatureType>,
144 pub target_type: TargetType,
146 pub missing_values: bool,
148 pub sparse: bool,
150}
151
152#[derive(Debug, Clone)]
153pub enum FeatureType {
154 Continuous,
156 Categorical,
158 Binary,
160 Ordinal,
162 Text,
164}
165
166#[derive(Debug, Clone)]
167pub enum TargetType {
168 Continuous,
170 Binary,
172 Multiclass,
174 Multilabel,
176}
177
178pub struct PluginRegistry {
180 plugins: RwLock<HashMap<String, Arc<dyn BaselinePlugin>>>,
181 plugin_configs: RwLock<HashMap<String, PluginConfig>>,
182 active_plugins: RwLock<HashMap<String, bool>>,
183}
184
185impl Default for PluginRegistry {
186 fn default() -> Self {
187 Self::new()
188 }
189}
190
191impl PluginRegistry {
192 pub fn new() -> Self {
194 Self {
195 plugins: RwLock::new(HashMap::new()),
196 plugin_configs: RwLock::new(HashMap::new()),
197 active_plugins: RwLock::new(HashMap::new()),
198 }
199 }
200
201 pub fn register_plugin(
203 &self,
204 plugin: Arc<dyn BaselinePlugin>,
205 config: PluginConfig,
206 ) -> Result<()> {
207 let name = plugin.name().to_string();
208
209 if self.plugins.read().unwrap().contains_key(&name) {
211 return Err(SklearsError::InvalidInput(format!(
212 "Plugin '{}' already registered",
213 name
214 )));
215 }
216
217 self.plugins.write().unwrap().insert(name.clone(), plugin);
219 self.plugin_configs
220 .write()
221 .unwrap()
222 .insert(name.clone(), config);
223 self.active_plugins.write().unwrap().insert(name, false);
224
225 Ok(())
226 }
227
228 pub fn activate_plugin(&self, name: &str) -> Result<()> {
230 let plugins = self.plugins.read().unwrap();
231 let plugin_configs = self.plugin_configs.write().unwrap();
232 let mut active_plugins = self.active_plugins.write().unwrap();
233
234 if let Some(plugin) = plugins.get(name) {
235 if let Some(config) = plugin_configs.get(name) {
236 active_plugins.insert(name.to_string(), true);
238 Ok(())
239 } else {
240 Err(SklearsError::InvalidInput(format!(
241 "No configuration found for plugin '{}'",
242 name
243 )))
244 }
245 } else {
246 Err(SklearsError::InvalidInput(format!(
247 "Plugin '{}' not found",
248 name
249 )))
250 }
251 }
252
253 pub fn deactivate_plugin(&self, name: &str) -> Result<()> {
255 let mut active_plugins = self.active_plugins.write().unwrap();
256
257 if active_plugins.contains_key(name) {
258 active_plugins.insert(name.to_string(), false);
259 Ok(())
260 } else {
261 Err(SklearsError::InvalidInput(format!(
262 "Plugin '{}' not found",
263 name
264 )))
265 }
266 }
267
268 pub fn list_plugins(&self) -> Vec<String> {
270 self.plugins.read().unwrap().keys().cloned().collect()
271 }
272
273 pub fn get_active_plugins(&self) -> Vec<String> {
275 self.active_plugins
276 .read()
277 .unwrap()
278 .iter()
279 .filter_map(|(name, &active)| if active { Some(name.clone()) } else { None })
280 .collect()
281 }
282
283 pub fn find_compatible_plugins(&self, data_info: &DataInfo) -> Vec<String> {
285 let plugins = self.plugins.read().unwrap();
286 let active_plugins = self.active_plugins.read().unwrap();
287
288 plugins
289 .iter()
290 .filter(|(name, plugin)| {
291 *active_plugins.get(*name).unwrap_or(&false) && plugin.is_compatible(data_info)
292 })
293 .map(|(name, _)| name.clone())
294 .collect()
295 }
296}
297
298pub struct HookSystem {
300 pre_fit_hooks: Arc<Mutex<Vec<Box<dyn PreFitHook>>>>,
301 post_fit_hooks: Arc<Mutex<Vec<Box<dyn PostFitHook>>>>,
302 pre_predict_hooks: Arc<Mutex<Vec<Box<dyn PrePredictHook>>>>,
303 post_predict_hooks: Arc<Mutex<Vec<Box<dyn PostPredictHook>>>>,
304 error_hooks: Arc<Mutex<Vec<Box<dyn ErrorHook>>>>,
305}
306
307pub trait PreFitHook: Send + Sync + Debug {
309 fn execute(&self, context: &mut FitContext) -> Result<()>;
310 fn priority(&self) -> i32 {
311 0
312 }
313}
314
315pub trait PostFitHook: Send + Sync + Debug {
317 fn execute(&self, context: &FitContext, result: &FitResult) -> Result<()>;
318 fn priority(&self) -> i32 {
319 0
320 }
321}
322
323pub trait PrePredictHook: Send + Sync + Debug {
325 fn execute(&self, context: &mut PredictContext) -> Result<()>;
326 fn priority(&self) -> i32 {
327 0
328 }
329}
330
331pub trait PostPredictHook: Send + Sync + Debug {
333 fn execute(&self, context: &PredictContext, predictions: &mut Array1<f64>) -> Result<()>;
334 fn priority(&self) -> i32 {
335 0
336 }
337}
338
339pub trait ErrorHook: Send + Sync + Debug {
341 fn execute(&self, context: &ErrorContext) -> Result<()>;
342 fn priority(&self) -> i32 {
343 0
344 }
345}
346
347#[derive(Debug)]
349pub struct FitContext {
350 pub estimator_name: String,
352 pub strategy: String,
354 pub x: Array2<f64>,
356 pub y: Array1<f64>,
358 pub metadata: HashMap<String, String>,
360 pub start_time: std::time::Instant,
362}
363
364#[derive(Debug)]
366pub struct FitResult {
367 pub success: bool,
369 pub duration: std::time::Duration,
371 pub parameters: HashMap<String, f64>,
373 pub metrics: HashMap<String, f64>,
375}
376
377#[derive(Debug)]
379pub struct PredictContext {
380 pub estimator_name: String,
382 pub strategy: String,
384 pub x: Array2<f64>,
386 pub metadata: HashMap<String, String>,
388 pub start_time: std::time::Instant,
390}
391
392#[derive(Debug)]
394pub struct ErrorContext {
395 pub operation: String,
397 pub error: String,
399 pub estimator_name: String,
401 pub context_data: HashMap<String, String>,
403 pub timestamp: std::time::Instant,
405}
406
407impl Default for HookSystem {
408 fn default() -> Self {
409 Self::new()
410 }
411}
412
413impl HookSystem {
414 pub fn new() -> Self {
416 Self {
417 pre_fit_hooks: Arc::new(Mutex::new(Vec::new())),
418 post_fit_hooks: Arc::new(Mutex::new(Vec::new())),
419 pre_predict_hooks: Arc::new(Mutex::new(Vec::new())),
420 post_predict_hooks: Arc::new(Mutex::new(Vec::new())),
421 error_hooks: Arc::new(Mutex::new(Vec::new())),
422 }
423 }
424
425 pub fn add_pre_fit_hook(&self, hook: Box<dyn PreFitHook>) {
427 let mut hooks = self.pre_fit_hooks.lock().unwrap();
428 hooks.push(hook);
429 hooks.sort_by_key(|h| -h.priority()); }
431
432 pub fn add_post_fit_hook(&self, hook: Box<dyn PostFitHook>) {
434 let mut hooks = self.post_fit_hooks.lock().unwrap();
435 hooks.push(hook);
436 hooks.sort_by_key(|h| -h.priority());
437 }
438
439 pub fn add_pre_predict_hook(&self, hook: Box<dyn PrePredictHook>) {
441 let mut hooks = self.pre_predict_hooks.lock().unwrap();
442 hooks.push(hook);
443 hooks.sort_by_key(|h| -h.priority());
444 }
445
446 pub fn add_post_predict_hook(&self, hook: Box<dyn PostPredictHook>) {
448 let mut hooks = self.post_predict_hooks.lock().unwrap();
449 hooks.push(hook);
450 hooks.sort_by_key(|h| -h.priority());
451 }
452
453 pub fn add_error_hook(&self, hook: Box<dyn ErrorHook>) {
455 let mut hooks = self.error_hooks.lock().unwrap();
456 hooks.push(hook);
457 hooks.sort_by_key(|h| -h.priority());
458 }
459
460 pub fn execute_pre_fit_hooks(&self, context: &mut FitContext) -> Result<()> {
462 let hooks = self.pre_fit_hooks.lock().unwrap();
463 for hook in hooks.iter() {
464 hook.execute(context)?;
465 }
466 Ok(())
467 }
468
469 pub fn execute_post_fit_hooks(&self, context: &FitContext, result: &FitResult) -> Result<()> {
471 let hooks = self.post_fit_hooks.lock().unwrap();
472 for hook in hooks.iter() {
473 hook.execute(context, result)?;
474 }
475 Ok(())
476 }
477
478 pub fn execute_pre_predict_hooks(&self, context: &mut PredictContext) -> Result<()> {
480 let hooks = self.pre_predict_hooks.lock().unwrap();
481 for hook in hooks.iter() {
482 hook.execute(context)?;
483 }
484 Ok(())
485 }
486
487 pub fn execute_post_predict_hooks(
489 &self,
490 context: &PredictContext,
491 predictions: &mut Array1<f64>,
492 ) -> Result<()> {
493 let hooks = self.post_predict_hooks.lock().unwrap();
494 for hook in hooks.iter() {
495 hook.execute(context, predictions)?;
496 }
497 Ok(())
498 }
499
500 pub fn execute_error_hooks(&self, context: &ErrorContext) -> Result<()> {
502 let hooks = self.error_hooks.lock().unwrap();
503 for hook in hooks.iter() {
504 hook.execute(context)?;
505 }
506 Ok(())
507 }
508}
509
510pub trait PipelineMiddleware: Send + Sync + Debug {
512 fn name(&self) -> &str;
514
515 fn before(&self, context: &mut MiddlewareContext) -> Result<()>;
517
518 fn after(&self, context: &mut MiddlewareContext, result: &mut MiddlewareResult) -> Result<()>;
520
521 fn on_error(&self, context: &MiddlewareContext, error: &SklearsError) -> Result<()> {
523 eprintln!("Middleware '{}' error: {:?}", self.name(), error);
525 Ok(())
526 }
527}
528
529#[derive(Debug)]
531pub struct MiddlewareContext {
532 pub operation: String,
534 pub parameters: HashMap<String, MiddlewareParameter>,
536 pub data: HashMap<String, Box<dyn Any + Send>>,
538 pub metrics: HashMap<String, f64>,
540 pub start_time: std::time::Instant,
542}
543
544#[derive(Debug)]
546pub struct MiddlewareResult {
547 pub success: bool,
549 pub duration: std::time::Duration,
551 pub data: HashMap<String, Box<dyn Any + Send>>,
553 pub metrics: HashMap<String, f64>,
555 pub warnings: Vec<String>,
557}
558
559#[derive(Debug, Clone)]
560pub enum MiddlewareParameter {
561 Integer(i64),
563 Float(f64),
565 String(String),
567 Boolean(bool),
569 Array(Vec<MiddlewareParameter>),
571}
572
573pub struct MiddlewarePipeline {
575 middleware: Vec<Box<dyn PipelineMiddleware>>,
576 error_handler: Option<Box<dyn Fn(&SklearsError) -> Result<()> + Send + Sync>>,
577}
578
579impl Default for MiddlewarePipeline {
580 fn default() -> Self {
581 Self::new()
582 }
583}
584
585impl MiddlewarePipeline {
586 pub fn new() -> Self {
588 Self {
589 middleware: Vec::new(),
590 error_handler: None,
591 }
592 }
593
594 pub fn add_middleware(&mut self, middleware: Box<dyn PipelineMiddleware>) {
596 self.middleware.push(middleware);
597 }
598
599 pub fn set_error_handler<F>(&mut self, handler: F)
601 where
602 F: Fn(&SklearsError) -> Result<()> + Send + Sync + 'static,
603 {
604 self.error_handler = Some(Box::new(handler));
605 }
606
607 pub fn execute<F>(
609 &self,
610 mut context: MiddlewareContext,
611 operation: F,
612 ) -> Result<MiddlewareResult>
613 where
614 F: FnOnce(&mut MiddlewareContext) -> Result<MiddlewareResult>,
615 {
616 for middleware in &self.middleware {
618 if let Err(e) = middleware.before(&mut context) {
619 middleware.on_error(&context, &e)?;
620 if let Some(handler) = &self.error_handler {
621 handler(&e)?;
622 }
623 return Err(e);
624 }
625 }
626
627 let mut result = match operation(&mut context) {
629 Ok(result) => result,
630 Err(e) => {
631 for middleware in &self.middleware {
633 middleware.on_error(&context, &e)?;
634 }
635 if let Some(handler) = &self.error_handler {
636 handler(&e)?;
637 }
638 return Err(e);
639 }
640 };
641
642 for middleware in self.middleware.iter().rev() {
644 if let Err(e) = middleware.after(&mut context, &mut result) {
645 middleware.on_error(&context, &e)?;
646 if let Some(handler) = &self.error_handler {
647 handler(&e)?;
648 }
649 return Err(e);
650 }
651 }
652
653 Ok(result)
654 }
655}
656
657pub mod middleware {
659 use super::*;
660
661 #[derive(Debug)]
663 pub struct LoggingMiddleware {
664 pub log_level: LogLevel,
666 pub include_timing: bool,
668 pub include_parameters: bool,
670 }
671
672 impl LoggingMiddleware {
673 pub fn new(log_level: LogLevel) -> Self {
674 Self {
675 log_level,
676 include_timing: true,
677 include_parameters: true,
678 }
679 }
680 }
681
682 impl PipelineMiddleware for LoggingMiddleware {
683 fn name(&self) -> &str {
684 "logging"
685 }
686
687 fn before(&self, context: &mut MiddlewareContext) -> Result<()> {
688 println!("Starting operation: {}", context.operation);
689 if self.include_parameters {
690 println!("Parameters: {:?}", context.parameters);
691 }
692 Ok(())
693 }
694
695 fn after(
696 &self,
697 context: &mut MiddlewareContext,
698 result: &mut MiddlewareResult,
699 ) -> Result<()> {
700 if self.include_timing {
701 println!(
702 "Operation '{}' completed in {:?}",
703 context.operation, result.duration
704 );
705 }
706 if !result.warnings.is_empty() {
707 println!("Warnings: {:?}", result.warnings);
708 }
709 Ok(())
710 }
711 }
712
713 #[derive(Debug)]
715 pub struct ValidationMiddleware {
716 pub validate_inputs: bool,
718 pub validate_outputs: bool,
720 pub strict_mode: bool,
722 }
723
724 impl Default for ValidationMiddleware {
725 fn default() -> Self {
726 Self::new()
727 }
728 }
729
730 impl ValidationMiddleware {
731 pub fn new() -> Self {
732 Self {
733 validate_inputs: true,
734 validate_outputs: true,
735 strict_mode: false,
736 }
737 }
738 }
739
740 impl PipelineMiddleware for ValidationMiddleware {
741 fn name(&self) -> &str {
742 "validation"
743 }
744
745 fn before(&self, context: &mut MiddlewareContext) -> Result<()> {
746 if self.validate_inputs {
747 for (key, value) in &context.parameters {
749 match value {
750 MiddlewareParameter::Float(f) => {
751 if f.is_nan() || f.is_infinite() {
752 return Err(SklearsError::InvalidInput(format!(
753 "Invalid float value for parameter '{}': {}",
754 key, f
755 )));
756 }
757 }
758 _ => {} }
760 }
761 }
762 Ok(())
763 }
764
765 fn after(
766 &self,
767 _context: &mut MiddlewareContext,
768 result: &mut MiddlewareResult,
769 ) -> Result<()> {
770 if self.validate_outputs && !result.success {
771 if self.strict_mode {
772 return Err(SklearsError::InvalidInput(
773 "Operation failed validation".to_string(),
774 ));
775 } else {
776 result
777 .warnings
778 .push("Operation validation failed".to_string());
779 }
780 }
781 Ok(())
782 }
783 }
784
785 #[derive(Debug)]
787 pub struct PerformanceMiddleware {
788 pub collect_metrics: bool,
790 pub memory_tracking: bool,
792 }
793
794 impl Default for PerformanceMiddleware {
795 fn default() -> Self {
796 Self::new()
797 }
798 }
799
800 impl PerformanceMiddleware {
801 pub fn new() -> Self {
802 Self {
803 collect_metrics: true,
804 memory_tracking: true,
805 }
806 }
807 }
808
809 impl PipelineMiddleware for PerformanceMiddleware {
810 fn name(&self) -> &str {
811 "performance"
812 }
813
814 fn before(&self, context: &mut MiddlewareContext) -> Result<()> {
815 context.start_time = std::time::Instant::now();
816 if self.memory_tracking {
817 context.metrics.insert("memory_before_mb".to_string(), 0.0);
819 }
820 Ok(())
821 }
822
823 fn after(
824 &self,
825 context: &mut MiddlewareContext,
826 result: &mut MiddlewareResult,
827 ) -> Result<()> {
828 result.duration = context.start_time.elapsed();
829
830 if self.collect_metrics {
831 result.metrics.insert(
832 "duration_ms".to_string(),
833 result.duration.as_millis() as f64,
834 );
835 }
836
837 if self.memory_tracking {
838 result.metrics.insert("memory_after_mb".to_string(), 0.0);
840 result.metrics.insert("memory_delta_mb".to_string(), 0.0);
841 }
842
843 Ok(())
844 }
845 }
846}
847
848pub struct CustomStrategyRegistry {
850 classification_strategies: RwLock<HashMap<String, Box<dyn CustomClassificationStrategy>>>,
851 regression_strategies: RwLock<HashMap<String, Box<dyn CustomRegressionStrategy>>>,
852 strategy_metadata: RwLock<HashMap<String, StrategyMetadata>>,
853}
854
855pub trait CustomClassificationStrategy: Send + Sync + Debug {
857 fn name(&self) -> &str;
858 fn predict(&self, x: &ArrayView2<f64>) -> Result<Array1<i32>>;
859 fn predict_proba(&self, x: &ArrayView2<f64>) -> Result<Array2<f64>>;
860 fn fit(&mut self, x: &ArrayView2<f64>, y: &ArrayView1<i32>) -> Result<()>;
861 fn is_fitted(&self) -> bool;
862 fn get_parameters(&self) -> HashMap<String, f64>;
863}
864
865pub trait CustomRegressionStrategy: Send + Sync + Debug {
867 fn name(&self) -> &str;
868 fn predict(&self, x: &ArrayView2<f64>) -> Result<Array1<f64>>;
869 fn fit(&mut self, x: &ArrayView2<f64>, y: &ArrayView1<f64>) -> Result<()>;
870 fn is_fitted(&self) -> bool;
871 fn get_parameters(&self) -> HashMap<String, f64>;
872}
873
874#[derive(Debug, Clone)]
876pub struct StrategyMetadata {
877 pub author: String,
879 pub version: String,
881 pub description: String,
883 pub task_type: TaskType,
885 pub complexity: StrategyComplexity,
887 pub requirements: Vec<String>,
889}
890
891#[derive(Debug, Clone)]
892pub enum StrategyComplexity {
893 Constant,
895 Linear,
897 Quadratic,
899 Exponential,
901}
902
903impl Default for CustomStrategyRegistry {
904 fn default() -> Self {
905 Self::new()
906 }
907}
908
909impl CustomStrategyRegistry {
910 pub fn new() -> Self {
912 Self {
913 classification_strategies: RwLock::new(HashMap::new()),
914 regression_strategies: RwLock::new(HashMap::new()),
915 strategy_metadata: RwLock::new(HashMap::new()),
916 }
917 }
918
919 pub fn register_classification_strategy(
921 &self,
922 strategy: Box<dyn CustomClassificationStrategy>,
923 metadata: StrategyMetadata,
924 ) -> Result<()> {
925 let name = strategy.name().to_string();
926
927 if metadata.task_type != TaskType::Classification {
929 return Err(SklearsError::InvalidInput(
930 "Strategy task type must be Classification".to_string(),
931 ));
932 }
933
934 if self
936 .classification_strategies
937 .read()
938 .unwrap()
939 .contains_key(&name)
940 {
941 return Err(SklearsError::InvalidInput(format!(
942 "Classification strategy '{}' already registered",
943 name
944 )));
945 }
946
947 self.classification_strategies
949 .write()
950 .unwrap()
951 .insert(name.clone(), strategy);
952 self.strategy_metadata
953 .write()
954 .unwrap()
955 .insert(name, metadata);
956
957 Ok(())
958 }
959
960 pub fn register_regression_strategy(
962 &self,
963 strategy: Box<dyn CustomRegressionStrategy>,
964 metadata: StrategyMetadata,
965 ) -> Result<()> {
966 let name = strategy.name().to_string();
967
968 if metadata.task_type != TaskType::Regression {
970 return Err(SklearsError::InvalidInput(
971 "Strategy task type must be Regression".to_string(),
972 ));
973 }
974
975 if self
977 .regression_strategies
978 .read()
979 .unwrap()
980 .contains_key(&name)
981 {
982 return Err(SklearsError::InvalidInput(format!(
983 "Regression strategy '{}' already registered",
984 name
985 )));
986 }
987
988 self.regression_strategies
990 .write()
991 .unwrap()
992 .insert(name.clone(), strategy);
993 self.strategy_metadata
994 .write()
995 .unwrap()
996 .insert(name, metadata);
997
998 Ok(())
999 }
1000
1001 pub fn list_classification_strategies(&self) -> Vec<String> {
1003 self.classification_strategies
1004 .read()
1005 .unwrap()
1006 .keys()
1007 .cloned()
1008 .collect()
1009 }
1010
1011 pub fn list_regression_strategies(&self) -> Vec<String> {
1013 self.regression_strategies
1014 .read()
1015 .unwrap()
1016 .keys()
1017 .cloned()
1018 .collect()
1019 }
1020
1021 pub fn get_strategy_metadata(&self, name: &str) -> Option<StrategyMetadata> {
1023 self.strategy_metadata.read().unwrap().get(name).cloned()
1024 }
1025
1026 pub fn get_classification_strategy(
1028 &self,
1029 name: &str,
1030 ) -> Option<Box<dyn CustomClassificationStrategy>> {
1031 if self
1034 .classification_strategies
1035 .read()
1036 .unwrap()
1037 .contains_key(name)
1038 {
1039 None
1041 } else {
1042 None
1043 }
1044 }
1045
1046 pub fn get_regression_strategy(&self, name: &str) -> Option<Box<dyn CustomRegressionStrategy>> {
1048 if self
1051 .regression_strategies
1052 .read()
1053 .unwrap()
1054 .contains_key(name)
1055 {
1056 None
1058 } else {
1059 None
1060 }
1061 }
1062
1063 pub fn unregister_strategy(&self, name: &str) -> Result<()> {
1065 let mut metadata = self.strategy_metadata.write().unwrap();
1066
1067 if let Some(meta) = metadata.remove(name) {
1068 match meta.task_type {
1069 TaskType::Classification => {
1070 self.classification_strategies.write().unwrap().remove(name);
1071 }
1072 TaskType::Regression => {
1073 self.regression_strategies.write().unwrap().remove(name);
1074 }
1075 _ => {
1076 return Err(SklearsError::InvalidInput(
1077 "Unsupported task type for strategy removal".to_string(),
1078 ))
1079 }
1080 }
1081 Ok(())
1082 } else {
1083 Err(SklearsError::InvalidInput(format!(
1084 "Strategy '{}' not found",
1085 name
1086 )))
1087 }
1088 }
1089}
1090
1091pub struct EvaluationIntegration {
1093 evaluators: HashMap<String, Box<dyn EvaluationFramework>>,
1094 metrics: HashMap<String, Box<dyn MetricComputer>>,
1095}
1096
1097pub trait EvaluationFramework: Send + Sync + Debug {
1099 fn name(&self) -> &str;
1100 fn evaluate(&self, estimator: &dyn Any, test_data: &TestData) -> Result<EvaluationResult>;
1101 fn supported_tasks(&self) -> Vec<TaskType>;
1102}
1103
1104pub trait MetricComputer: Send + Sync + Debug {
1106 fn name(&self) -> &str;
1107 fn compute(&self, y_true: &ArrayView1<f64>, y_pred: &ArrayView1<f64>) -> Result<MetricResult>;
1108 fn metric_type(&self) -> MetricType;
1109}
1110
1111#[derive(Debug)]
1113pub struct TestData {
1114 pub x: Array2<f64>,
1116 pub y: Array1<f64>,
1118 pub metadata: HashMap<String, String>,
1120}
1121
1122#[derive(Debug)]
1124pub struct EvaluationResult {
1125 pub primary_metric: f64,
1127 pub metrics: HashMap<String, f64>,
1129 pub confidence_intervals: HashMap<String, (f64, f64)>,
1131 pub execution_time: std::time::Duration,
1133 pub warnings: Vec<String>,
1135}
1136
1137#[derive(Debug)]
1139pub struct MetricResult {
1140 pub value: f64,
1142 pub confidence_interval: Option<(f64, f64)>,
1144 pub metadata: HashMap<String, String>,
1146}
1147
1148#[derive(Debug, Clone)]
1149pub enum MetricType {
1150 Accuracy,
1152 Loss,
1154 Similarity,
1156 Distance,
1158 Custom(String),
1160}
1161
1162impl Default for EvaluationIntegration {
1163 fn default() -> Self {
1164 Self::new()
1165 }
1166}
1167
1168impl EvaluationIntegration {
1169 pub fn new() -> Self {
1171 Self {
1172 evaluators: HashMap::new(),
1173 metrics: HashMap::new(),
1174 }
1175 }
1176
1177 pub fn register_evaluator(&mut self, evaluator: Box<dyn EvaluationFramework>) {
1179 let name = evaluator.name().to_string();
1180 self.evaluators.insert(name, evaluator);
1181 }
1182
1183 pub fn register_metric(&mut self, metric: Box<dyn MetricComputer>) {
1185 let name = metric.name().to_string();
1186 self.metrics.insert(name, metric);
1187 }
1188
1189 pub fn evaluate(
1191 &self,
1192 framework_name: &str,
1193 estimator: &dyn Any,
1194 test_data: &TestData,
1195 ) -> Result<EvaluationResult> {
1196 if let Some(evaluator) = self.evaluators.get(framework_name) {
1197 evaluator.evaluate(estimator, test_data)
1198 } else {
1199 Err(SklearsError::InvalidInput(format!(
1200 "Evaluation framework '{}' not found",
1201 framework_name
1202 )))
1203 }
1204 }
1205
1206 pub fn compute_metric(
1208 &self,
1209 metric_name: &str,
1210 y_true: &ArrayView1<f64>,
1211 y_pred: &ArrayView1<f64>,
1212 ) -> Result<MetricResult> {
1213 if let Some(metric) = self.metrics.get(metric_name) {
1214 metric.compute(y_true, y_pred)
1215 } else {
1216 Err(SklearsError::InvalidInput(format!(
1217 "Metric '{}' not found",
1218 metric_name
1219 )))
1220 }
1221 }
1222
1223 pub fn list_evaluators(&self) -> Vec<String> {
1225 self.evaluators.keys().cloned().collect()
1226 }
1227
1228 pub fn list_metrics(&self) -> Vec<String> {
1230 self.metrics.keys().cloned().collect()
1231 }
1232}
1233
1234#[allow(non_snake_case)]
1235#[cfg(test)]
1236mod tests {
1237 use super::*;
1238 use std::sync::Arc;
1239
1240 #[derive(Debug)]
1242 struct MockPlugin {
1243 name: String,
1244 version: String,
1245 }
1246
1247 impl BaselinePlugin for MockPlugin {
1248 fn name(&self) -> &str {
1249 &self.name
1250 }
1251 fn version(&self) -> &str {
1252 &self.version
1253 }
1254 fn description(&self) -> &str {
1255 "Mock plugin for testing"
1256 }
1257
1258 fn initialize(&mut self, _config: &PluginConfig) -> Result<()> {
1259 Ok(())
1260 }
1261 fn shutdown(&mut self) -> Result<()> {
1262 Ok(())
1263 }
1264
1265 fn is_compatible(&self, _data_info: &DataInfo) -> bool {
1266 true
1267 }
1268
1269 fn metadata(&self) -> PluginMetadata {
1270 PluginMetadata {
1272 author: "Test Author".to_string(),
1273 license: "MIT".to_string(),
1274 homepage: "https://example.com".to_string(),
1275 supported_tasks: vec![TaskType::Classification, TaskType::Regression],
1276 requirements: vec![],
1277 }
1278 }
1279 }
1280
1281 #[derive(Debug)]
1283 struct MockPreFitHook;
1284
1285 impl PreFitHook for MockPreFitHook {
1286 fn execute(&self, _context: &mut FitContext) -> Result<()> {
1287 Ok(())
1289 }
1290
1291 fn priority(&self) -> i32 {
1292 1
1293 }
1294 }
1295
1296 #[test]
1297 fn test_plugin_registry() {
1298 let registry = PluginRegistry::new();
1299
1300 let plugin = Arc::new(MockPlugin {
1301 name: "test_plugin".to_string(),
1302 version: "1.0.0".to_string(),
1303 });
1304
1305 let config = PluginConfig {
1306 parameters: HashMap::new(),
1307 resources: ResourceConfig {
1308 max_memory_mb: 1024,
1309 max_cpu_cores: 4,
1310 temp_directory: "/tmp".to_string(),
1311 cache_enabled: true,
1312 },
1313 logging: LoggingConfig {
1314 level: LogLevel::Info,
1315 output_file: None,
1316 include_timestamps: true,
1317 },
1318 };
1319
1320 let result = registry.register_plugin(plugin, config);
1322 assert!(result.is_ok());
1323
1324 let plugins = registry.list_plugins();
1326 assert!(plugins.contains(&"test_plugin".to_string()));
1327
1328 let result = registry.activate_plugin("test_plugin");
1330 assert!(result.is_ok());
1331
1332 let active_plugins = registry.get_active_plugins();
1334 assert!(active_plugins.contains(&"test_plugin".to_string()));
1335 }
1336
1337 #[test]
1338 fn test_hook_system() {
1339 let hook_system = HookSystem::new();
1340
1341 hook_system.add_pre_fit_hook(Box::new(MockPreFitHook));
1343
1344 let mut context = FitContext {
1346 estimator_name: "test".to_string(),
1347 strategy: "mean".to_string(),
1348 x: Array2::zeros((10, 5)),
1349 y: Array1::zeros(10),
1350 metadata: HashMap::new(),
1351 start_time: std::time::Instant::now(),
1352 };
1353
1354 let result = hook_system.execute_pre_fit_hooks(&mut context);
1356 assert!(result.is_ok());
1357 }
1358
1359 #[test]
1360 fn test_middleware_pipeline() {
1361 let mut pipeline = MiddlewarePipeline::new();
1362
1363 pipeline.add_middleware(Box::new(middleware::LoggingMiddleware::new(LogLevel::Info)));
1365
1366 let context = MiddlewareContext {
1368 operation: "test_operation".to_string(),
1369 parameters: HashMap::new(),
1370 data: HashMap::new(),
1371 metrics: HashMap::new(),
1372 start_time: std::time::Instant::now(),
1373 };
1374
1375 let result = pipeline.execute(context, |_ctx| {
1377 Ok(MiddlewareResult {
1378 success: true,
1379 duration: std::time::Duration::from_millis(100),
1380 data: HashMap::new(),
1381 metrics: HashMap::new(),
1382 warnings: Vec::new(),
1383 })
1384 });
1385
1386 assert!(result.is_ok());
1387 let result = result.unwrap();
1388 assert!(result.success);
1389 }
1390
1391 #[test]
1392 fn test_evaluation_integration() {
1393 let integration = EvaluationIntegration::new();
1394
1395 assert!(integration.list_evaluators().is_empty());
1397 assert!(integration.list_metrics().is_empty());
1398 }
1399}