sklears_inspection/
builder.rs

1//! Builder pattern and fluent API for complex explanations
2//!
3//! This module provides builder patterns for constructing complex explanation pipelines
4//! with a fluent API that makes it easy to chain operations and configure explanations.
5
6#[cfg(feature = "parallel")]
7use crate::ParallelConfig;
8use crate::{Float, SklResult};
9
10#[cfg(not(feature = "parallel"))]
11#[derive(Debug, Clone, Default)]
12struct ParallelConfig;
13// ✅ SciRS2 Policy Compliant Import
14use scirs2_core::ndarray::Array1;
15use scirs2_core::random::Rng;
16use std::marker::PhantomData;
17
18/// Fluent builder for explanation configurations
19#[derive(Debug, Clone)]
20pub struct ExplanationBuilder<T> {
21    target_type: PhantomData<T>,
22    n_samples: Option<usize>,
23    n_features: Option<usize>,
24    random_state: Option<u64>,
25    parallel_config: ParallelConfig,
26    validation_enabled: bool,
27    preprocessing_enabled: bool,
28    postprocessing_enabled: bool,
29}
30
31impl<T> Default for ExplanationBuilder<T> {
32    fn default() -> Self {
33        Self {
34            target_type: PhantomData,
35            n_samples: None,
36            n_features: None,
37            random_state: None,
38            parallel_config: ParallelConfig::default(),
39            validation_enabled: true,
40            preprocessing_enabled: false,
41            postprocessing_enabled: false,
42        }
43    }
44}
45
46impl<T> ExplanationBuilder<T> {
47    /// Create a new explanation builder
48    pub fn new() -> Self {
49        Self::default()
50    }
51
52    /// Set the number of samples to use
53    pub fn with_n_samples(mut self, n_samples: usize) -> Self {
54        self.n_samples = Some(n_samples);
55        self
56    }
57
58    /// Set the number of features
59    pub fn with_n_features(mut self, n_features: usize) -> Self {
60        self.n_features = Some(n_features);
61        self
62    }
63
64    /// Set the random state for reproducibility
65    pub fn with_random_state(mut self, random_state: u64) -> Self {
66        self.random_state = Some(random_state);
67        self
68    }
69
70    /// Configure parallel computation
71    pub fn with_parallel_config(mut self, config: ParallelConfig) -> Self {
72        self.parallel_config = config;
73        self
74    }
75
76    /// Enable/disable validation
77    pub fn with_validation(mut self, enabled: bool) -> Self {
78        self.validation_enabled = enabled;
79        self
80    }
81
82    /// Enable preprocessing
83    pub fn with_preprocessing(mut self) -> Self {
84        self.preprocessing_enabled = true;
85        self
86    }
87
88    /// Enable postprocessing
89    pub fn with_postprocessing(mut self) -> Self {
90        self.postprocessing_enabled = true;
91        self
92    }
93
94    /// Set the number of threads for parallel computation
95    pub fn with_threads(mut self, n_threads: usize) -> Self {
96        self.parallel_config = self.parallel_config.with_threads(n_threads);
97        self
98    }
99
100    /// Force sequential computation
101    pub fn sequential(mut self) -> Self {
102        self.parallel_config = self.parallel_config.sequential();
103        self
104    }
105
106    /// Build a SHAP configuration
107    pub fn build_shap_config(self) -> ShapConfig {
108        ShapConfig {
109            n_samples: self.n_samples.unwrap_or(1000),
110            random_state: self.random_state,
111            parallel_config: self.parallel_config,
112            validation_enabled: self.validation_enabled,
113        }
114    }
115
116    /// Build a LIME configuration
117    pub fn build_lime_config(self) -> LimeConfig {
118        LimeConfig {
119            n_samples: self.n_samples.unwrap_or(5000),
120            random_state: self.random_state,
121            parallel_config: self.parallel_config,
122            kernel_width: 0.75,
123            feature_selection: FeatureSelection::Auto,
124        }
125    }
126
127    /// Build a permutation importance configuration
128    pub fn build_permutation_config(self) -> PermutationConfig {
129        PermutationConfig {
130            n_repeats: self.n_samples.unwrap_or(10),
131            random_state: self.random_state,
132            parallel_config: self.parallel_config,
133            score_function: ScoreFunction::Accuracy,
134        }
135    }
136
137    /// Build a counterfactual configuration
138    pub fn build_counterfactual_config(self) -> CounterfactualConfig {
139        CounterfactualConfig {
140            max_iterations: self.n_samples.unwrap_or(1000),
141            random_state: self.random_state,
142            distance_threshold: 0.1,
143            optimization_method: OptimizationMethod::GradientDescent,
144        }
145    }
146}
147
148/// Configuration for SHAP explanations
149#[derive(Debug, Clone)]
150pub struct ShapConfig {
151    /// n_samples
152    pub n_samples: usize,
153    /// random_state
154    pub random_state: Option<u64>,
155    /// parallel_config
156    pub parallel_config: ParallelConfig,
157    /// validation_enabled
158    pub validation_enabled: bool,
159}
160
161/// Configuration for LIME explanations
162#[derive(Debug, Clone)]
163pub struct LimeConfig {
164    /// n_samples
165    pub n_samples: usize,
166    /// random_state
167    pub random_state: Option<u64>,
168    /// parallel_config
169    pub parallel_config: ParallelConfig,
170    /// kernel_width
171    pub kernel_width: Float,
172    /// feature_selection
173    pub feature_selection: FeatureSelection,
174}
175
176/// Feature selection method for LIME
177#[derive(Debug, Clone)]
178pub enum FeatureSelection {
179    /// Auto
180    Auto,
181    /// Lasso
182    Lasso,
183    /// Forward
184    Forward,
185
186    None,
187}
188
189/// Configuration for permutation importance
190#[derive(Debug, Clone)]
191pub struct PermutationConfig {
192    /// n_repeats
193    pub n_repeats: usize,
194    /// random_state
195    pub random_state: Option<u64>,
196    /// parallel_config
197    pub parallel_config: ParallelConfig,
198    /// score_function
199    pub score_function: ScoreFunction,
200}
201
202/// Score function for permutation importance
203#[derive(Debug, Clone)]
204pub enum ScoreFunction {
205    /// Accuracy
206    Accuracy,
207    /// R2
208    R2,
209    /// MeanSquaredError
210    MeanSquaredError,
211    /// MeanAbsoluteError
212    MeanAbsoluteError,
213}
214
215/// Configuration for counterfactual explanations
216#[derive(Debug, Clone)]
217pub struct CounterfactualConfig {
218    /// max_iterations
219    pub max_iterations: usize,
220    /// random_state
221    pub random_state: Option<u64>,
222    /// distance_threshold
223    pub distance_threshold: Float,
224    /// optimization_method
225    pub optimization_method: OptimizationMethod,
226}
227
228/// Optimization method for counterfactuals
229#[derive(Debug, Clone)]
230pub enum OptimizationMethod {
231    /// GradientDescent
232    GradientDescent,
233    /// SimulatedAnnealing
234    SimulatedAnnealing,
235    /// GeneticAlgorithm
236    GeneticAlgorithm,
237}
238
239/// Fluent builder for complex explanation pipelines
240#[derive(Debug)]
241pub struct PipelineBuilder<Input> {
242    steps: Vec<PipelineStep>,
243
244    parallel_config: ParallelConfig,
245    _input_type: PhantomData<Input>,
246}
247
248impl<Input> Default for PipelineBuilder<Input> {
249    fn default() -> Self {
250        Self {
251            steps: Vec::new(),
252            parallel_config: ParallelConfig::default(),
253            _input_type: PhantomData,
254        }
255    }
256}
257
258impl<Input> PipelineBuilder<Input>
259where
260    Input: Send + Sync + 'static,
261{
262    /// Create a new pipeline builder
263    pub fn new() -> Self {
264        Self::default()
265    }
266
267    /// Add a SHAP explanation step
268    pub fn add_shap(mut self, config: ShapConfig) -> Self {
269        self.steps.push(PipelineStep::Shap(config));
270        self
271    }
272
273    /// Add a LIME explanation step
274    pub fn add_lime(mut self, config: LimeConfig) -> Self {
275        self.steps.push(PipelineStep::Lime(config));
276        self
277    }
278
279    /// Add a permutation importance step
280    pub fn add_permutation(mut self, config: PermutationConfig) -> Self {
281        self.steps.push(PipelineStep::Permutation(config));
282        self
283    }
284
285    /// Add a counterfactual explanation step
286    pub fn add_counterfactual(mut self, config: CounterfactualConfig) -> Self {
287        self.steps.push(PipelineStep::Counterfactual(config));
288        self
289    }
290
291    /// Add a custom explanation step
292    pub fn add_custom(mut self, name: String) -> Self {
293        self.steps.push(PipelineStep::Custom { name });
294        self
295    }
296
297    /// Add a validation step
298    pub fn add_validation(mut self) -> Self {
299        self.steps.push(PipelineStep::Validation);
300        self
301    }
302
303    /// Add a normalization step
304    pub fn add_normalization(mut self) -> Self {
305        self.steps.push(PipelineStep::Normalization);
306        self
307    }
308
309    /// Configure parallel execution
310    pub fn with_parallel_config(mut self, config: ParallelConfig) -> Self {
311        self.parallel_config = config;
312        self
313    }
314
315    /// Build the pipeline
316    pub fn build(self) -> ExplanationPipelineExecutor<Input> {
317        ExplanationPipelineExecutor {
318            steps: self.steps,
319            parallel_config: self.parallel_config,
320            _input_type: PhantomData,
321        }
322    }
323}
324
325/// A step in the explanation pipeline
326#[derive(Debug, Clone)]
327pub enum PipelineStep {
328    /// Shap
329    Shap(ShapConfig),
330    /// Lime
331    Lime(LimeConfig),
332    /// Permutation
333    Permutation(PermutationConfig),
334    /// Counterfactual
335    Counterfactual(CounterfactualConfig),
336    /// Validation
337    Validation,
338    /// Normalization
339    Normalization,
340    /// Custom
341    Custom { name: String },
342}
343
344/// Executor for the explanation pipeline
345pub struct ExplanationPipelineExecutor<Input> {
346    steps: Vec<PipelineStep>,
347    parallel_config: ParallelConfig,
348    _input_type: PhantomData<Input>,
349}
350
351impl<Input> ExplanationPipelineExecutor<Input>
352where
353    Input: Send + Sync,
354{
355    /// Execute the pipeline
356    pub fn execute(&self, input: &Input) -> SklResult<PipelineExecutionResult> {
357        let mut results: Vec<Array1<Float>> = Vec::new();
358        let mut metadata: Vec<StepMetadata> = Vec::new();
359
360        for (i, step) in self.steps.iter().enumerate() {
361            let step_name = format!("Step_{}", i);
362            let start_time = std::time::Instant::now();
363
364            let result = match step {
365                PipelineStep::Shap(_config) => {
366                    // Placeholder SHAP implementation
367                    Ok::<Array1<Float>, crate::SklearsError>(Array1::zeros(10)) // Mock result
368                }
369                PipelineStep::Lime(_config) => {
370                    // Placeholder LIME implementation
371                    Ok::<Array1<Float>, crate::SklearsError>(Array1::zeros(10)) // Mock result
372                }
373                PipelineStep::Permutation(_config) => {
374                    // Placeholder permutation implementation
375                    Ok::<Array1<Float>, crate::SklearsError>(Array1::zeros(10)) // Mock result
376                }
377                PipelineStep::Counterfactual(_config) => {
378                    // Placeholder counterfactual implementation
379                    Ok::<Array1<Float>, crate::SklearsError>(Array1::zeros(10)) // Mock result
380                }
381                PipelineStep::Validation => {
382                    // Validation step doesn't produce explanations
383                    continue;
384                }
385                PipelineStep::Normalization => {
386                    // Normalization step modifies existing results
387                    if let Some(last_result) = results.last_mut() {
388                        let sum = last_result.sum();
389                        if sum != 0.0 {
390                            *last_result = last_result.mapv(|x| x / sum);
391                        }
392                    }
393                    continue;
394                }
395                PipelineStep::Custom { name: _ } => {
396                    // Placeholder custom implementation
397                    Ok::<Array1<Float>, crate::SklearsError>(Array1::zeros(10)) // Mock result
398                }
399            };
400
401            let execution_time = start_time.elapsed();
402
403            match result {
404                Ok(explanation) => {
405                    results.push(explanation);
406                    metadata.push(StepMetadata {
407                        step_name,
408                        execution_time,
409                        success: true,
410                        error_message: None,
411                    });
412                }
413                Err(e) => {
414                    metadata.push(StepMetadata {
415                        step_name,
416                        execution_time,
417                        success: false,
418                        error_message: Some(e.to_string()),
419                    });
420                    return Err(e);
421                }
422            }
423        }
424
425        Ok(PipelineExecutionResult {
426            explanations: results,
427            metadata,
428        })
429    }
430}
431
432/// Result of pipeline execution
433#[derive(Debug, Clone)]
434pub struct PipelineExecutionResult {
435    /// explanations
436    pub explanations: Vec<Array1<Float>>,
437    /// metadata
438    pub metadata: Vec<StepMetadata>,
439}
440
441/// Metadata for a pipeline step
442#[derive(Debug, Clone)]
443pub struct StepMetadata {
444    /// step_name
445    pub step_name: String,
446    /// execution_time
447    pub execution_time: std::time::Duration,
448    /// success
449    pub success: bool,
450    /// error_message
451    pub error_message: Option<String>,
452}
453
454/// Builder for explanation comparison studies
455#[derive(Debug, Default)]
456pub struct ComparisonStudyBuilder {
457    methods: Vec<String>,
458    datasets: Vec<String>,
459    metrics: Vec<String>,
460    parallel_config: ParallelConfig,
461}
462
463impl ComparisonStudyBuilder {
464    /// Create a new comparison study builder
465    pub fn new() -> Self {
466        Self::default()
467    }
468
469    /// Add an explanation method to compare
470    pub fn add_method<S: Into<String>>(mut self, method: S) -> Self {
471        self.methods.push(method.into());
472        self
473    }
474
475    /// Add a dataset to test on
476    pub fn add_dataset<S: Into<String>>(mut self, dataset: S) -> Self {
477        self.datasets.push(dataset.into());
478        self
479    }
480
481    /// Add an evaluation metric
482    pub fn add_metric<S: Into<String>>(mut self, metric: S) -> Self {
483        self.metrics.push(metric.into());
484        self
485    }
486
487    /// Configure parallel execution
488    pub fn with_parallel_config(mut self, config: ParallelConfig) -> Self {
489        self.parallel_config = config;
490        self
491    }
492
493    /// Build the comparison study
494    pub fn build(self) -> ComparisonStudy {
495        ComparisonStudy {
496            methods: self.methods,
497            datasets: self.datasets,
498            metrics: self.metrics,
499            parallel_config: self.parallel_config,
500        }
501    }
502}
503
504/// A comparison study configuration
505#[derive(Debug, Clone)]
506pub struct ComparisonStudy {
507    /// methods
508    pub methods: Vec<String>,
509    /// datasets
510    pub datasets: Vec<String>,
511    /// metrics
512    pub metrics: Vec<String>,
513    /// parallel_config
514    pub parallel_config: ParallelConfig,
515}
516
517impl ComparisonStudy {
518    /// Execute the comparison study
519    pub fn execute(&self) -> SklResult<ComparisonResults> {
520        let mut results = Vec::new();
521
522        for method in &self.methods {
523            for dataset in &self.datasets {
524                for metric in &self.metrics {
525                    // Placeholder comparison implementation
526                    let score = scirs2_core::random::thread_rng().random::<Float>(); // Mock score
527                    results.push(ComparisonResult {
528                        method: method.clone(),
529                        dataset: dataset.clone(),
530                        metric: metric.clone(),
531                        score,
532                    });
533                }
534            }
535        }
536
537        Ok(ComparisonResults { results })
538    }
539}
540
541/// Results of a comparison study
542#[derive(Debug, Clone)]
543pub struct ComparisonResults {
544    /// results
545    pub results: Vec<ComparisonResult>,
546}
547
548/// Individual comparison result
549#[derive(Debug, Clone)]
550pub struct ComparisonResult {
551    /// method
552    pub method: String,
553    /// dataset
554    pub dataset: String,
555    /// metric
556    pub metric: String,
557    /// score
558    pub score: Float,
559}
560
561#[cfg(test)]
562mod tests {
563    use super::*;
564    // ✅ SciRS2 Policy Compliant Import
565    use scirs2_core::ndarray::array;
566    use sklears_core::prelude::ArrayView1;
567
568    #[test]
569    fn test_explanation_builder_creation() {
570        let builder: ExplanationBuilder<ArrayView1<Float>> = ExplanationBuilder::new();
571        assert!(builder.n_samples.is_none());
572        assert!(builder.random_state.is_none());
573        assert!(builder.validation_enabled);
574    }
575
576    #[test]
577    fn test_explanation_builder_fluent_api() {
578        let builder: ExplanationBuilder<ArrayView1<Float>> = ExplanationBuilder::new()
579            .with_n_samples(1000)
580            .with_random_state(42)
581            .with_threads(4)
582            .with_validation(false);
583
584        assert_eq!(builder.n_samples, Some(1000));
585        assert_eq!(builder.random_state, Some(42));
586        assert!(!builder.validation_enabled);
587    }
588
589    #[test]
590    fn test_shap_config_building() {
591        let config = ExplanationBuilder::<ArrayView1<Float>>::new()
592            .with_n_samples(2000)
593            .with_random_state(123)
594            .build_shap_config();
595
596        assert_eq!(config.n_samples, 2000);
597        assert_eq!(config.random_state, Some(123));
598        assert!(config.validation_enabled);
599    }
600
601    #[test]
602    fn test_lime_config_building() {
603        let config = ExplanationBuilder::<ArrayView1<Float>>::new()
604            .with_n_samples(5000)
605            .build_lime_config();
606
607        assert_eq!(config.n_samples, 5000);
608        assert_eq!(config.kernel_width, 0.75);
609        assert!(matches!(config.feature_selection, FeatureSelection::Auto));
610    }
611
612    #[test]
613    fn test_pipeline_builder_creation() {
614        let builder: PipelineBuilder<ArrayView1<Float>> = PipelineBuilder::new();
615        assert_eq!(builder.steps.len(), 0);
616    }
617
618    #[test]
619    fn test_pipeline_builder_fluent_api() {
620        let shap_config = ExplanationBuilder::<ArrayView1<Float>>::new().build_shap_config();
621        let lime_config = ExplanationBuilder::<ArrayView1<Float>>::new().build_lime_config();
622
623        let pipeline = PipelineBuilder::<ArrayView1<Float>>::new()
624            .add_shap(shap_config)
625            .add_lime(lime_config)
626            .add_validation()
627            .add_normalization()
628            .build();
629
630        assert_eq!(pipeline.steps.len(), 4);
631    }
632
633    #[test]
634    fn test_comparison_study_builder() {
635        let study = ComparisonStudyBuilder::new()
636            .add_method("SHAP")
637            .add_method("LIME")
638            .add_dataset("iris")
639            .add_dataset("wine")
640            .add_metric("fidelity")
641            .add_metric("stability")
642            .build();
643
644        assert_eq!(study.methods.len(), 2);
645        assert_eq!(study.datasets.len(), 2);
646        assert_eq!(study.metrics.len(), 2);
647    }
648
649    #[test]
650    fn test_comparison_study_execution() {
651        let study = ComparisonStudyBuilder::new()
652            .add_method("SHAP")
653            .add_dataset("iris")
654            .add_metric("fidelity")
655            .build();
656
657        let results = study.execute();
658        assert!(results.is_ok());
659
660        let comparison_results = results.unwrap();
661        assert_eq!(comparison_results.results.len(), 1);
662        assert_eq!(comparison_results.results[0].method, "SHAP");
663        assert_eq!(comparison_results.results[0].dataset, "iris");
664        assert_eq!(comparison_results.results[0].metric, "fidelity");
665    }
666
667    #[test]
668    fn test_score_function_variants() {
669        assert!(matches!(ScoreFunction::Accuracy, ScoreFunction::Accuracy));
670        assert!(matches!(ScoreFunction::R2, ScoreFunction::R2));
671        assert!(matches!(
672            ScoreFunction::MeanSquaredError,
673            ScoreFunction::MeanSquaredError
674        ));
675        assert!(matches!(
676            ScoreFunction::MeanAbsoluteError,
677            ScoreFunction::MeanAbsoluteError
678        ));
679    }
680
681    #[test]
682    fn test_optimization_method_variants() {
683        assert!(matches!(
684            OptimizationMethod::GradientDescent,
685            OptimizationMethod::GradientDescent
686        ));
687        assert!(matches!(
688            OptimizationMethod::SimulatedAnnealing,
689            OptimizationMethod::SimulatedAnnealing
690        ));
691        assert!(matches!(
692            OptimizationMethod::GeneticAlgorithm,
693            OptimizationMethod::GeneticAlgorithm
694        ));
695    }
696
697    #[test]
698    fn test_feature_selection_variants() {
699        assert!(matches!(FeatureSelection::Auto, FeatureSelection::Auto));
700        assert!(matches!(FeatureSelection::Lasso, FeatureSelection::Lasso));
701        assert!(matches!(
702            FeatureSelection::Forward,
703            FeatureSelection::Forward
704        ));
705        assert!(matches!(FeatureSelection::None, FeatureSelection::None));
706    }
707}