sklears_inspection/
plugins.rs

1//! Plugin Architecture for Custom Explanation Methods
2//!
3//! This module provides a comprehensive plugin system for registering and managing
4//! custom explanation methods, allowing users to extend the library with their own
5//! interpretability algorithms.
6
7use crate::{Float, SklResult};
8// ✅ SciRS2 Policy Compliant Import
9use scirs2_core::ndarray::{Array1, Array2};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::fmt::Debug;
13use std::sync::{Arc, RwLock};
14
15/// Trait for plugin explanation methods
16pub trait ExplanationPlugin: Debug + Send + Sync {
17    /// Plugin identifier
18    fn plugin_id(&self) -> &str;
19
20    /// Plugin name
21    fn plugin_name(&self) -> &str;
22
23    /// Plugin version
24    fn plugin_version(&self) -> &str;
25
26    /// Plugin description
27    fn plugin_description(&self) -> &str;
28
29    /// Plugin author
30    fn plugin_author(&self) -> &str;
31
32    /// Supported input types
33    fn supported_input_types(&self) -> Vec<InputType>;
34
35    /// Supported output types
36    fn supported_output_types(&self) -> Vec<OutputType>;
37
38    /// Plugin capabilities
39    fn capabilities(&self) -> PluginCapabilities;
40
41    /// Initialize the plugin
42    fn initialize(&mut self, config: &PluginConfig) -> SklResult<()>;
43
44    /// Execute the explanation method
45    fn execute(&self, input: &PluginInput) -> SklResult<PluginOutput>;
46
47    /// Validate input before execution
48    fn validate_input(&self, input: &PluginInput) -> SklResult<()>;
49
50    /// Cleanup resources
51    fn cleanup(&mut self) -> SklResult<()>;
52
53    /// Get plugin metadata
54    fn metadata(&self) -> PluginMetadata {
55        PluginMetadata {
56            id: self.plugin_id().to_string(),
57            name: self.plugin_name().to_string(),
58            version: self.plugin_version().to_string(),
59            description: self.plugin_description().to_string(),
60            author: self.plugin_author().to_string(),
61            supported_inputs: self.supported_input_types(),
62            supported_outputs: self.supported_output_types(),
63            capabilities: self.capabilities(),
64            created_at: chrono::Utc::now(),
65        }
66    }
67}
68
69/// Plugin input types
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
71pub enum InputType {
72    /// Tabular data (features x samples)
73    Tabular,
74    /// Time series data
75    TimeSeries,
76    /// Image data
77    Image,
78    /// Text data
79    Text,
80    /// Graph data
81    Graph,
82    /// Custom data type
83    Custom(u32),
84}
85
86/// Plugin output types
87#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
88pub enum OutputType {
89    /// Feature importance scores
90    FeatureImportance,
91    /// Local explanations
92    LocalExplanation,
93    /// Global explanations
94    GlobalExplanation,
95    /// Counterfactual explanations
96    CounterfactualExplanation,
97    /// Visualization data
98    VisualizationData,
99    /// Custom output type
100    Custom(u32),
101}
102
103/// Plugin capabilities
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct PluginCapabilities {
106    /// Supports local explanations
107    pub local_explanations: bool,
108    /// Supports global explanations
109    pub global_explanations: bool,
110    /// Supports counterfactual explanations
111    pub counterfactual_explanations: bool,
112    /// Supports uncertainty quantification
113    pub uncertainty_quantification: bool,
114    /// Supports model-agnostic explanations
115    pub model_agnostic: bool,
116    /// Supports parallel processing
117    pub parallel_processing: bool,
118    /// Supports real-time explanations
119    pub real_time: bool,
120    /// Supports streaming data
121    pub streaming: bool,
122    /// Maximum dataset size (number of samples)
123    pub max_dataset_size: Option<usize>,
124    /// Maximum number of features
125    pub max_features: Option<usize>,
126    /// Estimated memory usage in bytes
127    pub estimated_memory_usage: Option<usize>,
128}
129
130impl Default for PluginCapabilities {
131    fn default() -> Self {
132        Self {
133            local_explanations: false,
134            global_explanations: false,
135            counterfactual_explanations: false,
136            uncertainty_quantification: false,
137            model_agnostic: true,
138            parallel_processing: false,
139            real_time: false,
140            streaming: false,
141            max_dataset_size: None,
142            max_features: None,
143            estimated_memory_usage: None,
144        }
145    }
146}
147
148/// Plugin configuration
149#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct PluginConfig {
151    /// Plugin-specific parameters
152    pub parameters: HashMap<String, PluginParameter>,
153    /// Maximum execution time in seconds
154    pub max_execution_time: Option<u64>,
155    /// Memory limit in bytes
156    pub memory_limit: Option<usize>,
157    /// Number of threads for parallel processing
158    pub num_threads: Option<usize>,
159    /// Random seed for reproducibility
160    pub random_seed: Option<u64>,
161    /// Logging level
162    pub log_level: LogLevel,
163}
164
165impl Default for PluginConfig {
166    fn default() -> Self {
167        Self {
168            parameters: HashMap::new(),
169            max_execution_time: Some(300),          // 5 minutes
170            memory_limit: Some(1024 * 1024 * 1024), // 1GB
171            num_threads: Some(1),
172            random_seed: None,
173            log_level: LogLevel::Info,
174        }
175    }
176}
177
178/// Plugin parameter values
179#[derive(Debug, Clone, Serialize, Deserialize)]
180pub enum PluginParameter {
181    /// Integer
182    Integer(i64),
183    /// Float
184    Float(f64),
185    /// String
186    String(String),
187    /// Boolean
188    Boolean(bool),
189    /// IntegerArray
190    IntegerArray(Vec<i64>),
191    /// FloatArray
192    FloatArray(Vec<f64>),
193    /// StringArray
194    StringArray(Vec<String>),
195    /// BooleanArray
196    BooleanArray(Vec<bool>),
197}
198
199/// Logging levels
200#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
201pub enum LogLevel {
202    /// Error
203    Error,
204    /// Warn
205    Warn,
206    /// Info
207    Info,
208    /// Debug
209    Debug,
210    /// Trace
211    Trace,
212}
213
214/// Plugin input data
215#[derive(Debug, Clone)]
216pub struct PluginInput {
217    /// Input data
218    pub data: PluginData,
219    /// Model predictions (if available)
220    pub predictions: Option<Array1<Float>>,
221    /// Target values (if available)
222    pub targets: Option<Array1<Float>>,
223    /// Feature names
224    pub feature_names: Option<Vec<String>>,
225    /// Sample weights
226    pub sample_weights: Option<Array1<Float>>,
227    /// Additional metadata
228    pub metadata: HashMap<String, String>,
229}
230
231/// Plugin data types
232#[derive(Debug, Clone)]
233pub enum PluginData {
234    /// Tabular data (samples x features)
235    Tabular(Array2<Float>),
236    /// Time series data
237    TimeSeries(Array2<Float>),
238    /// Image data (height x width x channels)
239    Image(Array2<Float>),
240    /// Text data (as string)
241    Text(String),
242    /// Graph data (adjacency matrix)
243    Graph(Array2<Float>),
244    /// Custom data (serialized)
245    Custom(Vec<u8>),
246}
247
248/// Plugin output data
249#[derive(Debug, Clone, Serialize, Deserialize)]
250pub struct PluginOutput {
251    /// Output type
252    pub output_type: OutputType,
253    /// Output data
254    pub data: PluginOutputData,
255    /// Execution metadata
256    pub metadata: ExecutionMetadata,
257    /// Confidence scores (if available)
258    pub confidence: Option<Array1<Float>>,
259    /// Uncertainty estimates (if available)
260    pub uncertainty: Option<Array1<Float>>,
261}
262
263/// Plugin output data types
264#[derive(Debug, Clone, Serialize, Deserialize)]
265pub enum PluginOutputData {
266    /// Feature importance scores
267    FeatureImportance {
268        scores: Vec<Float>,
269
270        feature_names: Vec<String>,
271
272        std_errors: Option<Vec<Float>>,
273    },
274    /// Local explanation
275    LocalExplanation {
276        instance_id: usize,
277
278        feature_contributions: Vec<Float>,
279        feature_names: Vec<String>,
280        base_value: Float,
281    },
282    /// Global explanation
283    GlobalExplanation {
284        feature_effects: Vec<Float>,
285        feature_names: Vec<String>,
286        interaction_effects: Option<Array2<Float>>,
287    },
288    /// Counterfactual explanation
289    CounterfactualExplanation {
290        counterfactual_instance: Array1<Float>,
291        feature_changes: Vec<(usize, Float, Float)>, // (feature_idx, original, new)
292        distance: Float,
293        feasibility_score: Float,
294    },
295    /// Visualization data
296    VisualizationData {
297        plot_type: String,
298        data: serde_json::Value,
299        config: HashMap<String, String>,
300    },
301    /// Custom output
302    Custom(serde_json::Value),
303}
304
305/// Execution metadata
306#[derive(Debug, Clone, Serialize, Deserialize)]
307pub struct ExecutionMetadata {
308    /// Execution time in milliseconds
309    pub execution_time_ms: u64,
310    /// Memory usage in bytes
311    pub memory_usage_bytes: usize,
312    /// Number of iterations (if applicable)
313    pub iterations: Option<usize>,
314    /// Convergence status
315    pub converged: Option<bool>,
316    /// Warning messages
317    pub warnings: Vec<String>,
318    /// Execution timestamp
319    pub timestamp: chrono::DateTime<chrono::Utc>,
320}
321
322/// Plugin metadata
323#[derive(Debug, Clone, Serialize, Deserialize)]
324pub struct PluginMetadata {
325    /// Plugin ID
326    pub id: String,
327    /// Plugin name
328    pub name: String,
329    /// Plugin version
330    pub version: String,
331    /// Plugin description
332    pub description: String,
333    /// Plugin author
334    pub author: String,
335    /// Supported input types
336    pub supported_inputs: Vec<InputType>,
337    /// Supported output types
338    pub supported_outputs: Vec<OutputType>,
339    /// Plugin capabilities
340    pub capabilities: PluginCapabilities,
341    /// Creation timestamp
342    pub created_at: chrono::DateTime<chrono::Utc>,
343}
344
345/// Plugin registry for managing plugins
346#[derive(Debug, Default)]
347pub struct PluginRegistry {
348    plugins: Arc<RwLock<HashMap<String, Arc<dyn ExplanationPlugin>>>>,
349    plugin_configs: Arc<RwLock<HashMap<String, PluginConfig>>>,
350    plugin_metadata: Arc<RwLock<HashMap<String, PluginMetadata>>>,
351}
352
353impl PluginRegistry {
354    /// Create a new plugin registry
355    pub fn new() -> Self {
356        Self::default()
357    }
358
359    /// Register a new plugin
360    pub fn register_plugin<P: ExplanationPlugin + 'static>(
361        &self,
362        mut plugin: P,
363        config: Option<PluginConfig>,
364    ) -> SklResult<()> {
365        let plugin_id = plugin.plugin_id().to_string();
366        let config = config.unwrap_or_default();
367
368        // Initialize the plugin
369        plugin.initialize(&config)?;
370
371        // Get metadata
372        let metadata = plugin.metadata();
373
374        // Store plugin, config, and metadata
375        {
376            let mut plugins = self.plugins.write().map_err(|_| {
377                crate::SklearsError::InvalidInput("Failed to acquire plugins lock".to_string())
378            })?;
379            plugins.insert(plugin_id.clone(), Arc::new(plugin));
380        }
381
382        {
383            let mut configs = self.plugin_configs.write().map_err(|_| {
384                crate::SklearsError::InvalidInput("Failed to acquire configs lock".to_string())
385            })?;
386            configs.insert(plugin_id.clone(), config);
387        }
388
389        {
390            let mut metadata_store = self.plugin_metadata.write().map_err(|_| {
391                crate::SklearsError::InvalidInput("Failed to acquire metadata lock".to_string())
392            })?;
393            metadata_store.insert(plugin_id, metadata);
394        }
395
396        Ok(())
397    }
398
399    /// Get a plugin by ID
400    pub fn get_plugin(&self, plugin_id: &str) -> Option<Arc<dyn ExplanationPlugin>> {
401        self.plugins.read().ok()?.get(plugin_id).cloned()
402    }
403
404    /// Get plugin configuration
405    pub fn get_plugin_config(&self, plugin_id: &str) -> Option<PluginConfig> {
406        self.plugin_configs.read().ok()?.get(plugin_id).cloned()
407    }
408
409    /// Get plugin metadata
410    pub fn get_plugin_metadata(&self, plugin_id: &str) -> Option<PluginMetadata> {
411        self.plugin_metadata.read().ok()?.get(plugin_id).cloned()
412    }
413
414    /// List all registered plugins
415    pub fn list_plugins(&self) -> Vec<String> {
416        self.plugins
417            .read()
418            .ok()
419            .map(|plugins| plugins.keys().cloned().collect())
420            .unwrap_or_default()
421    }
422
423    /// List plugins by input type
424    pub fn list_plugins_by_input_type(&self, input_type: InputType) -> Vec<String> {
425        self.plugin_metadata
426            .read()
427            .ok()
428            .map(|metadata| {
429                metadata
430                    .iter()
431                    .filter(|(_, meta)| meta.supported_inputs.contains(&input_type))
432                    .map(|(id, _)| id.clone())
433                    .collect()
434            })
435            .unwrap_or_default()
436    }
437
438    /// List plugins by output type
439    pub fn list_plugins_by_output_type(&self, output_type: OutputType) -> Vec<String> {
440        self.plugin_metadata
441            .read()
442            .ok()
443            .map(|metadata| {
444                metadata
445                    .iter()
446                    .filter(|(_, meta)| meta.supported_outputs.contains(&output_type))
447                    .map(|(id, _)| id.clone())
448                    .collect()
449            })
450            .unwrap_or_default()
451    }
452
453    /// List plugins by capability
454    pub fn list_plugins_by_capability(&self, capability: PluginCapabilityFilter) -> Vec<String> {
455        self.plugin_metadata
456            .read()
457            .ok()
458            .map(|metadata| {
459                metadata
460                    .iter()
461                    .filter(|(_, meta)| capability.matches(&meta.capabilities))
462                    .map(|(id, _)| id.clone())
463                    .collect()
464            })
465            .unwrap_or_default()
466    }
467
468    /// Execute a plugin
469    pub fn execute_plugin(&self, plugin_id: &str, input: &PluginInput) -> SklResult<PluginOutput> {
470        let plugin = self.get_plugin(plugin_id).ok_or_else(|| {
471            crate::SklearsError::InvalidInput(format!("Plugin '{}' not found", plugin_id))
472        })?;
473
474        // Validate input
475        plugin.validate_input(input)?;
476
477        // Execute plugin
478        let start_time = std::time::Instant::now();
479        let result = plugin.execute(input);
480        let execution_time = start_time.elapsed().as_millis() as u64;
481
482        // Add timing information to result
483        match result {
484            Ok(mut output) => {
485                output.metadata.execution_time_ms = execution_time;
486                Ok(output)
487            }
488            Err(e) => Err(e),
489        }
490    }
491
492    /// Unregister a plugin
493    pub fn unregister_plugin(&self, plugin_id: &str) -> SklResult<()> {
494        {
495            let mut plugins = self.plugins.write().map_err(|_| {
496                crate::SklearsError::InvalidInput("Failed to acquire plugins lock".to_string())
497            })?;
498            plugins.remove(plugin_id);
499        }
500
501        {
502            let mut configs = self.plugin_configs.write().map_err(|_| {
503                crate::SklearsError::InvalidInput("Failed to acquire configs lock".to_string())
504            })?;
505            configs.remove(plugin_id);
506        }
507
508        {
509            let mut metadata_store = self.plugin_metadata.write().map_err(|_| {
510                crate::SklearsError::InvalidInput("Failed to acquire metadata lock".to_string())
511            })?;
512            metadata_store.remove(plugin_id);
513        }
514
515        Ok(())
516    }
517
518    /// Get plugin statistics
519    pub fn get_statistics(&self) -> PluginRegistryStatistics {
520        let plugins = self.plugins.read().ok();
521        let metadata = self.plugin_metadata.read().ok();
522
523        let total_plugins = plugins.as_ref().map(|p| p.len()).unwrap_or(0);
524
525        let plugins_by_type = metadata
526            .as_ref()
527            .map(|meta| {
528                let mut input_types = HashMap::new();
529                let mut output_types = HashMap::new();
530
531                for (_, plugin_meta) in meta.iter() {
532                    for input_type in &plugin_meta.supported_inputs {
533                        *input_types.entry(*input_type).or_insert(0) += 1;
534                    }
535                    for output_type in &plugin_meta.supported_outputs {
536                        *output_types.entry(*output_type).or_insert(0) += 1;
537                    }
538                }
539
540                (input_types, output_types)
541            })
542            .unwrap_or_default();
543
544        PluginRegistryStatistics {
545            total_plugins,
546            plugins_by_input_type: plugins_by_type.0,
547            plugins_by_output_type: plugins_by_type.1,
548            registry_created_at: chrono::Utc::now(),
549        }
550    }
551}
552
553/// Plugin capability filter
554#[derive(Debug, Clone)]
555pub struct PluginCapabilityFilter {
556    /// Requires local explanations
557    pub local_explanations: Option<bool>,
558    /// Requires global explanations
559    pub global_explanations: Option<bool>,
560    /// Requires counterfactual explanations
561    pub counterfactual_explanations: Option<bool>,
562    /// Requires uncertainty quantification
563    pub uncertainty_quantification: Option<bool>,
564    /// Requires model-agnostic support
565    pub model_agnostic: Option<bool>,
566    /// Requires parallel processing
567    pub parallel_processing: Option<bool>,
568    /// Requires real-time support
569    pub real_time: Option<bool>,
570    /// Requires streaming support
571    pub streaming: Option<bool>,
572    /// Maximum dataset size constraint
573    pub max_dataset_size: Option<usize>,
574    /// Maximum features constraint
575    pub max_features: Option<usize>,
576}
577
578impl PluginCapabilityFilter {
579    /// Create a new capability filter
580    pub fn new() -> Self {
581        Self {
582            local_explanations: None,
583            global_explanations: None,
584            counterfactual_explanations: None,
585            uncertainty_quantification: None,
586            model_agnostic: None,
587            parallel_processing: None,
588            real_time: None,
589            streaming: None,
590            max_dataset_size: None,
591            max_features: None,
592        }
593    }
594
595    /// Check if capabilities match the filter
596    pub fn matches(&self, capabilities: &PluginCapabilities) -> bool {
597        if let Some(required) = self.local_explanations {
598            if capabilities.local_explanations != required {
599                return false;
600            }
601        }
602
603        if let Some(required) = self.global_explanations {
604            if capabilities.global_explanations != required {
605                return false;
606            }
607        }
608
609        if let Some(required) = self.counterfactual_explanations {
610            if capabilities.counterfactual_explanations != required {
611                return false;
612            }
613        }
614
615        if let Some(required) = self.uncertainty_quantification {
616            if capabilities.uncertainty_quantification != required {
617                return false;
618            }
619        }
620
621        if let Some(required) = self.model_agnostic {
622            if capabilities.model_agnostic != required {
623                return false;
624            }
625        }
626
627        if let Some(required) = self.parallel_processing {
628            if capabilities.parallel_processing != required {
629                return false;
630            }
631        }
632
633        if let Some(required) = self.real_time {
634            if capabilities.real_time != required {
635                return false;
636            }
637        }
638
639        if let Some(required) = self.streaming {
640            if capabilities.streaming != required {
641                return false;
642            }
643        }
644
645        if let Some(max_size) = self.max_dataset_size {
646            if let Some(cap_size) = capabilities.max_dataset_size {
647                if cap_size < max_size {
648                    return false;
649                }
650            } else {
651                return false;
652            }
653        }
654
655        if let Some(max_features) = self.max_features {
656            if let Some(cap_features) = capabilities.max_features {
657                if cap_features < max_features {
658                    return false;
659                }
660            } else {
661                return false;
662            }
663        }
664
665        true
666    }
667}
668
669impl Default for PluginCapabilityFilter {
670    fn default() -> Self {
671        Self::new()
672    }
673}
674
675/// Plugin registry statistics
676#[derive(Debug, Clone, Serialize, Deserialize)]
677pub struct PluginRegistryStatistics {
678    /// Total number of plugins
679    pub total_plugins: usize,
680    /// Number of plugins by input type
681    pub plugins_by_input_type: HashMap<InputType, usize>,
682    /// Number of plugins by output type
683    pub plugins_by_output_type: HashMap<OutputType, usize>,
684    /// Registry creation timestamp
685    pub registry_created_at: chrono::DateTime<chrono::Utc>,
686}
687
688/// Plugin manager for orchestrating multiple plugins
689#[derive(Debug)]
690pub struct PluginManager {
691    registry: PluginRegistry,
692    execution_history: Arc<RwLock<Vec<PluginExecution>>>,
693}
694
695impl PluginManager {
696    /// Create a new plugin manager
697    pub fn new() -> Self {
698        Self {
699            registry: PluginRegistry::new(),
700            execution_history: Arc::new(RwLock::new(Vec::new())),
701        }
702    }
703
704    /// Get the plugin registry
705    pub fn registry(&self) -> &PluginRegistry {
706        &self.registry
707    }
708
709    /// Execute a plugin with history tracking
710    pub fn execute_with_history(
711        &self,
712        plugin_id: &str,
713        input: &PluginInput,
714    ) -> SklResult<PluginOutput> {
715        let start_time = std::time::Instant::now();
716        let result = self.registry.execute_plugin(plugin_id, input);
717        let execution_time = start_time.elapsed();
718
719        // Record execution
720        let execution = PluginExecution {
721            plugin_id: plugin_id.to_string(),
722            success: result.is_ok(),
723            execution_time_ms: execution_time.as_millis() as u64,
724            timestamp: chrono::Utc::now(),
725            error_message: result.as_ref().err().map(|e| e.to_string()),
726        };
727
728        if let Ok(mut history) = self.execution_history.write() {
729            history.push(execution);
730        }
731
732        result
733    }
734
735    /// Get execution history
736    pub fn get_execution_history(&self) -> Vec<PluginExecution> {
737        self.execution_history
738            .read()
739            .ok()
740            .map(|history| history.clone())
741            .unwrap_or_default()
742    }
743
744    /// Get execution statistics
745    pub fn get_execution_statistics(&self) -> ExecutionStatistics {
746        let history = self.get_execution_history();
747
748        let total_executions = history.len();
749        let successful_executions = history.iter().filter(|e| e.success).count();
750        let failed_executions = total_executions - successful_executions;
751
752        let average_execution_time = if total_executions > 0 {
753            history.iter().map(|e| e.execution_time_ms).sum::<u64>() / total_executions as u64
754        } else {
755            0
756        };
757
758        let plugin_usage = {
759            let mut usage = HashMap::new();
760            for execution in &history {
761                *usage.entry(execution.plugin_id.clone()).or_insert(0) += 1;
762            }
763            usage
764        };
765
766        ExecutionStatistics {
767            total_executions,
768            successful_executions,
769            failed_executions,
770            average_execution_time_ms: average_execution_time,
771            plugin_usage,
772        }
773    }
774
775    /// Clear execution history
776    pub fn clear_execution_history(&self) {
777        if let Ok(mut history) = self.execution_history.write() {
778            history.clear();
779        }
780    }
781}
782
783impl Default for PluginManager {
784    fn default() -> Self {
785        Self::new()
786    }
787}
788
789/// Plugin execution record
790#[derive(Debug, Clone, Serialize, Deserialize)]
791pub struct PluginExecution {
792    /// Plugin ID
793    pub plugin_id: String,
794    /// Whether execution was successful
795    pub success: bool,
796    /// Execution time in milliseconds
797    pub execution_time_ms: u64,
798    /// Execution timestamp
799    pub timestamp: chrono::DateTime<chrono::Utc>,
800    /// Error message (if failed)
801    pub error_message: Option<String>,
802}
803
804/// Execution statistics
805#[derive(Debug, Clone, Serialize, Deserialize)]
806pub struct ExecutionStatistics {
807    /// Total number of executions
808    pub total_executions: usize,
809    /// Number of successful executions
810    pub successful_executions: usize,
811    /// Number of failed executions
812    pub failed_executions: usize,
813    /// Average execution time in milliseconds
814    pub average_execution_time_ms: u64,
815    /// Plugin usage counts
816    pub plugin_usage: HashMap<String, usize>,
817}
818
819/// Example plugin implementation
820#[derive(Debug)]
821pub struct ExampleCustomPlugin {
822    id: String,
823    name: String,
824    version: String,
825    description: String,
826    author: String,
827    initialized: bool,
828}
829
830impl ExampleCustomPlugin {
831    /// Create a new example plugin
832    pub fn new() -> Self {
833        Self {
834            id: "example_custom_plugin".to_string(),
835            name: "Example Custom Plugin".to_string(),
836            version: "1.0.0".to_string(),
837            description: "An example plugin for demonstration purposes".to_string(),
838            author: "Sklears Team".to_string(),
839            initialized: false,
840        }
841    }
842}
843
844impl ExplanationPlugin for ExampleCustomPlugin {
845    fn plugin_id(&self) -> &str {
846        &self.id
847    }
848
849    fn plugin_name(&self) -> &str {
850        &self.name
851    }
852
853    fn plugin_version(&self) -> &str {
854        &self.version
855    }
856
857    fn plugin_description(&self) -> &str {
858        &self.description
859    }
860
861    fn plugin_author(&self) -> &str {
862        &self.author
863    }
864
865    fn supported_input_types(&self) -> Vec<InputType> {
866        vec![InputType::Tabular, InputType::TimeSeries]
867    }
868
869    fn supported_output_types(&self) -> Vec<OutputType> {
870        vec![OutputType::FeatureImportance, OutputType::LocalExplanation]
871    }
872
873    fn capabilities(&self) -> PluginCapabilities {
874        PluginCapabilities {
875            local_explanations: true,
876            global_explanations: true,
877            counterfactual_explanations: false,
878            uncertainty_quantification: false,
879            model_agnostic: true,
880            parallel_processing: false,
881            real_time: true,
882            streaming: false,
883            max_dataset_size: Some(10000),
884            max_features: Some(1000),
885            estimated_memory_usage: Some(1024 * 1024), // 1MB
886        }
887    }
888
889    fn initialize(&mut self, _config: &PluginConfig) -> SklResult<()> {
890        self.initialized = true;
891        Ok(())
892    }
893
894    fn execute(&self, input: &PluginInput) -> SklResult<PluginOutput> {
895        if !self.initialized {
896            return Err(crate::SklearsError::InvalidInput(
897                "Plugin not initialized".to_string(),
898            ));
899        }
900
901        let start_time = std::time::Instant::now();
902
903        // Example implementation: compute simple feature importance
904        let feature_importance = match &input.data {
905            PluginData::Tabular(data) => {
906                let n_features = data.ncols();
907                let importance_scores: Vec<Float> = (0..n_features)
908                    .map(|i| {
909                        let column = data.column(i);
910                        column.var(0.0) // Use variance as importance
911                    })
912                    .collect();
913
914                let feature_names = input
915                    .feature_names
916                    .clone()
917                    .unwrap_or_else(|| (0..n_features).map(|i| format!("feature_{}", i)).collect());
918
919                PluginOutputData::FeatureImportance {
920                    scores: importance_scores,
921                    feature_names,
922                    std_errors: None,
923                }
924            }
925            _ => {
926                return Err(crate::SklearsError::InvalidInput(
927                    "Unsupported input type for this plugin".to_string(),
928                ));
929            }
930        };
931
932        let execution_time = start_time.elapsed().as_millis() as u64;
933
934        Ok(PluginOutput {
935            output_type: OutputType::FeatureImportance,
936            data: feature_importance,
937            metadata: ExecutionMetadata {
938                execution_time_ms: execution_time,
939                memory_usage_bytes: 0,
940                iterations: None,
941                converged: Some(true),
942                warnings: Vec::new(),
943                timestamp: chrono::Utc::now(),
944            },
945            confidence: None,
946            uncertainty: None,
947        })
948    }
949
950    fn validate_input(&self, input: &PluginInput) -> SklResult<()> {
951        match &input.data {
952            PluginData::Tabular(data) => {
953                if data.nrows() == 0 || data.ncols() == 0 {
954                    return Err(crate::SklearsError::InvalidInput(
955                        "Input data cannot be empty".to_string(),
956                    ));
957                }
958
959                if let Some(max_features) = self.capabilities().max_features {
960                    if data.ncols() > max_features {
961                        return Err(crate::SklearsError::InvalidInput(format!(
962                            "Too many features: {} > {}",
963                            data.ncols(),
964                            max_features
965                        )));
966                    }
967                }
968
969                if let Some(max_samples) = self.capabilities().max_dataset_size {
970                    if data.nrows() > max_samples {
971                        return Err(crate::SklearsError::InvalidInput(format!(
972                            "Too many samples: {} > {}",
973                            data.nrows(),
974                            max_samples
975                        )));
976                    }
977                }
978
979                Ok(())
980            }
981            _ => Err(crate::SklearsError::InvalidInput(
982                "Unsupported input type".to_string(),
983            )),
984        }
985    }
986
987    fn cleanup(&mut self) -> SklResult<()> {
988        self.initialized = false;
989        Ok(())
990    }
991}
992
993impl Default for ExampleCustomPlugin {
994    fn default() -> Self {
995        Self::new()
996    }
997}
998
999#[cfg(test)]
1000mod tests {
1001    use super::*;
1002    // ✅ SciRS2 Policy Compliant Import
1003    use scirs2_core::ndarray::Array2;
1004
1005    #[test]
1006    fn test_plugin_registry() {
1007        let registry = PluginRegistry::new();
1008
1009        // Register a plugin
1010        let plugin = ExampleCustomPlugin::new();
1011        let result = registry.register_plugin(plugin, None);
1012        assert!(result.is_ok());
1013
1014        // Check plugin is registered
1015        let plugins = registry.list_plugins();
1016        assert!(plugins.contains(&"example_custom_plugin".to_string()));
1017
1018        // Get plugin metadata
1019        let metadata = registry.get_plugin_metadata("example_custom_plugin");
1020        assert!(metadata.is_some());
1021        let metadata = metadata.unwrap();
1022        assert_eq!(metadata.name, "Example Custom Plugin");
1023        assert_eq!(metadata.version, "1.0.0");
1024    }
1025
1026    #[test]
1027    fn test_plugin_execution() {
1028        let registry = PluginRegistry::new();
1029
1030        // Register plugin
1031        let plugin = ExampleCustomPlugin::new();
1032        registry.register_plugin(plugin, None).unwrap();
1033
1034        // Create input data
1035        let data = Array2::from_shape_vec((10, 3), (0..30).map(|x| x as Float).collect()).unwrap();
1036        let input = PluginInput {
1037            data: PluginData::Tabular(data),
1038            predictions: None,
1039            targets: None,
1040            feature_names: Some(vec!["f1".to_string(), "f2".to_string(), "f3".to_string()]),
1041            sample_weights: None,
1042            metadata: HashMap::new(),
1043        };
1044
1045        // Execute plugin
1046        let result = registry.execute_plugin("example_custom_plugin", &input);
1047        assert!(result.is_ok());
1048
1049        let output = result.unwrap();
1050        assert_eq!(output.output_type, OutputType::FeatureImportance);
1051
1052        match output.data {
1053            PluginOutputData::FeatureImportance {
1054                scores,
1055                feature_names,
1056                ..
1057            } => {
1058                assert_eq!(scores.len(), 3);
1059                assert_eq!(feature_names.len(), 3);
1060                assert_eq!(feature_names[0], "f1");
1061            }
1062            _ => panic!("Expected feature importance output"),
1063        }
1064    }
1065
1066    #[test]
1067    fn test_plugin_capability_filter() {
1068        let capabilities = PluginCapabilities {
1069            local_explanations: true,
1070            global_explanations: true,
1071            counterfactual_explanations: false,
1072            uncertainty_quantification: false,
1073            model_agnostic: true,
1074            parallel_processing: false,
1075            real_time: true,
1076            streaming: false,
1077            max_dataset_size: Some(10000),
1078            max_features: Some(1000),
1079            estimated_memory_usage: Some(1024 * 1024),
1080        };
1081
1082        let mut filter = PluginCapabilityFilter::new();
1083        filter.local_explanations = Some(true);
1084        filter.real_time = Some(true);
1085        filter.max_dataset_size = Some(5000);
1086
1087        assert!(filter.matches(&capabilities));
1088
1089        filter.max_dataset_size = Some(20000);
1090        assert!(!filter.matches(&capabilities));
1091    }
1092
1093    #[test]
1094    fn test_plugin_manager() {
1095        let manager = PluginManager::new();
1096
1097        // Register plugin
1098        let plugin = ExampleCustomPlugin::new();
1099        manager.registry().register_plugin(plugin, None).unwrap();
1100
1101        // Create input data
1102        let data = Array2::from_shape_vec((5, 2), (0..10).map(|x| x as Float).collect()).unwrap();
1103        let input = PluginInput {
1104            data: PluginData::Tabular(data),
1105            predictions: None,
1106            targets: None,
1107            feature_names: None,
1108            sample_weights: None,
1109            metadata: HashMap::new(),
1110        };
1111
1112        // Execute with history
1113        let result = manager.execute_with_history("example_custom_plugin", &input);
1114        assert!(result.is_ok());
1115
1116        // Check history
1117        let history = manager.get_execution_history();
1118        assert_eq!(history.len(), 1);
1119        assert!(history[0].success);
1120
1121        // Get statistics
1122        let stats = manager.get_execution_statistics();
1123        assert_eq!(stats.total_executions, 1);
1124        assert_eq!(stats.successful_executions, 1);
1125        assert_eq!(stats.failed_executions, 0);
1126    }
1127
1128    #[test]
1129    fn test_plugin_list_by_type() {
1130        let registry = PluginRegistry::new();
1131
1132        // Register plugin
1133        let plugin = ExampleCustomPlugin::new();
1134        registry.register_plugin(plugin, None).unwrap();
1135
1136        // List by input type
1137        let tabular_plugins = registry.list_plugins_by_input_type(InputType::Tabular);
1138        assert!(tabular_plugins.contains(&"example_custom_plugin".to_string()));
1139
1140        let image_plugins = registry.list_plugins_by_input_type(InputType::Image);
1141        assert!(image_plugins.is_empty());
1142
1143        // List by output type
1144        let importance_plugins =
1145            registry.list_plugins_by_output_type(OutputType::FeatureImportance);
1146        assert!(importance_plugins.contains(&"example_custom_plugin".to_string()));
1147
1148        let counterfactual_plugins =
1149            registry.list_plugins_by_output_type(OutputType::CounterfactualExplanation);
1150        assert!(counterfactual_plugins.is_empty());
1151    }
1152
1153    #[test]
1154    fn test_plugin_validation() {
1155        let plugin = ExampleCustomPlugin::new();
1156
1157        // Test empty data validation
1158        let empty_data = Array2::from_shape_vec((0, 0), vec![]).unwrap();
1159        let input = PluginInput {
1160            data: PluginData::Tabular(empty_data),
1161            predictions: None,
1162            targets: None,
1163            feature_names: None,
1164            sample_weights: None,
1165            metadata: HashMap::new(),
1166        };
1167
1168        let result = plugin.validate_input(&input);
1169        assert!(result.is_err());
1170
1171        // Test valid data
1172        let valid_data =
1173            Array2::from_shape_vec((5, 2), (0..10).map(|x| x as Float).collect()).unwrap();
1174        let input = PluginInput {
1175            data: PluginData::Tabular(valid_data),
1176            predictions: None,
1177            targets: None,
1178            feature_names: None,
1179            sample_weights: None,
1180            metadata: HashMap::new(),
1181        };
1182
1183        let result = plugin.validate_input(&input);
1184        assert!(result.is_ok());
1185    }
1186
1187    #[test]
1188    fn test_plugin_parameter_types() {
1189        let mut config = PluginConfig::default();
1190
1191        config
1192            .parameters
1193            .insert("integer_param".to_string(), PluginParameter::Integer(42));
1194        config
1195            .parameters
1196            .insert("float_param".to_string(), PluginParameter::Float(3.14));
1197        config.parameters.insert(
1198            "string_param".to_string(),
1199            PluginParameter::String("test".to_string()),
1200        );
1201        config
1202            .parameters
1203            .insert("bool_param".to_string(), PluginParameter::Boolean(true));
1204
1205        assert_eq!(config.parameters.len(), 4);
1206
1207        match config.parameters.get("integer_param") {
1208            Some(PluginParameter::Integer(val)) => assert_eq!(*val, 42),
1209            _ => panic!("Expected integer parameter"),
1210        }
1211
1212        match config.parameters.get("float_param") {
1213            Some(PluginParameter::Float(val)) => assert_eq!(*val, 3.14),
1214            _ => panic!("Expected float parameter"),
1215        }
1216    }
1217}