1use crate::error::{MetricsError, Result};
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
9use scirs2_core::numeric::Float;
10use statrs::statistics::Statistics;
11use std::collections::HashMap;
12
13pub mod feature_importance;
14pub mod global_explanations;
15pub mod local_explanations;
16pub mod uncertainty_quantification;
17
18pub use feature_importance::*;
19pub use global_explanations::*;
20pub use local_explanations::*;
21pub use uncertainty_quantification::*;
22
23#[derive(Debug, Clone)]
25pub struct ExplainabilityMetrics<F: Float> {
26 pub feature_importance: HashMap<String, F>,
28 pub local_consistency: F,
30 pub global_stability: F,
32 pub uncertainty_metrics: UncertaintyMetrics<F>,
34 pub faithfulness: F,
36 pub completeness: F,
38}
39
40#[derive(Debug, Clone)]
42pub struct UncertaintyMetrics<F: Float> {
43 pub epistemic_uncertainty: F,
45 pub aleatoric_uncertainty: F,
47 pub total_uncertainty: F,
49 pub coverage: F,
51 pub calibration_error: F,
53}
54
55pub struct ExplainabilityEvaluator<F: Float> {
57 pub n_perturbations: usize,
59 pub perturbation_strength: F,
61 pub importance_threshold: F,
63 pub confidence_level: F,
65}
66
67impl<
68 F: Float
69 + scirs2_core::numeric::FromPrimitive
70 + std::iter::Sum
71 + scirs2_core::ndarray::ScalarOperand,
72 > Default for ExplainabilityEvaluator<F>
73{
74 fn default() -> Self {
75 Self::new()
76 }
77}
78
79impl<
80 F: Float
81 + scirs2_core::numeric::FromPrimitive
82 + std::iter::Sum
83 + scirs2_core::ndarray::ScalarOperand,
84 > ExplainabilityEvaluator<F>
85{
86 pub fn new() -> Self {
88 Self {
89 n_perturbations: 100,
90 perturbation_strength: F::from(0.1).unwrap(),
91 importance_threshold: F::from(0.01).unwrap(),
92 confidence_level: F::from(0.95).unwrap(),
93 }
94 }
95
96 pub fn with_perturbations(mut self, n: usize) -> Self {
98 self.n_perturbations = n;
99 self
100 }
101
102 pub fn with_perturbation_strength(mut self, strength: F) -> Self {
104 self.perturbation_strength = strength;
105 self
106 }
107
108 pub fn with_importance_threshold(mut self, threshold: F) -> Self {
110 self.importance_threshold = threshold;
111 self
112 }
113
114 pub fn evaluate_explainability<M>(
116 &self,
117 model: &M,
118 x_test: &Array2<F>,
119 feature_names: &[String],
120 explanation_method: ExplanationMethod,
121 ) -> Result<ExplainabilityMetrics<F>>
122 where
123 M: Fn(&ArrayView2<F>) -> Array1<F>,
124 {
125 let feature_importance =
127 self.compute_feature_importance(model, x_test, feature_names, &explanation_method)?;
128
129 let local_consistency =
131 self.evaluate_local_consistency(model, x_test, &explanation_method)?;
132
133 let global_stability =
135 self.evaluate_global_stability(model, x_test, &explanation_method)?;
136
137 let uncertainty_metrics = self.compute_uncertainty_metrics(model, x_test)?;
139
140 let faithfulness = self.evaluate_faithfulness(model, x_test, &explanation_method)?;
142
143 let completeness = self.evaluate_completeness(model, x_test, &explanation_method)?;
145
146 Ok(ExplainabilityMetrics {
147 feature_importance,
148 local_consistency,
149 global_stability,
150 uncertainty_metrics,
151 faithfulness,
152 completeness,
153 })
154 }
155
156 fn compute_feature_importance<M>(
158 &self,
159 model: &M,
160 x_test: &Array2<F>,
161 feature_names: &[String],
162 method: &ExplanationMethod,
163 ) -> Result<HashMap<String, F>>
164 where
165 M: Fn(&ArrayView2<F>) -> Array1<F>,
166 {
167 let n_features = x_test.ncols();
168 let mut importance_scores = HashMap::new();
169
170 match method {
171 ExplanationMethod::Permutation => {
172 let baseline_predictions = model(&x_test.view());
174 let baseline_mean = baseline_predictions.mean().unwrap_or(F::zero());
175
176 for (i, feature_name) in feature_names.iter().enumerate() {
177 if i >= n_features {
178 continue;
179 }
180
181 let mut perturbed_errors = Vec::new();
182
183 for _ in 0..self.n_perturbations {
184 let mut x_perturbed = x_test.clone();
185 self.permute_feature(&mut x_perturbed, i)?;
187
188 let perturbed_predictions = model(&x_perturbed.view());
189 let perturbed_mean = perturbed_predictions.mean().unwrap_or(F::zero());
190 let error = (baseline_mean - perturbed_mean).abs();
191 perturbed_errors.push(error);
192 }
193
194 let importance = perturbed_errors.iter().cloned().sum::<F>()
195 / F::from(perturbed_errors.len()).unwrap();
196 importance_scores.insert(feature_name.clone(), importance);
197 }
198 }
199 ExplanationMethod::LIME => {
200 importance_scores = self.compute_lime_importance(model, x_test, feature_names)?;
202 }
203 ExplanationMethod::SHAP => {
204 importance_scores = self.compute_shap_importance(model, x_test, feature_names)?;
206 }
207 ExplanationMethod::GradientBased => {
208 importance_scores =
210 self.compute_gradient_importance(model, x_test, feature_names)?;
211 }
212 }
213
214 Ok(importance_scores)
215 }
216
217 fn evaluate_local_consistency<M>(
219 &self,
220 model: &M,
221 x_test: &Array2<F>,
222 method: &ExplanationMethod,
223 ) -> Result<F>
224 where
225 M: Fn(&ArrayView2<F>) -> Array1<F>,
226 {
227 let nsamples = x_test.nrows().min(10); let mut consistency_scores = Vec::new();
229
230 for i in 0..nsamples {
231 let sample = x_test.row(i);
232 let mut local_explanations = Vec::new();
233
234 for _ in 0..10 {
236 let mut perturbed_sample = sample.to_owned();
237 self.add_noise_to_sample(&mut perturbed_sample)?;
238
239 let explanation =
240 self.generate_local_explanation(model, &perturbed_sample.view(), method)?;
241 local_explanations.push(explanation);
242 }
243
244 let consistency = self.compute_explanation_consistency(&local_explanations)?;
246 consistency_scores.push(consistency);
247 }
248
249 let average_consistency = consistency_scores.iter().cloned().sum::<F>()
250 / F::from(consistency_scores.len()).unwrap();
251
252 Ok(average_consistency)
253 }
254
255 fn evaluate_global_stability<M>(
257 &self,
258 model: &M,
259 x_test: &Array2<F>,
260 method: &ExplanationMethod,
261 ) -> Result<F>
262 where
263 M: Fn(&ArrayView2<F>) -> Array1<F>,
264 {
265 let mut global_explanations = Vec::new();
266
267 for _ in 0..self.n_perturbations {
269 let bootstrap_indices = self.bootstrap_sample_indices(x_test.nrows())?;
270 let bootstrap_sample = self.bootstrap_data(x_test, &bootstrap_indices)?;
271
272 let global_explanation =
273 self.generate_global_explanation(model, &bootstrap_sample.view(), method)?;
274 global_explanations.push(global_explanation);
275 }
276
277 let stability = self.compute_explanation_consistency(&global_explanations)?;
279 Ok(stability)
280 }
281
282 fn compute_uncertainty_metrics<M>(
284 &self,
285 model: &M,
286 x_test: &Array2<F>,
287 ) -> Result<UncertaintyMetrics<F>>
288 where
289 M: Fn(&ArrayView2<F>) -> Array1<F>,
290 {
291 let mut predictions_ensemble = Vec::new();
293
294 for _ in 0..50 {
295 let predictions = model(&x_test.view());
297 predictions_ensemble.push(predictions);
298 }
299
300 let epistemic_uncertainty = self.compute_epistemic_uncertainty(&predictions_ensemble)?;
302
303 let aleatoric_uncertainty = self.compute_aleatoric_uncertainty(&predictions_ensemble)?;
305
306 let total_uncertainty = epistemic_uncertainty + aleatoric_uncertainty;
308
309 let coverage = F::from(0.9).unwrap(); let calibration_error = F::from(0.05).unwrap(); Ok(UncertaintyMetrics {
314 epistemic_uncertainty,
315 aleatoric_uncertainty,
316 total_uncertainty,
317 coverage,
318 calibration_error,
319 })
320 }
321
322 fn evaluate_faithfulness<M>(
324 &self,
325 model: &M,
326 x_test: &Array2<F>,
327 method: &ExplanationMethod,
328 ) -> Result<F>
329 where
330 M: Fn(&ArrayView2<F>) -> Array1<F>,
331 {
332 let nsamples = x_test.nrows().min(20);
333 let mut faithfulness_scores = Vec::new();
334
335 for i in 0..nsamples {
336 let sample = x_test.row(i);
337 let original_prediction = model(&sample.insert_axis(Axis(0)).view());
338
339 let explanation = self.generate_local_explanation(model, &sample, method)?;
341
342 let masked_sample = self.mask_important_features(&sample, &explanation, 5)?;
344 let masked_prediction = model(&masked_sample.insert_axis(Axis(0)).view());
345
346 let faithfulness = (original_prediction[0] - masked_prediction[0]).abs();
348 faithfulness_scores.push(faithfulness);
349 }
350
351 let average_faithfulness = faithfulness_scores.iter().cloned().sum::<F>()
352 / F::from(faithfulness_scores.len()).unwrap();
353
354 Ok(average_faithfulness)
355 }
356
357 fn evaluate_completeness<M>(
359 &self,
360 model: &M,
361 x_test: &Array2<F>,
362 method: &ExplanationMethod,
363 ) -> Result<F>
364 where
365 M: Fn(&ArrayView2<F>) -> Array1<F>,
366 {
367 let nsamples = x_test.nrows().min(20);
368 let mut completeness_scores = Vec::new();
369
370 for i in 0..nsamples {
371 let sample = x_test.row(i);
372 let original_prediction = model(&sample.insert_axis(Axis(0)).view());
373
374 let explanation = self.generate_local_explanation(model, &sample, method)?;
376
377 let important_only_sample =
379 self.keep_important_features_only(&sample, &explanation, 5)?;
380 let important_only_prediction =
381 model(&important_only_sample.insert_axis(Axis(0)).view());
382
383 let preservation =
385 F::one() - (original_prediction[0] - important_only_prediction[0]).abs();
386 completeness_scores.push(preservation);
387 }
388
389 let average_completeness = completeness_scores.iter().cloned().sum::<F>()
390 / F::from(completeness_scores.len()).unwrap();
391
392 Ok(average_completeness)
393 }
394
395 fn permute_feature(&self, data: &mut Array2<F>, featureindex: usize) -> Result<()> {
398 if featureindex >= data.ncols() {
399 return Err(MetricsError::InvalidInput(
400 "Feature _index out of bounds".to_string(),
401 ));
402 }
403
404 let mut feature_values: Vec<F> = data.column(featureindex).to_vec();
405
406 for i in (1..feature_values.len()).rev() {
408 let j = i % (i + 1);
409 feature_values.swap(i, j);
410 }
411
412 for (i, &value) in feature_values.iter().enumerate() {
413 data[[i, featureindex]] = value;
414 }
415
416 Ok(())
417 }
418
419 fn add_noise_to_sample(&self, sample: &mut Array1<F>) -> Result<()> {
420 for value in sample.iter_mut() {
421 let noise = self.perturbation_strength * F::from(0.01).unwrap(); *value = *value + noise;
424 }
425 Ok(())
426 }
427
428 fn generate_local_explanation<M>(
429 &self,
430 model: &M,
431 sample: &ArrayView1<F>,
432 _method: &ExplanationMethod,
433 ) -> Result<Array1<F>>
434 where
435 M: Fn(&ArrayView2<F>) -> Array1<F>,
436 {
437 let n_features = sample.len();
439 let mut importance = Array1::zeros(n_features);
440
441 let baseline_pred = model(&sample.insert_axis(Axis(0)).view())[0];
442
443 for i in 0..n_features {
444 let mut perturbed = sample.to_owned();
445 perturbed[i] = perturbed[i] + self.perturbation_strength;
446
447 let perturbed_pred = model(&perturbed.insert_axis(Axis(0)).view())[0];
448 importance[i] = (perturbed_pred - baseline_pred).abs();
449 }
450
451 Ok(importance)
452 }
453
454 fn generate_global_explanation<M>(
455 &self,
456 model: &M,
457 data: &ArrayView2<F>,
458 method: &ExplanationMethod,
459 ) -> Result<Array1<F>>
460 where
461 M: Fn(&ArrayView2<F>) -> Array1<F>,
462 {
463 let n_features = data.ncols();
464 let mut global_importance = Array1::zeros(n_features);
465
466 for i in 0..data.nrows() {
468 let sample = data.row(i);
469 let local_explanation = self.generate_local_explanation(model, &sample, method)?;
470 global_importance = global_importance + local_explanation;
471 }
472
473 global_importance = global_importance / F::from(data.nrows()).unwrap();
474 Ok(global_importance)
475 }
476
477 fn compute_explanation_consistency(&self, explanations: &[Array1<F>]) -> Result<F> {
478 if explanations.len() < 2 {
479 return Ok(F::one());
480 }
481
482 let mut correlations = Vec::new();
483
484 for i in 0..explanations.len() {
485 for j in (i + 1)..explanations.len() {
486 let correlation = self.compute_correlation(&explanations[i], &explanations[j])?;
487 correlations.push(correlation);
488 }
489 }
490
491 let average_correlation =
492 correlations.iter().cloned().sum::<F>() / F::from(correlations.len()).unwrap();
493
494 Ok(average_correlation)
495 }
496
497 fn compute_correlation(&self, x: &Array1<F>, y: &Array1<F>) -> Result<F> {
498 if x.len() != y.len() {
499 return Err(MetricsError::InvalidInput(
500 "Arrays must have the same length".to_string(),
501 ));
502 }
503
504 let mean_x = x.mean().unwrap_or(F::zero());
505 let mean_y = y.mean().unwrap_or(F::zero());
506
507 let numerator: F = x
508 .iter()
509 .zip(y.iter())
510 .map(|(&xi, &yi)| (xi - mean_x) * (yi - mean_y))
511 .sum();
512
513 let sum_sq_x: F = x.iter().map(|&xi| (xi - mean_x) * (xi - mean_x)).sum();
514 let sum_sq_y: F = y.iter().map(|&yi| (yi - mean_y) * (yi - mean_y)).sum();
515
516 let denominator = (sum_sq_x * sum_sq_y).sqrt();
517
518 if denominator == F::zero() {
519 Ok(F::zero())
520 } else {
521 Ok(numerator / denominator)
522 }
523 }
524
525 fn bootstrap_sample_indices(&self, nsamples: usize) -> Result<Vec<usize>> {
526 let mut indices = Vec::with_capacity(nsamples);
528 for i in 0..nsamples {
529 indices.push(i % nsamples);
530 }
531 Ok(indices)
532 }
533
534 fn bootstrap_data(&self, data: &Array2<F>, indices: &[usize]) -> Result<Array2<F>> {
535 let mut bootstrap_data = Array2::zeros((indices.len(), data.ncols()));
536
537 for (i, &idx) in indices.iter().enumerate() {
538 for j in 0..data.ncols() {
539 bootstrap_data[[i, j]] = data[[idx, j]];
540 }
541 }
542
543 Ok(bootstrap_data)
544 }
545
546 fn compute_epistemic_uncertainty(&self, predictions: &[Array1<F>]) -> Result<F> {
547 if predictions.is_empty() {
548 return Ok(F::zero());
549 }
550
551 let n_predictions = predictions.len();
552 let nsamples = predictions[0].len();
553
554 let mut variances = Vec::new();
555
556 for i in 0..nsamples {
557 let sample_predictions: Vec<F> = predictions.iter().map(|pred| pred[i]).collect();
558
559 let mean =
560 sample_predictions.iter().cloned().sum::<F>() / F::from(n_predictions).unwrap();
561 let variance = sample_predictions
562 .iter()
563 .map(|&pred| (pred - mean) * (pred - mean))
564 .sum::<F>()
565 / F::from(n_predictions - 1).unwrap();
566
567 variances.push(variance);
568 }
569
570 let average_variance =
571 variances.iter().cloned().sum::<F>() / F::from(variances.len()).unwrap();
572 Ok(average_variance.sqrt())
573 }
574
575 fn compute_aleatoric_uncertainty(&self, predictions: &[Array1<F>]) -> Result<F> {
576 Ok(F::from(0.1).unwrap())
579 }
580
581 fn mask_important_features(
582 &self,
583 sample: &ArrayView1<F>,
584 explanation: &Array1<F>,
585 k: usize,
586 ) -> Result<Array1<F>> {
587 let mut masked = sample.to_owned();
588
589 let mut importance_indices: Vec<(usize, F)> = explanation
591 .iter()
592 .enumerate()
593 .map(|(i, &imp)| (i, imp))
594 .collect();
595 importance_indices
596 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
597
598 for i in 0..k.min(importance_indices.len()) {
600 let feature_idx = importance_indices[i].0;
601 masked[feature_idx] = F::zero(); }
603
604 Ok(masked)
605 }
606
607 fn keep_important_features_only(
608 &self,
609 sample: &ArrayView1<F>,
610 explanation: &Array1<F>,
611 k: usize,
612 ) -> Result<Array1<F>> {
613 let mut filtered = Array1::zeros(sample.len());
614
615 let mut importance_indices: Vec<(usize, F)> = explanation
617 .iter()
618 .enumerate()
619 .map(|(i, &imp)| (i, imp))
620 .collect();
621 importance_indices
622 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
623
624 for i in 0..k.min(importance_indices.len()) {
626 let feature_idx = importance_indices[i].0;
627 filtered[feature_idx] = sample[feature_idx];
628 }
629
630 Ok(filtered)
631 }
632
633 fn compute_lime_importance<M>(
635 &self,
636 model: &M,
637 x_test: &Array2<F>,
638 feature_names: &[String],
639 ) -> Result<HashMap<String, F>>
640 where
641 M: Fn(&ArrayView2<F>) -> Array1<F>,
642 {
643 if x_test.is_empty() || feature_names.is_empty() {
644 return Err(MetricsError::InvalidInput(
645 "Empty input data or feature _names".to_string(),
646 ));
647 }
648
649 if x_test.ncols() != feature_names.len() {
650 return Err(MetricsError::InvalidInput(
651 "Number of features doesn't match feature _names length".to_string(),
652 ));
653 }
654
655 let mut importance_scores = HashMap::new();
656 let nsamples = std::cmp::min(1000, self.n_perturbations); for instance in x_test.axis_iter(Axis(0)) {
660 let instance_importance =
661 self.compute_lime_for_instance(model, &instance, feature_names, nsamples)?;
662
663 for (feature_name, importance) in instance_importance {
665 let current_score = importance_scores
666 .get(&feature_name)
667 .copied()
668 .unwrap_or(F::zero());
669 importance_scores.insert(
670 feature_name,
671 current_score + importance / F::from(x_test.nrows()).unwrap(),
672 );
673 }
674 }
675
676 Ok(importance_scores)
677 }
678
679 fn compute_lime_for_instance<M>(
681 &self,
682 model: &M,
683 instance: &ArrayView1<F>,
684 feature_names: &[String],
685 nsamples: usize,
686 ) -> Result<HashMap<String, F>>
687 where
688 M: Fn(&ArrayView2<F>) -> Array1<F>,
689 {
690 let _n_features = instance.len();
691
692 let (perturbed_samples, weights) = self.generate_lime_samples(instance, nsamples)?;
694
695 let predictions = model(&perturbed_samples.view());
697
698 let coefficients =
700 self.fit_interpretable_model(&perturbed_samples, &predictions, &weights)?;
701
702 let mut importance = HashMap::new();
704 for (i, name) in feature_names.iter().enumerate() {
705 if i < coefficients.len() {
706 importance.insert(name.clone(), coefficients[i].abs());
707 }
708 }
709
710 Ok(importance)
711 }
712
713 fn generate_lime_samples(
715 &self,
716 instance: &ArrayView1<F>,
717 nsamples: usize,
718 ) -> Result<(Array2<F>, Array1<F>)> {
719 let n_features = instance.len();
720 let mut perturbed_samples = Array2::zeros((nsamples, n_features));
721 let mut weights = Array1::zeros(nsamples);
722
723 let feature_mean = instance.mean().unwrap_or(F::zero());
725 let feature_std = {
726 let variance = instance
727 .iter()
728 .map(|&x| (x - feature_mean) * (x - feature_mean))
729 .sum::<F>()
730 / F::from(n_features).unwrap();
731 variance.sqrt()
732 };
733
734 for i in 0..nsamples {
735 let mut perturbed_instance = instance.to_owned();
736 let mut distance_sum = F::zero();
737
738 for j in 0..n_features {
740 let perturbation_factor = F::from((i + j) as f64 / (nsamples * n_features) as f64)
742 .unwrap()
743 - F::from(0.5).unwrap();
744 let perturbation = perturbation_factor * self.perturbation_strength * feature_std;
745
746 perturbed_instance[j] = instance[j] + perturbation;
747 distance_sum = distance_sum + perturbation.abs();
748 }
749
750 for j in 0..n_features {
752 perturbed_samples[[i, j]] = perturbed_instance[j];
753 }
754
755 let distance = distance_sum / F::from(n_features).unwrap();
757 weights[i] = (-distance * F::from(2.0).unwrap()).exp(); }
759
760 Ok((perturbed_samples, weights))
761 }
762
763 fn fit_interpretable_model(
765 &self,
766 samples: &Array2<F>,
767 targets: &Array1<F>,
768 weights: &Array1<F>,
769 ) -> Result<Vec<F>> {
770 let nsamples = samples.nrows();
771 let n_features = samples.ncols();
772
773 if nsamples == 0 || n_features == 0 {
774 return Ok(vec![F::zero(); n_features]);
775 }
776
777 let mut xtx = Array2::zeros((n_features, n_features));
780 let mut xty = Array1::zeros(n_features);
781
782 for i in 0..nsamples {
784 let weight = weights[i];
785 let target = targets[i];
786
787 for j in 0..n_features {
788 let x_ij = samples[[i, j]];
789
790 xty[j] = xty[j] + weight * x_ij * target;
792
793 for k in 0..n_features {
795 let x_ik = samples[[i, k]];
796 xtx[[j, k]] = xtx[[j, k]] + weight * x_ij * x_ik;
797 }
798 }
799 }
800
801 let regularization = F::from(1e-6).unwrap();
803 for i in 0..n_features {
804 xtx[[i, i]] = xtx[[i, i]] + regularization;
805 }
806
807 let coefficients = self.solve_linear_system(&xtx, &xty)?;
809
810 Ok(coefficients)
811 }
812
813 fn solve_linear_system(&self, a: &Array2<F>, b: &Array1<F>) -> Result<Vec<F>> {
815 let n = a.nrows();
816 if n != a.ncols() || n != b.len() {
817 return Err(MetricsError::InvalidInput(
818 "Matrix dimensions mismatch".to_string(),
819 ));
820 }
821
822 let mut aug = Array2::zeros((n, n + 1));
824 for i in 0..n {
825 for j in 0..n {
826 aug[[i, j]] = a[[i, j]];
827 }
828 aug[[i, n]] = b[i];
829 }
830
831 for i in 0..n {
833 let mut max_row = i;
835 for k in (i + 1)..n {
836 if aug[[k, i]].abs() > aug[[max_row, i]].abs() {
837 max_row = k;
838 }
839 }
840
841 if max_row != i {
843 for j in 0..=n {
844 let temp = aug[[i, j]];
845 aug[[i, j]] = aug[[max_row, j]];
846 aug[[max_row, j]] = temp;
847 }
848 }
849
850 if aug[[i, i]].abs() < F::from(1e-10).unwrap() {
852 return Ok(vec![F::zero(); n]);
854 }
855
856 for k in (i + 1)..n {
858 let factor = aug[[k, i]] / aug[[i, i]];
859 for j in i..=n {
860 aug[[k, j]] = aug[[k, j]] - factor * aug[[i, j]];
861 }
862 }
863 }
864
865 let mut x = vec![F::zero(); n];
867 for i in (0..n).rev() {
868 x[i] = aug[[i, n]];
869 for j in (i + 1)..n {
870 x[i] = x[i] - aug[[i, j]] * x[j];
871 }
872 x[i] = x[i] / aug[[i, i]];
873 }
874
875 Ok(x)
876 }
877
878 fn compute_shap_importance<M>(
880 &self,
881 model: &M,
882 x_test: &Array2<F>,
883 feature_names: &[String],
884 ) -> Result<HashMap<String, F>>
885 where
886 M: Fn(&ArrayView2<F>) -> Array1<F>,
887 {
888 if x_test.is_empty() || feature_names.is_empty() {
889 return Err(MetricsError::InvalidInput(
890 "Empty input data or feature _names".to_string(),
891 ));
892 }
893
894 if x_test.ncols() != feature_names.len() {
895 return Err(MetricsError::InvalidInput(
896 "Number of features doesn't match feature _names length".to_string(),
897 ));
898 }
899
900 let mut importance_scores = HashMap::new();
901
902 let background_mean = self.compute_background_mean(x_test)?;
904
905 for instance in x_test.axis_iter(Axis(0)) {
907 let instance_importance =
908 self.compute_shap_for_instance(model, &instance, &background_mean, feature_names)?;
909
910 for (feature_name, importance) in instance_importance {
912 let current_score = importance_scores
913 .get(&feature_name)
914 .copied()
915 .unwrap_or(F::zero());
916 importance_scores.insert(
917 feature_name,
918 current_score + importance / F::from(x_test.nrows()).unwrap(),
919 );
920 }
921 }
922
923 Ok(importance_scores)
924 }
925
926 fn compute_shap_for_instance<M>(
928 &self,
929 model: &M,
930 instance: &ArrayView1<F>,
931 background_mean: &Array1<F>,
932 feature_names: &[String],
933 ) -> Result<HashMap<String, F>>
934 where
935 M: Fn(&ArrayView2<F>) -> Array1<F>,
936 {
937 let n_features = instance.len();
938
939 let max_coalitions = std::cmp::min(
942 2_usize.pow(std::cmp::min(n_features, 10) as u32),
943 self.n_perturbations,
944 );
945
946 let shapley_values = self.compute_shapley_values_approximation(
947 model,
948 instance,
949 background_mean,
950 max_coalitions,
951 )?;
952
953 let mut importance = HashMap::new();
955 for (i, name) in feature_names.iter().enumerate() {
956 if i < shapley_values.len() {
957 importance.insert(name.clone(), shapley_values[i].abs());
958 }
959 }
960
961 Ok(importance)
962 }
963
964 fn compute_background_mean(&self, xdata: &Array2<F>) -> Result<Array1<F>> {
966 if xdata.is_empty() {
967 return Err(MetricsError::InvalidInput(
968 "Empty _data for background computation".to_string(),
969 ));
970 }
971
972 let n_features = xdata.ncols();
973 let mut background = Array1::zeros(n_features);
974
975 for j in 0..n_features {
976 let column_sum: F = xdata.column(j).iter().cloned().sum();
977 background[j] = column_sum / F::from(xdata.nrows()).unwrap();
978 }
979
980 Ok(background)
981 }
982
983 fn compute_shapley_values_approximation<M>(
985 &self,
986 model: &M,
987 instance: &ArrayView1<F>,
988 background: &Array1<F>,
989 max_coalitions: usize,
990 ) -> Result<Vec<F>>
991 where
992 M: Fn(&ArrayView2<F>) -> Array1<F>,
993 {
994 let n_features = instance.len();
995 let mut shapley_values = vec![F::zero(); n_features];
996
997 let baseline_input =
999 Array2::from_shape_vec((1, n_features), background.to_vec()).map_err(|_| {
1000 MetricsError::InvalidInput("Failed to create baseline array".to_string())
1001 })?;
1002 let baseline_pred = model(&baseline_input.view())[0];
1003
1004 let full_input = Array2::from_shape_vec((1, n_features), instance.to_vec())
1006 .map_err(|_| MetricsError::InvalidInput("Failed to create full array".to_string()))?;
1007 let full_pred = model(&full_input.view())[0];
1008
1009 let nsamples = std::cmp::min(max_coalitions, 1000);
1011
1012 for i in 0..n_features {
1013 let mut marginal_contributions = Vec::new();
1014
1015 for sample_idx in 0..nsamples {
1017 let coalition = self.generate_random_coalition(n_features, i, sample_idx);
1018
1019 let with_i =
1021 self.create_coalition_input(instance, background, &coalition, Some(i))?;
1022 let pred_with_i = model(&with_i.view())[0];
1023
1024 let without_i =
1026 self.create_coalition_input(instance, background, &coalition, None)?;
1027 let pred_without_i = model(&without_i.view())[0];
1028
1029 let marginal_contrib = pred_with_i - pred_without_i;
1031 marginal_contributions.push(marginal_contrib);
1032 }
1033
1034 if !marginal_contributions.is_empty() {
1036 let sum: F = marginal_contributions.iter().cloned().sum();
1037 shapley_values[i] = sum / F::from(marginal_contributions.len()).unwrap();
1038 }
1039 }
1040
1041 let total_difference = full_pred - baseline_pred;
1044 let shapley_sum: F = shapley_values.iter().cloned().sum();
1045
1046 if shapley_sum != F::zero() {
1047 let normalization_factor = total_difference / shapley_sum;
1048 for val in shapley_values.iter_mut() {
1049 *val = *val * normalization_factor;
1050 }
1051 }
1052
1053 Ok(shapley_values)
1054 }
1055
1056 fn generate_random_coalition(
1058 &self,
1059 n_features: usize,
1060 target_feature: usize,
1061 seed: usize,
1062 ) -> Vec<bool> {
1063 let mut coalition = vec![false; n_features];
1064
1065 let mut pseudo_random = seed;
1067
1068 for i in 0..n_features {
1069 if i != target_feature {
1070 pseudo_random = pseudo_random.wrapping_mul(1103515245).wrapping_add(12345);
1071 coalition[i] = pseudo_random.is_multiple_of(2);
1072 }
1073 }
1074
1075 coalition
1076 }
1077
1078 fn create_coalition_input(
1080 &self,
1081 instance: &ArrayView1<F>,
1082 background: &Array1<F>,
1083 coalition: &[bool],
1084 include_target: Option<usize>,
1085 ) -> Result<Array2<F>> {
1086 let n_features = instance.len();
1087 let mut coalition_input = background.clone();
1088
1089 for (i, &in_coalition) in coalition.iter().enumerate() {
1091 if in_coalition {
1092 coalition_input[i] = instance[i];
1093 }
1094 }
1095
1096 if let Some(target_idx) = include_target {
1098 if target_idx < n_features {
1099 coalition_input[target_idx] = instance[target_idx];
1100 }
1101 }
1102
1103 Array2::from_shape_vec((1, n_features), coalition_input.to_vec()).map_err(|_| {
1105 MetricsError::InvalidInput("Failed to create coalition input array".to_string())
1106 })
1107 }
1108
1109 fn compute_gradient_importance<M>(
1111 &self,
1112 model: &M,
1113 x_test: &Array2<F>,
1114 feature_names: &[String],
1115 ) -> Result<HashMap<String, F>>
1116 where
1117 M: Fn(&ArrayView2<F>) -> Array1<F>,
1118 {
1119 if x_test.is_empty() || feature_names.is_empty() {
1120 return Err(MetricsError::InvalidInput(
1121 "Empty input data or feature _names".to_string(),
1122 ));
1123 }
1124
1125 if x_test.ncols() != feature_names.len() {
1126 return Err(MetricsError::InvalidInput(
1127 "Number of features doesn't match feature _names length".to_string(),
1128 ));
1129 }
1130
1131 let mut importance_scores = HashMap::new();
1132
1133 for instance in x_test.axis_iter(Axis(0)) {
1135 let instance_importance =
1136 self.compute_gradient_for_instance(model, &instance, feature_names)?;
1137
1138 for (feature_name, importance) in instance_importance {
1140 let current_score = importance_scores
1141 .get(&feature_name)
1142 .copied()
1143 .unwrap_or(F::zero());
1144 importance_scores.insert(
1145 feature_name,
1146 current_score + importance / F::from(x_test.nrows()).unwrap(),
1147 );
1148 }
1149 }
1150
1151 Ok(importance_scores)
1152 }
1153
1154 fn compute_gradient_for_instance<M>(
1156 &self,
1157 model: &M,
1158 instance: &ArrayView1<F>,
1159 feature_names: &[String],
1160 ) -> Result<HashMap<String, F>>
1161 where
1162 M: Fn(&ArrayView2<F>) -> Array1<F>,
1163 {
1164 let n_features = instance.len();
1165
1166 let gradients = self.compute_numerical_gradients(model, instance)?;
1168
1169 let saliency_map = self.compute_saliency_map(&gradients, instance)?;
1171 let integrated_gradients = self.compute_integrated_gradients(model, instance)?;
1172 let gradient_times_input = self.compute_gradient_times_input(&gradients, instance)?;
1173
1174 let mut importance = HashMap::new();
1176 for (i, name) in feature_names.iter().enumerate() {
1177 if i < n_features {
1178 let combined_importance =
1179 (saliency_map[i] + integrated_gradients[i] + gradient_times_input[i])
1180 / F::from(3.0).unwrap();
1181 importance.insert(name.clone(), combined_importance.abs());
1182 }
1183 }
1184
1185 Ok(importance)
1186 }
1187
1188 fn compute_numerical_gradients<M>(&self, model: &M, instance: &ArrayView1<F>) -> Result<Vec<F>>
1190 where
1191 M: Fn(&ArrayView2<F>) -> Array1<F>,
1192 {
1193 let n_features = instance.len();
1194 let mut gradients = vec![F::zero(); n_features];
1195
1196 let epsilon_base = F::from(1e-5).unwrap();
1198
1199 let baseline_input =
1201 Array2::from_shape_vec((1, n_features), instance.to_vec()).map_err(|_| {
1202 MetricsError::InvalidInput("Failed to create baseline array".to_string())
1203 })?;
1204 let _baseline_pred = model(&baseline_input.view())[0];
1205
1206 for i in 0..n_features {
1208 let feature_magnitude = instance[i].abs().max(F::from(1.0).unwrap());
1209 let epsilon = epsilon_base * feature_magnitude;
1210
1211 let mut forward_instance = instance.to_owned();
1213 forward_instance[i] = forward_instance[i] + epsilon;
1214 let forward_input = Array2::from_shape_vec((1, n_features), forward_instance.to_vec())
1215 .map_err(|_| {
1216 MetricsError::InvalidInput("Failed to create forward array".to_string())
1217 })?;
1218 let forward_pred = model(&forward_input.view())[0];
1219
1220 let mut backward_instance = instance.to_owned();
1222 backward_instance[i] = backward_instance[i] - epsilon;
1223 let backward_input =
1224 Array2::from_shape_vec((1, n_features), backward_instance.to_vec()).map_err(
1225 |_| MetricsError::InvalidInput("Failed to create backward array".to_string()),
1226 )?;
1227 let backward_pred = model(&backward_input.view())[0];
1228
1229 gradients[i] = (forward_pred - backward_pred) / (F::from(2.0).unwrap() * epsilon);
1231 }
1232
1233 Ok(gradients)
1234 }
1235
1236 fn compute_saliency_map(&self, gradients: &[F], instance: &ArrayView1<F>) -> Result<Vec<F>> {
1238 Ok(gradients.iter().map(|&g| g.abs()).collect())
1240 }
1241
1242 fn compute_integrated_gradients<M>(&self, model: &M, instance: &ArrayView1<F>) -> Result<Vec<F>>
1244 where
1245 M: Fn(&ArrayView2<F>) -> Array1<F>,
1246 {
1247 let n_features = instance.len();
1248 let mut integrated_grads = vec![F::zero(); n_features];
1249
1250 let baseline = Array1::zeros(n_features);
1252 let n_steps = 50; for step in 0..n_steps {
1256 let alpha = F::from(step as f64).unwrap() / F::from(n_steps as f64).unwrap();
1257
1258 let mut interpolated = Array1::zeros(n_features);
1260 for i in 0..n_features {
1261 interpolated[i] = baseline[i] + alpha * (instance[i] - baseline[i]);
1262 }
1263
1264 let step_gradients = self.compute_numerical_gradients(model, &interpolated.view())?;
1266
1267 for i in 0..n_features {
1269 integrated_grads[i] =
1270 integrated_grads[i] + step_gradients[i] * (instance[i] - baseline[i]);
1271 }
1272 }
1273
1274 for grad in integrated_grads.iter_mut() {
1276 *grad = *grad / F::from(n_steps).unwrap();
1277 }
1278
1279 Ok(integrated_grads)
1280 }
1281
1282 fn compute_gradient_times_input(
1284 &self,
1285 gradients: &[F],
1286 instance: &ArrayView1<F>,
1287 ) -> Result<Vec<F>> {
1288 let mut grad_times_input = Vec::new();
1289
1290 for (i, &grad) in gradients.iter().enumerate() {
1291 if i < instance.len() {
1292 grad_times_input.push(grad * instance[i]);
1293 }
1294 }
1295
1296 Ok(grad_times_input)
1297 }
1298}
1299
1300#[derive(Debug, Clone)]
1302pub enum ExplanationMethod {
1303 Permutation,
1305 LIME,
1307 SHAP,
1309 GradientBased,
1311}
1312
1313#[allow(dead_code)]
1315pub fn compute_interpretability_score<F: Float + std::iter::Sum>(
1316 explainability_metrics: &ExplainabilityMetrics<F>,
1317) -> F {
1318 let feature_importance_score = if explainability_metrics.feature_importance.is_empty() {
1320 F::zero()
1321 } else {
1322 explainability_metrics
1323 .feature_importance
1324 .values()
1325 .cloned()
1326 .sum::<F>()
1327 / F::from(explainability_metrics.feature_importance.len()).unwrap()
1328 };
1329
1330 let weights = [
1331 F::from(0.25).unwrap(), F::from(0.2).unwrap(), F::from(0.2).unwrap(), F::from(0.15).unwrap(), F::from(0.15).unwrap(), F::from(0.05).unwrap(), ];
1338
1339 let scores = [
1340 feature_importance_score,
1341 explainability_metrics.local_consistency,
1342 explainability_metrics.global_stability,
1343 explainability_metrics.faithfulness,
1344 explainability_metrics.completeness,
1345 F::one() - explainability_metrics.uncertainty_metrics.total_uncertainty, ];
1347
1348 weights
1349 .iter()
1350 .zip(scores.iter())
1351 .map(|(&w, &s)| w * s)
1352 .sum()
1353}
1354
1355#[cfg(test)]
1356mod tests {
1357 use super::*;
1358 use scirs2_core::ndarray::array;
1359
1360 #[test]
1361 fn test_explainability_evaluator_creation() {
1362 let evaluator = ExplainabilityEvaluator::<f64>::new()
1363 .with_perturbations(50)
1364 .with_perturbation_strength(0.05)
1365 .with_importance_threshold(0.02);
1366
1367 assert_eq!(evaluator.n_perturbations, 50);
1368 assert_eq!(evaluator.perturbation_strength, 0.05);
1369 assert_eq!(evaluator.importance_threshold, 0.02);
1370 }
1371
1372 #[test]
1373 fn test_correlation_computation() {
1374 let evaluator = ExplainabilityEvaluator::<f64>::new();
1375
1376 let x = array![1.0, 2.0, 3.0, 4.0, 5.0];
1377 let y = array![2.0, 4.0, 6.0, 8.0, 10.0]; let correlation = evaluator.compute_correlation(&x, &y).unwrap();
1380 assert!((correlation - 1.0).abs() < 1e-10);
1381 }
1382
1383 #[test]
1384 fn test_permutation_feature() {
1385 let evaluator = ExplainabilityEvaluator::<f64>::new();
1386 let mut data = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
1387 let original_data = data.clone();
1388
1389 evaluator.permute_feature(&mut data, 1).unwrap();
1390
1391 assert_eq!(data.column(0), original_data.column(0));
1393 assert_eq!(data.column(2), original_data.column(2));
1394 assert_eq!(data.column(1).len(), original_data.column(1).len());
1396 }
1397
1398 #[test]
1399 fn test_interpretability_score() {
1400 let mut feature_importance = HashMap::new();
1401 feature_importance.insert("feature1".to_string(), 0.5);
1402 feature_importance.insert("feature2".to_string(), 0.3);
1403
1404 let metrics = ExplainabilityMetrics {
1405 feature_importance,
1406 local_consistency: 0.8,
1407 global_stability: 0.7,
1408 uncertainty_metrics: UncertaintyMetrics {
1409 epistemic_uncertainty: 0.1,
1410 aleatoric_uncertainty: 0.05,
1411 total_uncertainty: 0.15,
1412 coverage: 0.95,
1413 calibration_error: 0.02,
1414 },
1415 faithfulness: 0.9,
1416 completeness: 0.85,
1417 };
1418
1419 let score = compute_interpretability_score(&metrics);
1420 assert!(score > 0.0 && score <= 1.0);
1421 }
1422
1423 #[test]
1424 fn test_bootstrap_sampling() {
1425 let evaluator = ExplainabilityEvaluator::<f64>::new();
1426 let indices = evaluator.bootstrap_sample_indices(10).unwrap();
1427
1428 assert_eq!(indices.len(), 10);
1429 assert!(indices.iter().all(|&i| i < 10));
1431 }
1432
1433 #[test]
1434 fn test_mask_important_features() {
1435 let evaluator = ExplainabilityEvaluator::<f64>::new();
1436 let sample = array![1.0, 2.0, 3.0, 4.0, 5.0];
1437 let explanation = array![0.1, 0.5, 0.2, 0.8, 0.3]; let masked = evaluator
1440 .mask_important_features(&sample.view(), &explanation, 2)
1441 .unwrap();
1442
1443 assert_eq!(masked[3], 0.0);
1445 assert_eq!(masked[1], 0.0);
1446 assert_eq!(masked[0], 1.0);
1448 assert_eq!(masked[2], 3.0);
1449 assert_eq!(masked[4], 5.0);
1450 }
1451}