1#[cfg(feature = "parallel")]
7use crate::ParallelConfig;
8use crate::{Float, SklResult};
9
10#[cfg(not(feature = "parallel"))]
11#[derive(Debug, Clone, Default)]
12struct ParallelConfig;
13use scirs2_core::ndarray::Array1;
15use scirs2_core::random::Rng;
16use std::marker::PhantomData;
17
18#[derive(Debug, Clone)]
20pub struct ExplanationBuilder<T> {
21 target_type: PhantomData<T>,
22 n_samples: Option<usize>,
23 n_features: Option<usize>,
24 random_state: Option<u64>,
25 parallel_config: ParallelConfig,
26 validation_enabled: bool,
27 preprocessing_enabled: bool,
28 postprocessing_enabled: bool,
29}
30
31impl<T> Default for ExplanationBuilder<T> {
32 fn default() -> Self {
33 Self {
34 target_type: PhantomData,
35 n_samples: None,
36 n_features: None,
37 random_state: None,
38 parallel_config: ParallelConfig::default(),
39 validation_enabled: true,
40 preprocessing_enabled: false,
41 postprocessing_enabled: false,
42 }
43 }
44}
45
46impl<T> ExplanationBuilder<T> {
47 pub fn new() -> Self {
49 Self::default()
50 }
51
52 pub fn with_n_samples(mut self, n_samples: usize) -> Self {
54 self.n_samples = Some(n_samples);
55 self
56 }
57
58 pub fn with_n_features(mut self, n_features: usize) -> Self {
60 self.n_features = Some(n_features);
61 self
62 }
63
64 pub fn with_random_state(mut self, random_state: u64) -> Self {
66 self.random_state = Some(random_state);
67 self
68 }
69
70 pub fn with_parallel_config(mut self, config: ParallelConfig) -> Self {
72 self.parallel_config = config;
73 self
74 }
75
76 pub fn with_validation(mut self, enabled: bool) -> Self {
78 self.validation_enabled = enabled;
79 self
80 }
81
82 pub fn with_preprocessing(mut self) -> Self {
84 self.preprocessing_enabled = true;
85 self
86 }
87
88 pub fn with_postprocessing(mut self) -> Self {
90 self.postprocessing_enabled = true;
91 self
92 }
93
94 pub fn with_threads(mut self, n_threads: usize) -> Self {
96 self.parallel_config = self.parallel_config.with_threads(n_threads);
97 self
98 }
99
100 pub fn sequential(mut self) -> Self {
102 self.parallel_config = self.parallel_config.sequential();
103 self
104 }
105
106 pub fn build_shap_config(self) -> ShapConfig {
108 ShapConfig {
109 n_samples: self.n_samples.unwrap_or(1000),
110 random_state: self.random_state,
111 parallel_config: self.parallel_config,
112 validation_enabled: self.validation_enabled,
113 }
114 }
115
116 pub fn build_lime_config(self) -> LimeConfig {
118 LimeConfig {
119 n_samples: self.n_samples.unwrap_or(5000),
120 random_state: self.random_state,
121 parallel_config: self.parallel_config,
122 kernel_width: 0.75,
123 feature_selection: FeatureSelection::Auto,
124 }
125 }
126
127 pub fn build_permutation_config(self) -> PermutationConfig {
129 PermutationConfig {
130 n_repeats: self.n_samples.unwrap_or(10),
131 random_state: self.random_state,
132 parallel_config: self.parallel_config,
133 score_function: ScoreFunction::Accuracy,
134 }
135 }
136
137 pub fn build_counterfactual_config(self) -> CounterfactualConfig {
139 CounterfactualConfig {
140 max_iterations: self.n_samples.unwrap_or(1000),
141 random_state: self.random_state,
142 distance_threshold: 0.1,
143 optimization_method: OptimizationMethod::GradientDescent,
144 }
145 }
146}
147
148#[derive(Debug, Clone)]
150pub struct ShapConfig {
151 pub n_samples: usize,
153 pub random_state: Option<u64>,
155 pub parallel_config: ParallelConfig,
157 pub validation_enabled: bool,
159}
160
161#[derive(Debug, Clone)]
163pub struct LimeConfig {
164 pub n_samples: usize,
166 pub random_state: Option<u64>,
168 pub parallel_config: ParallelConfig,
170 pub kernel_width: Float,
172 pub feature_selection: FeatureSelection,
174}
175
176#[derive(Debug, Clone)]
178pub enum FeatureSelection {
179 Auto,
181 Lasso,
183 Forward,
185
186 None,
187}
188
189#[derive(Debug, Clone)]
191pub struct PermutationConfig {
192 pub n_repeats: usize,
194 pub random_state: Option<u64>,
196 pub parallel_config: ParallelConfig,
198 pub score_function: ScoreFunction,
200}
201
202#[derive(Debug, Clone)]
204pub enum ScoreFunction {
205 Accuracy,
207 R2,
209 MeanSquaredError,
211 MeanAbsoluteError,
213}
214
215#[derive(Debug, Clone)]
217pub struct CounterfactualConfig {
218 pub max_iterations: usize,
220 pub random_state: Option<u64>,
222 pub distance_threshold: Float,
224 pub optimization_method: OptimizationMethod,
226}
227
228#[derive(Debug, Clone)]
230pub enum OptimizationMethod {
231 GradientDescent,
233 SimulatedAnnealing,
235 GeneticAlgorithm,
237}
238
239#[derive(Debug)]
241pub struct PipelineBuilder<Input> {
242 steps: Vec<PipelineStep>,
243
244 parallel_config: ParallelConfig,
245 _input_type: PhantomData<Input>,
246}
247
248impl<Input> Default for PipelineBuilder<Input> {
249 fn default() -> Self {
250 Self {
251 steps: Vec::new(),
252 parallel_config: ParallelConfig::default(),
253 _input_type: PhantomData,
254 }
255 }
256}
257
258impl<Input> PipelineBuilder<Input>
259where
260 Input: Send + Sync + 'static,
261{
262 pub fn new() -> Self {
264 Self::default()
265 }
266
267 pub fn add_shap(mut self, config: ShapConfig) -> Self {
269 self.steps.push(PipelineStep::Shap(config));
270 self
271 }
272
273 pub fn add_lime(mut self, config: LimeConfig) -> Self {
275 self.steps.push(PipelineStep::Lime(config));
276 self
277 }
278
279 pub fn add_permutation(mut self, config: PermutationConfig) -> Self {
281 self.steps.push(PipelineStep::Permutation(config));
282 self
283 }
284
285 pub fn add_counterfactual(mut self, config: CounterfactualConfig) -> Self {
287 self.steps.push(PipelineStep::Counterfactual(config));
288 self
289 }
290
291 pub fn add_custom(mut self, name: String) -> Self {
293 self.steps.push(PipelineStep::Custom { name });
294 self
295 }
296
297 pub fn add_validation(mut self) -> Self {
299 self.steps.push(PipelineStep::Validation);
300 self
301 }
302
303 pub fn add_normalization(mut self) -> Self {
305 self.steps.push(PipelineStep::Normalization);
306 self
307 }
308
309 pub fn with_parallel_config(mut self, config: ParallelConfig) -> Self {
311 self.parallel_config = config;
312 self
313 }
314
315 pub fn build(self) -> ExplanationPipelineExecutor<Input> {
317 ExplanationPipelineExecutor {
318 steps: self.steps,
319 parallel_config: self.parallel_config,
320 _input_type: PhantomData,
321 }
322 }
323}
324
325#[derive(Debug, Clone)]
327pub enum PipelineStep {
328 Shap(ShapConfig),
330 Lime(LimeConfig),
332 Permutation(PermutationConfig),
334 Counterfactual(CounterfactualConfig),
336 Validation,
338 Normalization,
340 Custom { name: String },
342}
343
344pub struct ExplanationPipelineExecutor<Input> {
346 steps: Vec<PipelineStep>,
347 parallel_config: ParallelConfig,
348 _input_type: PhantomData<Input>,
349}
350
351impl<Input> ExplanationPipelineExecutor<Input>
352where
353 Input: Send + Sync,
354{
355 pub fn execute(&self, input: &Input) -> SklResult<PipelineExecutionResult> {
357 let mut results: Vec<Array1<Float>> = Vec::new();
358 let mut metadata: Vec<StepMetadata> = Vec::new();
359
360 for (i, step) in self.steps.iter().enumerate() {
361 let step_name = format!("Step_{}", i);
362 let start_time = std::time::Instant::now();
363
364 let result = match step {
365 PipelineStep::Shap(_config) => {
366 Ok::<Array1<Float>, crate::SklearsError>(Array1::zeros(10)) }
369 PipelineStep::Lime(_config) => {
370 Ok::<Array1<Float>, crate::SklearsError>(Array1::zeros(10)) }
373 PipelineStep::Permutation(_config) => {
374 Ok::<Array1<Float>, crate::SklearsError>(Array1::zeros(10)) }
377 PipelineStep::Counterfactual(_config) => {
378 Ok::<Array1<Float>, crate::SklearsError>(Array1::zeros(10)) }
381 PipelineStep::Validation => {
382 continue;
384 }
385 PipelineStep::Normalization => {
386 if let Some(last_result) = results.last_mut() {
388 let sum = last_result.sum();
389 if sum != 0.0 {
390 *last_result = last_result.mapv(|x| x / sum);
391 }
392 }
393 continue;
394 }
395 PipelineStep::Custom { name: _ } => {
396 Ok::<Array1<Float>, crate::SklearsError>(Array1::zeros(10)) }
399 };
400
401 let execution_time = start_time.elapsed();
402
403 match result {
404 Ok(explanation) => {
405 results.push(explanation);
406 metadata.push(StepMetadata {
407 step_name,
408 execution_time,
409 success: true,
410 error_message: None,
411 });
412 }
413 Err(e) => {
414 metadata.push(StepMetadata {
415 step_name,
416 execution_time,
417 success: false,
418 error_message: Some(e.to_string()),
419 });
420 return Err(e);
421 }
422 }
423 }
424
425 Ok(PipelineExecutionResult {
426 explanations: results,
427 metadata,
428 })
429 }
430}
431
432#[derive(Debug, Clone)]
434pub struct PipelineExecutionResult {
435 pub explanations: Vec<Array1<Float>>,
437 pub metadata: Vec<StepMetadata>,
439}
440
441#[derive(Debug, Clone)]
443pub struct StepMetadata {
444 pub step_name: String,
446 pub execution_time: std::time::Duration,
448 pub success: bool,
450 pub error_message: Option<String>,
452}
453
454#[derive(Debug, Default)]
456pub struct ComparisonStudyBuilder {
457 methods: Vec<String>,
458 datasets: Vec<String>,
459 metrics: Vec<String>,
460 parallel_config: ParallelConfig,
461}
462
463impl ComparisonStudyBuilder {
464 pub fn new() -> Self {
466 Self::default()
467 }
468
469 pub fn add_method<S: Into<String>>(mut self, method: S) -> Self {
471 self.methods.push(method.into());
472 self
473 }
474
475 pub fn add_dataset<S: Into<String>>(mut self, dataset: S) -> Self {
477 self.datasets.push(dataset.into());
478 self
479 }
480
481 pub fn add_metric<S: Into<String>>(mut self, metric: S) -> Self {
483 self.metrics.push(metric.into());
484 self
485 }
486
487 pub fn with_parallel_config(mut self, config: ParallelConfig) -> Self {
489 self.parallel_config = config;
490 self
491 }
492
493 pub fn build(self) -> ComparisonStudy {
495 ComparisonStudy {
496 methods: self.methods,
497 datasets: self.datasets,
498 metrics: self.metrics,
499 parallel_config: self.parallel_config,
500 }
501 }
502}
503
504#[derive(Debug, Clone)]
506pub struct ComparisonStudy {
507 pub methods: Vec<String>,
509 pub datasets: Vec<String>,
511 pub metrics: Vec<String>,
513 pub parallel_config: ParallelConfig,
515}
516
517impl ComparisonStudy {
518 pub fn execute(&self) -> SklResult<ComparisonResults> {
520 let mut results = Vec::new();
521
522 for method in &self.methods {
523 for dataset in &self.datasets {
524 for metric in &self.metrics {
525 let score = scirs2_core::random::thread_rng().random::<Float>(); results.push(ComparisonResult {
528 method: method.clone(),
529 dataset: dataset.clone(),
530 metric: metric.clone(),
531 score,
532 });
533 }
534 }
535 }
536
537 Ok(ComparisonResults { results })
538 }
539}
540
541#[derive(Debug, Clone)]
543pub struct ComparisonResults {
544 pub results: Vec<ComparisonResult>,
546}
547
548#[derive(Debug, Clone)]
550pub struct ComparisonResult {
551 pub method: String,
553 pub dataset: String,
555 pub metric: String,
557 pub score: Float,
559}
560
561#[cfg(test)]
562mod tests {
563 use super::*;
564 use scirs2_core::ndarray::array;
566 use sklears_core::prelude::ArrayView1;
567
568 #[test]
569 fn test_explanation_builder_creation() {
570 let builder: ExplanationBuilder<ArrayView1<Float>> = ExplanationBuilder::new();
571 assert!(builder.n_samples.is_none());
572 assert!(builder.random_state.is_none());
573 assert!(builder.validation_enabled);
574 }
575
576 #[test]
577 fn test_explanation_builder_fluent_api() {
578 let builder: ExplanationBuilder<ArrayView1<Float>> = ExplanationBuilder::new()
579 .with_n_samples(1000)
580 .with_random_state(42)
581 .with_threads(4)
582 .with_validation(false);
583
584 assert_eq!(builder.n_samples, Some(1000));
585 assert_eq!(builder.random_state, Some(42));
586 assert!(!builder.validation_enabled);
587 }
588
589 #[test]
590 fn test_shap_config_building() {
591 let config = ExplanationBuilder::<ArrayView1<Float>>::new()
592 .with_n_samples(2000)
593 .with_random_state(123)
594 .build_shap_config();
595
596 assert_eq!(config.n_samples, 2000);
597 assert_eq!(config.random_state, Some(123));
598 assert!(config.validation_enabled);
599 }
600
601 #[test]
602 fn test_lime_config_building() {
603 let config = ExplanationBuilder::<ArrayView1<Float>>::new()
604 .with_n_samples(5000)
605 .build_lime_config();
606
607 assert_eq!(config.n_samples, 5000);
608 assert_eq!(config.kernel_width, 0.75);
609 assert!(matches!(config.feature_selection, FeatureSelection::Auto));
610 }
611
612 #[test]
613 fn test_pipeline_builder_creation() {
614 let builder: PipelineBuilder<ArrayView1<Float>> = PipelineBuilder::new();
615 assert_eq!(builder.steps.len(), 0);
616 }
617
618 #[test]
619 fn test_pipeline_builder_fluent_api() {
620 let shap_config = ExplanationBuilder::<ArrayView1<Float>>::new().build_shap_config();
621 let lime_config = ExplanationBuilder::<ArrayView1<Float>>::new().build_lime_config();
622
623 let pipeline = PipelineBuilder::<ArrayView1<Float>>::new()
624 .add_shap(shap_config)
625 .add_lime(lime_config)
626 .add_validation()
627 .add_normalization()
628 .build();
629
630 assert_eq!(pipeline.steps.len(), 4);
631 }
632
633 #[test]
634 fn test_comparison_study_builder() {
635 let study = ComparisonStudyBuilder::new()
636 .add_method("SHAP")
637 .add_method("LIME")
638 .add_dataset("iris")
639 .add_dataset("wine")
640 .add_metric("fidelity")
641 .add_metric("stability")
642 .build();
643
644 assert_eq!(study.methods.len(), 2);
645 assert_eq!(study.datasets.len(), 2);
646 assert_eq!(study.metrics.len(), 2);
647 }
648
649 #[test]
650 fn test_comparison_study_execution() {
651 let study = ComparisonStudyBuilder::new()
652 .add_method("SHAP")
653 .add_dataset("iris")
654 .add_metric("fidelity")
655 .build();
656
657 let results = study.execute();
658 assert!(results.is_ok());
659
660 let comparison_results = results.unwrap();
661 assert_eq!(comparison_results.results.len(), 1);
662 assert_eq!(comparison_results.results[0].method, "SHAP");
663 assert_eq!(comparison_results.results[0].dataset, "iris");
664 assert_eq!(comparison_results.results[0].metric, "fidelity");
665 }
666
667 #[test]
668 fn test_score_function_variants() {
669 assert!(matches!(ScoreFunction::Accuracy, ScoreFunction::Accuracy));
670 assert!(matches!(ScoreFunction::R2, ScoreFunction::R2));
671 assert!(matches!(
672 ScoreFunction::MeanSquaredError,
673 ScoreFunction::MeanSquaredError
674 ));
675 assert!(matches!(
676 ScoreFunction::MeanAbsoluteError,
677 ScoreFunction::MeanAbsoluteError
678 ));
679 }
680
681 #[test]
682 fn test_optimization_method_variants() {
683 assert!(matches!(
684 OptimizationMethod::GradientDescent,
685 OptimizationMethod::GradientDescent
686 ));
687 assert!(matches!(
688 OptimizationMethod::SimulatedAnnealing,
689 OptimizationMethod::SimulatedAnnealing
690 ));
691 assert!(matches!(
692 OptimizationMethod::GeneticAlgorithm,
693 OptimizationMethod::GeneticAlgorithm
694 ));
695 }
696
697 #[test]
698 fn test_feature_selection_variants() {
699 assert!(matches!(FeatureSelection::Auto, FeatureSelection::Auto));
700 assert!(matches!(FeatureSelection::Lasso, FeatureSelection::Lasso));
701 assert!(matches!(
702 FeatureSelection::Forward,
703 FeatureSelection::Forward
704 ));
705 assert!(matches!(FeatureSelection::None, FeatureSelection::None));
706 }
707}