1use crate::{Float, SklResult};
8use scirs2_core::ndarray::{Array1, Array2};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::fmt::Debug;
13use std::sync::{Arc, RwLock};
14
15pub trait ExplanationPlugin: Debug + Send + Sync {
17 fn plugin_id(&self) -> &str;
19
20 fn plugin_name(&self) -> &str;
22
23 fn plugin_version(&self) -> &str;
25
26 fn plugin_description(&self) -> &str;
28
29 fn plugin_author(&self) -> &str;
31
32 fn supported_input_types(&self) -> Vec<InputType>;
34
35 fn supported_output_types(&self) -> Vec<OutputType>;
37
38 fn capabilities(&self) -> PluginCapabilities;
40
41 fn initialize(&mut self, config: &PluginConfig) -> SklResult<()>;
43
44 fn execute(&self, input: &PluginInput) -> SklResult<PluginOutput>;
46
47 fn validate_input(&self, input: &PluginInput) -> SklResult<()>;
49
50 fn cleanup(&mut self) -> SklResult<()>;
52
53 fn metadata(&self) -> PluginMetadata {
55 PluginMetadata {
56 id: self.plugin_id().to_string(),
57 name: self.plugin_name().to_string(),
58 version: self.plugin_version().to_string(),
59 description: self.plugin_description().to_string(),
60 author: self.plugin_author().to_string(),
61 supported_inputs: self.supported_input_types(),
62 supported_outputs: self.supported_output_types(),
63 capabilities: self.capabilities(),
64 created_at: chrono::Utc::now(),
65 }
66 }
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
71pub enum InputType {
72 Tabular,
74 TimeSeries,
76 Image,
78 Text,
80 Graph,
82 Custom(u32),
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
88pub enum OutputType {
89 FeatureImportance,
91 LocalExplanation,
93 GlobalExplanation,
95 CounterfactualExplanation,
97 VisualizationData,
99 Custom(u32),
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct PluginCapabilities {
106 pub local_explanations: bool,
108 pub global_explanations: bool,
110 pub counterfactual_explanations: bool,
112 pub uncertainty_quantification: bool,
114 pub model_agnostic: bool,
116 pub parallel_processing: bool,
118 pub real_time: bool,
120 pub streaming: bool,
122 pub max_dataset_size: Option<usize>,
124 pub max_features: Option<usize>,
126 pub estimated_memory_usage: Option<usize>,
128}
129
130impl Default for PluginCapabilities {
131 fn default() -> Self {
132 Self {
133 local_explanations: false,
134 global_explanations: false,
135 counterfactual_explanations: false,
136 uncertainty_quantification: false,
137 model_agnostic: true,
138 parallel_processing: false,
139 real_time: false,
140 streaming: false,
141 max_dataset_size: None,
142 max_features: None,
143 estimated_memory_usage: None,
144 }
145 }
146}
147
148#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct PluginConfig {
151 pub parameters: HashMap<String, PluginParameter>,
153 pub max_execution_time: Option<u64>,
155 pub memory_limit: Option<usize>,
157 pub num_threads: Option<usize>,
159 pub random_seed: Option<u64>,
161 pub log_level: LogLevel,
163}
164
165impl Default for PluginConfig {
166 fn default() -> Self {
167 Self {
168 parameters: HashMap::new(),
169 max_execution_time: Some(300), memory_limit: Some(1024 * 1024 * 1024), num_threads: Some(1),
172 random_seed: None,
173 log_level: LogLevel::Info,
174 }
175 }
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize)]
180pub enum PluginParameter {
181 Integer(i64),
183 Float(f64),
185 String(String),
187 Boolean(bool),
189 IntegerArray(Vec<i64>),
191 FloatArray(Vec<f64>),
193 StringArray(Vec<String>),
195 BooleanArray(Vec<bool>),
197}
198
199#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
201pub enum LogLevel {
202 Error,
204 Warn,
206 Info,
208 Debug,
210 Trace,
212}
213
214#[derive(Debug, Clone)]
216pub struct PluginInput {
217 pub data: PluginData,
219 pub predictions: Option<Array1<Float>>,
221 pub targets: Option<Array1<Float>>,
223 pub feature_names: Option<Vec<String>>,
225 pub sample_weights: Option<Array1<Float>>,
227 pub metadata: HashMap<String, String>,
229}
230
231#[derive(Debug, Clone)]
233pub enum PluginData {
234 Tabular(Array2<Float>),
236 TimeSeries(Array2<Float>),
238 Image(Array2<Float>),
240 Text(String),
242 Graph(Array2<Float>),
244 Custom(Vec<u8>),
246}
247
248#[derive(Debug, Clone, Serialize, Deserialize)]
250pub struct PluginOutput {
251 pub output_type: OutputType,
253 pub data: PluginOutputData,
255 pub metadata: ExecutionMetadata,
257 pub confidence: Option<Array1<Float>>,
259 pub uncertainty: Option<Array1<Float>>,
261}
262
263#[derive(Debug, Clone, Serialize, Deserialize)]
265pub enum PluginOutputData {
266 FeatureImportance {
268 scores: Vec<Float>,
269
270 feature_names: Vec<String>,
271
272 std_errors: Option<Vec<Float>>,
273 },
274 LocalExplanation {
276 instance_id: usize,
277
278 feature_contributions: Vec<Float>,
279 feature_names: Vec<String>,
280 base_value: Float,
281 },
282 GlobalExplanation {
284 feature_effects: Vec<Float>,
285 feature_names: Vec<String>,
286 interaction_effects: Option<Array2<Float>>,
287 },
288 CounterfactualExplanation {
290 counterfactual_instance: Array1<Float>,
291 feature_changes: Vec<(usize, Float, Float)>, distance: Float,
293 feasibility_score: Float,
294 },
295 VisualizationData {
297 plot_type: String,
298 data: serde_json::Value,
299 config: HashMap<String, String>,
300 },
301 Custom(serde_json::Value),
303}
304
305#[derive(Debug, Clone, Serialize, Deserialize)]
307pub struct ExecutionMetadata {
308 pub execution_time_ms: u64,
310 pub memory_usage_bytes: usize,
312 pub iterations: Option<usize>,
314 pub converged: Option<bool>,
316 pub warnings: Vec<String>,
318 pub timestamp: chrono::DateTime<chrono::Utc>,
320}
321
322#[derive(Debug, Clone, Serialize, Deserialize)]
324pub struct PluginMetadata {
325 pub id: String,
327 pub name: String,
329 pub version: String,
331 pub description: String,
333 pub author: String,
335 pub supported_inputs: Vec<InputType>,
337 pub supported_outputs: Vec<OutputType>,
339 pub capabilities: PluginCapabilities,
341 pub created_at: chrono::DateTime<chrono::Utc>,
343}
344
345#[derive(Debug, Default)]
347pub struct PluginRegistry {
348 plugins: Arc<RwLock<HashMap<String, Arc<dyn ExplanationPlugin>>>>,
349 plugin_configs: Arc<RwLock<HashMap<String, PluginConfig>>>,
350 plugin_metadata: Arc<RwLock<HashMap<String, PluginMetadata>>>,
351}
352
353impl PluginRegistry {
354 pub fn new() -> Self {
356 Self::default()
357 }
358
359 pub fn register_plugin<P: ExplanationPlugin + 'static>(
361 &self,
362 mut plugin: P,
363 config: Option<PluginConfig>,
364 ) -> SklResult<()> {
365 let plugin_id = plugin.plugin_id().to_string();
366 let config = config.unwrap_or_default();
367
368 plugin.initialize(&config)?;
370
371 let metadata = plugin.metadata();
373
374 {
376 let mut plugins = self.plugins.write().map_err(|_| {
377 crate::SklearsError::InvalidInput("Failed to acquire plugins lock".to_string())
378 })?;
379 plugins.insert(plugin_id.clone(), Arc::new(plugin));
380 }
381
382 {
383 let mut configs = self.plugin_configs.write().map_err(|_| {
384 crate::SklearsError::InvalidInput("Failed to acquire configs lock".to_string())
385 })?;
386 configs.insert(plugin_id.clone(), config);
387 }
388
389 {
390 let mut metadata_store = self.plugin_metadata.write().map_err(|_| {
391 crate::SklearsError::InvalidInput("Failed to acquire metadata lock".to_string())
392 })?;
393 metadata_store.insert(plugin_id, metadata);
394 }
395
396 Ok(())
397 }
398
399 pub fn get_plugin(&self, plugin_id: &str) -> Option<Arc<dyn ExplanationPlugin>> {
401 self.plugins.read().ok()?.get(plugin_id).cloned()
402 }
403
404 pub fn get_plugin_config(&self, plugin_id: &str) -> Option<PluginConfig> {
406 self.plugin_configs.read().ok()?.get(plugin_id).cloned()
407 }
408
409 pub fn get_plugin_metadata(&self, plugin_id: &str) -> Option<PluginMetadata> {
411 self.plugin_metadata.read().ok()?.get(plugin_id).cloned()
412 }
413
414 pub fn list_plugins(&self) -> Vec<String> {
416 self.plugins
417 .read()
418 .ok()
419 .map(|plugins| plugins.keys().cloned().collect())
420 .unwrap_or_default()
421 }
422
423 pub fn list_plugins_by_input_type(&self, input_type: InputType) -> Vec<String> {
425 self.plugin_metadata
426 .read()
427 .ok()
428 .map(|metadata| {
429 metadata
430 .iter()
431 .filter(|(_, meta)| meta.supported_inputs.contains(&input_type))
432 .map(|(id, _)| id.clone())
433 .collect()
434 })
435 .unwrap_or_default()
436 }
437
438 pub fn list_plugins_by_output_type(&self, output_type: OutputType) -> Vec<String> {
440 self.plugin_metadata
441 .read()
442 .ok()
443 .map(|metadata| {
444 metadata
445 .iter()
446 .filter(|(_, meta)| meta.supported_outputs.contains(&output_type))
447 .map(|(id, _)| id.clone())
448 .collect()
449 })
450 .unwrap_or_default()
451 }
452
453 pub fn list_plugins_by_capability(&self, capability: PluginCapabilityFilter) -> Vec<String> {
455 self.plugin_metadata
456 .read()
457 .ok()
458 .map(|metadata| {
459 metadata
460 .iter()
461 .filter(|(_, meta)| capability.matches(&meta.capabilities))
462 .map(|(id, _)| id.clone())
463 .collect()
464 })
465 .unwrap_or_default()
466 }
467
468 pub fn execute_plugin(&self, plugin_id: &str, input: &PluginInput) -> SklResult<PluginOutput> {
470 let plugin = self.get_plugin(plugin_id).ok_or_else(|| {
471 crate::SklearsError::InvalidInput(format!("Plugin '{}' not found", plugin_id))
472 })?;
473
474 plugin.validate_input(input)?;
476
477 let start_time = std::time::Instant::now();
479 let result = plugin.execute(input);
480 let execution_time = start_time.elapsed().as_millis() as u64;
481
482 match result {
484 Ok(mut output) => {
485 output.metadata.execution_time_ms = execution_time;
486 Ok(output)
487 }
488 Err(e) => Err(e),
489 }
490 }
491
492 pub fn unregister_plugin(&self, plugin_id: &str) -> SklResult<()> {
494 {
495 let mut plugins = self.plugins.write().map_err(|_| {
496 crate::SklearsError::InvalidInput("Failed to acquire plugins lock".to_string())
497 })?;
498 plugins.remove(plugin_id);
499 }
500
501 {
502 let mut configs = self.plugin_configs.write().map_err(|_| {
503 crate::SklearsError::InvalidInput("Failed to acquire configs lock".to_string())
504 })?;
505 configs.remove(plugin_id);
506 }
507
508 {
509 let mut metadata_store = self.plugin_metadata.write().map_err(|_| {
510 crate::SklearsError::InvalidInput("Failed to acquire metadata lock".to_string())
511 })?;
512 metadata_store.remove(plugin_id);
513 }
514
515 Ok(())
516 }
517
518 pub fn get_statistics(&self) -> PluginRegistryStatistics {
520 let plugins = self.plugins.read().ok();
521 let metadata = self.plugin_metadata.read().ok();
522
523 let total_plugins = plugins.as_ref().map(|p| p.len()).unwrap_or(0);
524
525 let plugins_by_type = metadata
526 .as_ref()
527 .map(|meta| {
528 let mut input_types = HashMap::new();
529 let mut output_types = HashMap::new();
530
531 for (_, plugin_meta) in meta.iter() {
532 for input_type in &plugin_meta.supported_inputs {
533 *input_types.entry(*input_type).or_insert(0) += 1;
534 }
535 for output_type in &plugin_meta.supported_outputs {
536 *output_types.entry(*output_type).or_insert(0) += 1;
537 }
538 }
539
540 (input_types, output_types)
541 })
542 .unwrap_or_default();
543
544 PluginRegistryStatistics {
545 total_plugins,
546 plugins_by_input_type: plugins_by_type.0,
547 plugins_by_output_type: plugins_by_type.1,
548 registry_created_at: chrono::Utc::now(),
549 }
550 }
551}
552
553#[derive(Debug, Clone)]
555pub struct PluginCapabilityFilter {
556 pub local_explanations: Option<bool>,
558 pub global_explanations: Option<bool>,
560 pub counterfactual_explanations: Option<bool>,
562 pub uncertainty_quantification: Option<bool>,
564 pub model_agnostic: Option<bool>,
566 pub parallel_processing: Option<bool>,
568 pub real_time: Option<bool>,
570 pub streaming: Option<bool>,
572 pub max_dataset_size: Option<usize>,
574 pub max_features: Option<usize>,
576}
577
578impl PluginCapabilityFilter {
579 pub fn new() -> Self {
581 Self {
582 local_explanations: None,
583 global_explanations: None,
584 counterfactual_explanations: None,
585 uncertainty_quantification: None,
586 model_agnostic: None,
587 parallel_processing: None,
588 real_time: None,
589 streaming: None,
590 max_dataset_size: None,
591 max_features: None,
592 }
593 }
594
595 pub fn matches(&self, capabilities: &PluginCapabilities) -> bool {
597 if let Some(required) = self.local_explanations {
598 if capabilities.local_explanations != required {
599 return false;
600 }
601 }
602
603 if let Some(required) = self.global_explanations {
604 if capabilities.global_explanations != required {
605 return false;
606 }
607 }
608
609 if let Some(required) = self.counterfactual_explanations {
610 if capabilities.counterfactual_explanations != required {
611 return false;
612 }
613 }
614
615 if let Some(required) = self.uncertainty_quantification {
616 if capabilities.uncertainty_quantification != required {
617 return false;
618 }
619 }
620
621 if let Some(required) = self.model_agnostic {
622 if capabilities.model_agnostic != required {
623 return false;
624 }
625 }
626
627 if let Some(required) = self.parallel_processing {
628 if capabilities.parallel_processing != required {
629 return false;
630 }
631 }
632
633 if let Some(required) = self.real_time {
634 if capabilities.real_time != required {
635 return false;
636 }
637 }
638
639 if let Some(required) = self.streaming {
640 if capabilities.streaming != required {
641 return false;
642 }
643 }
644
645 if let Some(max_size) = self.max_dataset_size {
646 if let Some(cap_size) = capabilities.max_dataset_size {
647 if cap_size < max_size {
648 return false;
649 }
650 } else {
651 return false;
652 }
653 }
654
655 if let Some(max_features) = self.max_features {
656 if let Some(cap_features) = capabilities.max_features {
657 if cap_features < max_features {
658 return false;
659 }
660 } else {
661 return false;
662 }
663 }
664
665 true
666 }
667}
668
669impl Default for PluginCapabilityFilter {
670 fn default() -> Self {
671 Self::new()
672 }
673}
674
675#[derive(Debug, Clone, Serialize, Deserialize)]
677pub struct PluginRegistryStatistics {
678 pub total_plugins: usize,
680 pub plugins_by_input_type: HashMap<InputType, usize>,
682 pub plugins_by_output_type: HashMap<OutputType, usize>,
684 pub registry_created_at: chrono::DateTime<chrono::Utc>,
686}
687
688#[derive(Debug)]
690pub struct PluginManager {
691 registry: PluginRegistry,
692 execution_history: Arc<RwLock<Vec<PluginExecution>>>,
693}
694
695impl PluginManager {
696 pub fn new() -> Self {
698 Self {
699 registry: PluginRegistry::new(),
700 execution_history: Arc::new(RwLock::new(Vec::new())),
701 }
702 }
703
704 pub fn registry(&self) -> &PluginRegistry {
706 &self.registry
707 }
708
709 pub fn execute_with_history(
711 &self,
712 plugin_id: &str,
713 input: &PluginInput,
714 ) -> SklResult<PluginOutput> {
715 let start_time = std::time::Instant::now();
716 let result = self.registry.execute_plugin(plugin_id, input);
717 let execution_time = start_time.elapsed();
718
719 let execution = PluginExecution {
721 plugin_id: plugin_id.to_string(),
722 success: result.is_ok(),
723 execution_time_ms: execution_time.as_millis() as u64,
724 timestamp: chrono::Utc::now(),
725 error_message: result.as_ref().err().map(|e| e.to_string()),
726 };
727
728 if let Ok(mut history) = self.execution_history.write() {
729 history.push(execution);
730 }
731
732 result
733 }
734
735 pub fn get_execution_history(&self) -> Vec<PluginExecution> {
737 self.execution_history
738 .read()
739 .ok()
740 .map(|history| history.clone())
741 .unwrap_or_default()
742 }
743
744 pub fn get_execution_statistics(&self) -> ExecutionStatistics {
746 let history = self.get_execution_history();
747
748 let total_executions = history.len();
749 let successful_executions = history.iter().filter(|e| e.success).count();
750 let failed_executions = total_executions - successful_executions;
751
752 let average_execution_time = if total_executions > 0 {
753 history.iter().map(|e| e.execution_time_ms).sum::<u64>() / total_executions as u64
754 } else {
755 0
756 };
757
758 let plugin_usage = {
759 let mut usage = HashMap::new();
760 for execution in &history {
761 *usage.entry(execution.plugin_id.clone()).or_insert(0) += 1;
762 }
763 usage
764 };
765
766 ExecutionStatistics {
767 total_executions,
768 successful_executions,
769 failed_executions,
770 average_execution_time_ms: average_execution_time,
771 plugin_usage,
772 }
773 }
774
775 pub fn clear_execution_history(&self) {
777 if let Ok(mut history) = self.execution_history.write() {
778 history.clear();
779 }
780 }
781}
782
783impl Default for PluginManager {
784 fn default() -> Self {
785 Self::new()
786 }
787}
788
789#[derive(Debug, Clone, Serialize, Deserialize)]
791pub struct PluginExecution {
792 pub plugin_id: String,
794 pub success: bool,
796 pub execution_time_ms: u64,
798 pub timestamp: chrono::DateTime<chrono::Utc>,
800 pub error_message: Option<String>,
802}
803
804#[derive(Debug, Clone, Serialize, Deserialize)]
806pub struct ExecutionStatistics {
807 pub total_executions: usize,
809 pub successful_executions: usize,
811 pub failed_executions: usize,
813 pub average_execution_time_ms: u64,
815 pub plugin_usage: HashMap<String, usize>,
817}
818
819#[derive(Debug)]
821pub struct ExampleCustomPlugin {
822 id: String,
823 name: String,
824 version: String,
825 description: String,
826 author: String,
827 initialized: bool,
828}
829
830impl ExampleCustomPlugin {
831 pub fn new() -> Self {
833 Self {
834 id: "example_custom_plugin".to_string(),
835 name: "Example Custom Plugin".to_string(),
836 version: "1.0.0".to_string(),
837 description: "An example plugin for demonstration purposes".to_string(),
838 author: "Sklears Team".to_string(),
839 initialized: false,
840 }
841 }
842}
843
844impl ExplanationPlugin for ExampleCustomPlugin {
845 fn plugin_id(&self) -> &str {
846 &self.id
847 }
848
849 fn plugin_name(&self) -> &str {
850 &self.name
851 }
852
853 fn plugin_version(&self) -> &str {
854 &self.version
855 }
856
857 fn plugin_description(&self) -> &str {
858 &self.description
859 }
860
861 fn plugin_author(&self) -> &str {
862 &self.author
863 }
864
865 fn supported_input_types(&self) -> Vec<InputType> {
866 vec![InputType::Tabular, InputType::TimeSeries]
867 }
868
869 fn supported_output_types(&self) -> Vec<OutputType> {
870 vec![OutputType::FeatureImportance, OutputType::LocalExplanation]
871 }
872
873 fn capabilities(&self) -> PluginCapabilities {
874 PluginCapabilities {
875 local_explanations: true,
876 global_explanations: true,
877 counterfactual_explanations: false,
878 uncertainty_quantification: false,
879 model_agnostic: true,
880 parallel_processing: false,
881 real_time: true,
882 streaming: false,
883 max_dataset_size: Some(10000),
884 max_features: Some(1000),
885 estimated_memory_usage: Some(1024 * 1024), }
887 }
888
889 fn initialize(&mut self, _config: &PluginConfig) -> SklResult<()> {
890 self.initialized = true;
891 Ok(())
892 }
893
894 fn execute(&self, input: &PluginInput) -> SklResult<PluginOutput> {
895 if !self.initialized {
896 return Err(crate::SklearsError::InvalidInput(
897 "Plugin not initialized".to_string(),
898 ));
899 }
900
901 let start_time = std::time::Instant::now();
902
903 let feature_importance = match &input.data {
905 PluginData::Tabular(data) => {
906 let n_features = data.ncols();
907 let importance_scores: Vec<Float> = (0..n_features)
908 .map(|i| {
909 let column = data.column(i);
910 column.var(0.0) })
912 .collect();
913
914 let feature_names = input
915 .feature_names
916 .clone()
917 .unwrap_or_else(|| (0..n_features).map(|i| format!("feature_{}", i)).collect());
918
919 PluginOutputData::FeatureImportance {
920 scores: importance_scores,
921 feature_names,
922 std_errors: None,
923 }
924 }
925 _ => {
926 return Err(crate::SklearsError::InvalidInput(
927 "Unsupported input type for this plugin".to_string(),
928 ));
929 }
930 };
931
932 let execution_time = start_time.elapsed().as_millis() as u64;
933
934 Ok(PluginOutput {
935 output_type: OutputType::FeatureImportance,
936 data: feature_importance,
937 metadata: ExecutionMetadata {
938 execution_time_ms: execution_time,
939 memory_usage_bytes: 0,
940 iterations: None,
941 converged: Some(true),
942 warnings: Vec::new(),
943 timestamp: chrono::Utc::now(),
944 },
945 confidence: None,
946 uncertainty: None,
947 })
948 }
949
950 fn validate_input(&self, input: &PluginInput) -> SklResult<()> {
951 match &input.data {
952 PluginData::Tabular(data) => {
953 if data.nrows() == 0 || data.ncols() == 0 {
954 return Err(crate::SklearsError::InvalidInput(
955 "Input data cannot be empty".to_string(),
956 ));
957 }
958
959 if let Some(max_features) = self.capabilities().max_features {
960 if data.ncols() > max_features {
961 return Err(crate::SklearsError::InvalidInput(format!(
962 "Too many features: {} > {}",
963 data.ncols(),
964 max_features
965 )));
966 }
967 }
968
969 if let Some(max_samples) = self.capabilities().max_dataset_size {
970 if data.nrows() > max_samples {
971 return Err(crate::SklearsError::InvalidInput(format!(
972 "Too many samples: {} > {}",
973 data.nrows(),
974 max_samples
975 )));
976 }
977 }
978
979 Ok(())
980 }
981 _ => Err(crate::SklearsError::InvalidInput(
982 "Unsupported input type".to_string(),
983 )),
984 }
985 }
986
987 fn cleanup(&mut self) -> SklResult<()> {
988 self.initialized = false;
989 Ok(())
990 }
991}
992
993impl Default for ExampleCustomPlugin {
994 fn default() -> Self {
995 Self::new()
996 }
997}
998
999#[cfg(test)]
1000mod tests {
1001 use super::*;
1002 use scirs2_core::ndarray::Array2;
1004
1005 #[test]
1006 fn test_plugin_registry() {
1007 let registry = PluginRegistry::new();
1008
1009 let plugin = ExampleCustomPlugin::new();
1011 let result = registry.register_plugin(plugin, None);
1012 assert!(result.is_ok());
1013
1014 let plugins = registry.list_plugins();
1016 assert!(plugins.contains(&"example_custom_plugin".to_string()));
1017
1018 let metadata = registry.get_plugin_metadata("example_custom_plugin");
1020 assert!(metadata.is_some());
1021 let metadata = metadata.unwrap();
1022 assert_eq!(metadata.name, "Example Custom Plugin");
1023 assert_eq!(metadata.version, "1.0.0");
1024 }
1025
1026 #[test]
1027 fn test_plugin_execution() {
1028 let registry = PluginRegistry::new();
1029
1030 let plugin = ExampleCustomPlugin::new();
1032 registry.register_plugin(plugin, None).unwrap();
1033
1034 let data = Array2::from_shape_vec((10, 3), (0..30).map(|x| x as Float).collect()).unwrap();
1036 let input = PluginInput {
1037 data: PluginData::Tabular(data),
1038 predictions: None,
1039 targets: None,
1040 feature_names: Some(vec!["f1".to_string(), "f2".to_string(), "f3".to_string()]),
1041 sample_weights: None,
1042 metadata: HashMap::new(),
1043 };
1044
1045 let result = registry.execute_plugin("example_custom_plugin", &input);
1047 assert!(result.is_ok());
1048
1049 let output = result.unwrap();
1050 assert_eq!(output.output_type, OutputType::FeatureImportance);
1051
1052 match output.data {
1053 PluginOutputData::FeatureImportance {
1054 scores,
1055 feature_names,
1056 ..
1057 } => {
1058 assert_eq!(scores.len(), 3);
1059 assert_eq!(feature_names.len(), 3);
1060 assert_eq!(feature_names[0], "f1");
1061 }
1062 _ => panic!("Expected feature importance output"),
1063 }
1064 }
1065
1066 #[test]
1067 fn test_plugin_capability_filter() {
1068 let capabilities = PluginCapabilities {
1069 local_explanations: true,
1070 global_explanations: true,
1071 counterfactual_explanations: false,
1072 uncertainty_quantification: false,
1073 model_agnostic: true,
1074 parallel_processing: false,
1075 real_time: true,
1076 streaming: false,
1077 max_dataset_size: Some(10000),
1078 max_features: Some(1000),
1079 estimated_memory_usage: Some(1024 * 1024),
1080 };
1081
1082 let mut filter = PluginCapabilityFilter::new();
1083 filter.local_explanations = Some(true);
1084 filter.real_time = Some(true);
1085 filter.max_dataset_size = Some(5000);
1086
1087 assert!(filter.matches(&capabilities));
1088
1089 filter.max_dataset_size = Some(20000);
1090 assert!(!filter.matches(&capabilities));
1091 }
1092
1093 #[test]
1094 fn test_plugin_manager() {
1095 let manager = PluginManager::new();
1096
1097 let plugin = ExampleCustomPlugin::new();
1099 manager.registry().register_plugin(plugin, None).unwrap();
1100
1101 let data = Array2::from_shape_vec((5, 2), (0..10).map(|x| x as Float).collect()).unwrap();
1103 let input = PluginInput {
1104 data: PluginData::Tabular(data),
1105 predictions: None,
1106 targets: None,
1107 feature_names: None,
1108 sample_weights: None,
1109 metadata: HashMap::new(),
1110 };
1111
1112 let result = manager.execute_with_history("example_custom_plugin", &input);
1114 assert!(result.is_ok());
1115
1116 let history = manager.get_execution_history();
1118 assert_eq!(history.len(), 1);
1119 assert!(history[0].success);
1120
1121 let stats = manager.get_execution_statistics();
1123 assert_eq!(stats.total_executions, 1);
1124 assert_eq!(stats.successful_executions, 1);
1125 assert_eq!(stats.failed_executions, 0);
1126 }
1127
1128 #[test]
1129 fn test_plugin_list_by_type() {
1130 let registry = PluginRegistry::new();
1131
1132 let plugin = ExampleCustomPlugin::new();
1134 registry.register_plugin(plugin, None).unwrap();
1135
1136 let tabular_plugins = registry.list_plugins_by_input_type(InputType::Tabular);
1138 assert!(tabular_plugins.contains(&"example_custom_plugin".to_string()));
1139
1140 let image_plugins = registry.list_plugins_by_input_type(InputType::Image);
1141 assert!(image_plugins.is_empty());
1142
1143 let importance_plugins =
1145 registry.list_plugins_by_output_type(OutputType::FeatureImportance);
1146 assert!(importance_plugins.contains(&"example_custom_plugin".to_string()));
1147
1148 let counterfactual_plugins =
1149 registry.list_plugins_by_output_type(OutputType::CounterfactualExplanation);
1150 assert!(counterfactual_plugins.is_empty());
1151 }
1152
1153 #[test]
1154 fn test_plugin_validation() {
1155 let plugin = ExampleCustomPlugin::new();
1156
1157 let empty_data = Array2::from_shape_vec((0, 0), vec![]).unwrap();
1159 let input = PluginInput {
1160 data: PluginData::Tabular(empty_data),
1161 predictions: None,
1162 targets: None,
1163 feature_names: None,
1164 sample_weights: None,
1165 metadata: HashMap::new(),
1166 };
1167
1168 let result = plugin.validate_input(&input);
1169 assert!(result.is_err());
1170
1171 let valid_data =
1173 Array2::from_shape_vec((5, 2), (0..10).map(|x| x as Float).collect()).unwrap();
1174 let input = PluginInput {
1175 data: PluginData::Tabular(valid_data),
1176 predictions: None,
1177 targets: None,
1178 feature_names: None,
1179 sample_weights: None,
1180 metadata: HashMap::new(),
1181 };
1182
1183 let result = plugin.validate_input(&input);
1184 assert!(result.is_ok());
1185 }
1186
1187 #[test]
1188 fn test_plugin_parameter_types() {
1189 let mut config = PluginConfig::default();
1190
1191 config
1192 .parameters
1193 .insert("integer_param".to_string(), PluginParameter::Integer(42));
1194 config
1195 .parameters
1196 .insert("float_param".to_string(), PluginParameter::Float(3.14));
1197 config.parameters.insert(
1198 "string_param".to_string(),
1199 PluginParameter::String("test".to_string()),
1200 );
1201 config
1202 .parameters
1203 .insert("bool_param".to_string(), PluginParameter::Boolean(true));
1204
1205 assert_eq!(config.parameters.len(), 4);
1206
1207 match config.parameters.get("integer_param") {
1208 Some(PluginParameter::Integer(val)) => assert_eq!(*val, 42),
1209 _ => panic!("Expected integer parameter"),
1210 }
1211
1212 match config.parameters.get("float_param") {
1213 Some(PluginParameter::Float(val)) => assert_eq!(*val, 3.14),
1214 _ => panic!("Expected float parameter"),
1215 }
1216 }
1217}