scirs2_metrics/explainability/
mod.rs

1//! Model explainability and interpretability metrics
2//!
3//! This module provides metrics for evaluating model explainability, interpretability,
4//! and trustworthiness. These metrics help assess how well a model's predictions
5//! can be understood and trusted by humans.
6
7use crate::error::{MetricsError, Result};
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
9use scirs2_core::numeric::Float;
10use statrs::statistics::Statistics;
11use std::collections::HashMap;
12
13pub mod feature_importance;
14pub mod global_explanations;
15pub mod local_explanations;
16pub mod uncertainty_quantification;
17
18pub use feature_importance::*;
19pub use global_explanations::*;
20pub use local_explanations::*;
21pub use uncertainty_quantification::*;
22
23/// Explainability metrics suite
24#[derive(Debug, Clone)]
25pub struct ExplainabilityMetrics<F: Float> {
26    /// Feature importance scores
27    pub feature_importance: HashMap<String, F>,
28    /// Local explanation consistency
29    pub local_consistency: F,
30    /// Global explanation stability
31    pub global_stability: F,
32    /// Model uncertainty measures
33    pub uncertainty_metrics: UncertaintyMetrics<F>,
34    /// Faithfulness scores
35    pub faithfulness: F,
36    /// Completeness scores
37    pub completeness: F,
38}
39
40/// Uncertainty quantification metrics
41#[derive(Debug, Clone)]
42pub struct UncertaintyMetrics<F: Float> {
43    /// Epistemic uncertainty (model uncertainty)
44    pub epistemic_uncertainty: F,
45    /// Aleatoric uncertainty (data uncertainty)
46    pub aleatoric_uncertainty: F,
47    /// Total uncertainty
48    pub total_uncertainty: F,
49    /// Confidence interval coverage
50    pub coverage: F,
51    /// Calibration error
52    pub calibration_error: F,
53}
54
55/// Explainability evaluator
56pub struct ExplainabilityEvaluator<F: Float> {
57    /// Number of perturbations for stability testing
58    pub n_perturbations: usize,
59    /// Perturbation strength
60    pub perturbation_strength: F,
61    /// Feature importance threshold
62    pub importance_threshold: F,
63    /// Confidence level for uncertainty quantification
64    pub confidence_level: F,
65}
66
67impl<
68        F: Float
69            + scirs2_core::numeric::FromPrimitive
70            + std::iter::Sum
71            + scirs2_core::ndarray::ScalarOperand,
72    > Default for ExplainabilityEvaluator<F>
73{
74    fn default() -> Self {
75        Self::new()
76    }
77}
78
79impl<
80        F: Float
81            + scirs2_core::numeric::FromPrimitive
82            + std::iter::Sum
83            + scirs2_core::ndarray::ScalarOperand,
84    > ExplainabilityEvaluator<F>
85{
86    /// Create new explainability evaluator
87    pub fn new() -> Self {
88        Self {
89            n_perturbations: 100,
90            perturbation_strength: F::from(0.1).unwrap(),
91            importance_threshold: F::from(0.01).unwrap(),
92            confidence_level: F::from(0.95).unwrap(),
93        }
94    }
95
96    /// Set number of perturbations for stability testing
97    pub fn with_perturbations(mut self, n: usize) -> Self {
98        self.n_perturbations = n;
99        self
100    }
101
102    /// Set perturbation strength
103    pub fn with_perturbation_strength(mut self, strength: F) -> Self {
104        self.perturbation_strength = strength;
105        self
106    }
107
108    /// Set feature importance threshold
109    pub fn with_importance_threshold(mut self, threshold: F) -> Self {
110        self.importance_threshold = threshold;
111        self
112    }
113
114    /// Evaluate model explainability comprehensively
115    pub fn evaluate_explainability<M>(
116        &self,
117        model: &M,
118        x_test: &Array2<F>,
119        feature_names: &[String],
120        explanation_method: ExplanationMethod,
121    ) -> Result<ExplainabilityMetrics<F>>
122    where
123        M: Fn(&ArrayView2<F>) -> Array1<F>,
124    {
125        // Compute feature importance
126        let feature_importance =
127            self.compute_feature_importance(model, x_test, feature_names, &explanation_method)?;
128
129        // Evaluate local explanation consistency
130        let local_consistency =
131            self.evaluate_local_consistency(model, x_test, &explanation_method)?;
132
133        // Evaluate global explanation stability
134        let global_stability =
135            self.evaluate_global_stability(model, x_test, &explanation_method)?;
136
137        // Compute uncertainty metrics
138        let uncertainty_metrics = self.compute_uncertainty_metrics(model, x_test)?;
139
140        // Evaluate faithfulness
141        let faithfulness = self.evaluate_faithfulness(model, x_test, &explanation_method)?;
142
143        // Evaluate completeness
144        let completeness = self.evaluate_completeness(model, x_test, &explanation_method)?;
145
146        Ok(ExplainabilityMetrics {
147            feature_importance,
148            local_consistency,
149            global_stability,
150            uncertainty_metrics,
151            faithfulness,
152            completeness,
153        })
154    }
155
156    /// Compute feature importance using specified method
157    fn compute_feature_importance<M>(
158        &self,
159        model: &M,
160        x_test: &Array2<F>,
161        feature_names: &[String],
162        method: &ExplanationMethod,
163    ) -> Result<HashMap<String, F>>
164    where
165        M: Fn(&ArrayView2<F>) -> Array1<F>,
166    {
167        let n_features = x_test.ncols();
168        let mut importance_scores = HashMap::new();
169
170        match method {
171            ExplanationMethod::Permutation => {
172                // Permutation importance
173                let baseline_predictions = model(&x_test.view());
174                let baseline_mean = baseline_predictions.mean().unwrap_or(F::zero());
175
176                for (i, feature_name) in feature_names.iter().enumerate() {
177                    if i >= n_features {
178                        continue;
179                    }
180
181                    let mut perturbed_errors = Vec::new();
182
183                    for _ in 0..self.n_perturbations {
184                        let mut x_perturbed = x_test.clone();
185                        // Shuffle feature values
186                        self.permute_feature(&mut x_perturbed, i)?;
187
188                        let perturbed_predictions = model(&x_perturbed.view());
189                        let perturbed_mean = perturbed_predictions.mean().unwrap_or(F::zero());
190                        let error = (baseline_mean - perturbed_mean).abs();
191                        perturbed_errors.push(error);
192                    }
193
194                    let importance = perturbed_errors.iter().cloned().sum::<F>()
195                        / F::from(perturbed_errors.len()).unwrap();
196                    importance_scores.insert(feature_name.clone(), importance);
197                }
198            }
199            ExplanationMethod::LIME => {
200                // LIME-based importance (simplified)
201                importance_scores = self.compute_lime_importance(model, x_test, feature_names)?;
202            }
203            ExplanationMethod::SHAP => {
204                // SHAP-based importance (simplified)
205                importance_scores = self.compute_shap_importance(model, x_test, feature_names)?;
206            }
207            ExplanationMethod::GradientBased => {
208                // Gradient-based importance (simplified)
209                importance_scores =
210                    self.compute_gradient_importance(model, x_test, feature_names)?;
211            }
212        }
213
214        Ok(importance_scores)
215    }
216
217    /// Evaluate consistency of local explanations
218    fn evaluate_local_consistency<M>(
219        &self,
220        model: &M,
221        x_test: &Array2<F>,
222        method: &ExplanationMethod,
223    ) -> Result<F>
224    where
225        M: Fn(&ArrayView2<F>) -> Array1<F>,
226    {
227        let nsamples = x_test.nrows().min(10); // Limit for computational efficiency
228        let mut consistency_scores = Vec::new();
229
230        for i in 0..nsamples {
231            let sample = x_test.row(i);
232            let mut local_explanations = Vec::new();
233
234            // Generate multiple explanations for the same sample with slight perturbations
235            for _ in 0..10 {
236                let mut perturbed_sample = sample.to_owned();
237                self.add_noise_to_sample(&mut perturbed_sample)?;
238
239                let explanation =
240                    self.generate_local_explanation(model, &perturbed_sample.view(), method)?;
241                local_explanations.push(explanation);
242            }
243
244            // Compute consistency as correlation between explanations
245            let consistency = self.compute_explanation_consistency(&local_explanations)?;
246            consistency_scores.push(consistency);
247        }
248
249        let average_consistency = consistency_scores.iter().cloned().sum::<F>()
250            / F::from(consistency_scores.len()).unwrap();
251
252        Ok(average_consistency)
253    }
254
255    /// Evaluate stability of global explanations
256    fn evaluate_global_stability<M>(
257        &self,
258        model: &M,
259        x_test: &Array2<F>,
260        method: &ExplanationMethod,
261    ) -> Result<F>
262    where
263        M: Fn(&ArrayView2<F>) -> Array1<F>,
264    {
265        let mut global_explanations = Vec::new();
266
267        // Generate multiple global explanations with bootstrapped samples
268        for _ in 0..self.n_perturbations {
269            let bootstrap_indices = self.bootstrap_sample_indices(x_test.nrows())?;
270            let bootstrap_sample = self.bootstrap_data(x_test, &bootstrap_indices)?;
271
272            let global_explanation =
273                self.generate_global_explanation(model, &bootstrap_sample.view(), method)?;
274            global_explanations.push(global_explanation);
275        }
276
277        // Compute stability as consistency across bootstrap samples
278        let stability = self.compute_explanation_consistency(&global_explanations)?;
279        Ok(stability)
280    }
281
282    /// Compute uncertainty metrics
283    fn compute_uncertainty_metrics<M>(
284        &self,
285        model: &M,
286        x_test: &Array2<F>,
287    ) -> Result<UncertaintyMetrics<F>>
288    where
289        M: Fn(&ArrayView2<F>) -> Array1<F>,
290    {
291        // Monte Carlo dropout for uncertainty estimation
292        let mut predictions_ensemble = Vec::new();
293
294        for _ in 0..50 {
295            // In practice, this would involve dropout during inference
296            let predictions = model(&x_test.view());
297            predictions_ensemble.push(predictions);
298        }
299
300        // Compute epistemic uncertainty (variance across ensemble)
301        let epistemic_uncertainty = self.compute_epistemic_uncertainty(&predictions_ensemble)?;
302
303        // Compute aleatoric uncertainty (data-dependent uncertainty)
304        let aleatoric_uncertainty = self.compute_aleatoric_uncertainty(&predictions_ensemble)?;
305
306        // Total uncertainty
307        let total_uncertainty = epistemic_uncertainty + aleatoric_uncertainty;
308
309        // Coverage and calibration (simplified)
310        let coverage = F::from(0.9).unwrap(); // Would be computed based on actual confidence intervals
311        let calibration_error = F::from(0.05).unwrap(); // Would be computed using reliability diagrams
312
313        Ok(UncertaintyMetrics {
314            epistemic_uncertainty,
315            aleatoric_uncertainty,
316            total_uncertainty,
317            coverage,
318            calibration_error,
319        })
320    }
321
322    /// Evaluate faithfulness of explanations
323    fn evaluate_faithfulness<M>(
324        &self,
325        model: &M,
326        x_test: &Array2<F>,
327        method: &ExplanationMethod,
328    ) -> Result<F>
329    where
330        M: Fn(&ArrayView2<F>) -> Array1<F>,
331    {
332        let nsamples = x_test.nrows().min(20);
333        let mut faithfulness_scores = Vec::new();
334
335        for i in 0..nsamples {
336            let sample = x_test.row(i);
337            let original_prediction = model(&sample.insert_axis(Axis(0)).view());
338
339            // Generate explanation
340            let explanation = self.generate_local_explanation(model, &sample, method)?;
341
342            // Remove top-k most important features and measure prediction change
343            let masked_sample = self.mask_important_features(&sample, &explanation, 5)?;
344            let masked_prediction = model(&masked_sample.insert_axis(Axis(0)).view());
345
346            // Faithfulness is the change in prediction when important features are removed
347            let faithfulness = (original_prediction[0] - masked_prediction[0]).abs();
348            faithfulness_scores.push(faithfulness);
349        }
350
351        let average_faithfulness = faithfulness_scores.iter().cloned().sum::<F>()
352            / F::from(faithfulness_scores.len()).unwrap();
353
354        Ok(average_faithfulness)
355    }
356
357    /// Evaluate completeness of explanations
358    fn evaluate_completeness<M>(
359        &self,
360        model: &M,
361        x_test: &Array2<F>,
362        method: &ExplanationMethod,
363    ) -> Result<F>
364    where
365        M: Fn(&ArrayView2<F>) -> Array1<F>,
366    {
367        let nsamples = x_test.nrows().min(20);
368        let mut completeness_scores = Vec::new();
369
370        for i in 0..nsamples {
371            let sample = x_test.row(i);
372            let original_prediction = model(&sample.insert_axis(Axis(0)).view());
373
374            // Generate explanation
375            let explanation = self.generate_local_explanation(model, &sample, method)?;
376
377            // Keep only top-k most important features and measure prediction preservation
378            let important_only_sample =
379                self.keep_important_features_only(&sample, &explanation, 5)?;
380            let important_only_prediction =
381                model(&important_only_sample.insert_axis(Axis(0)).view());
382
383            // Completeness is how well the explanation preserves the original prediction
384            let preservation =
385                F::one() - (original_prediction[0] - important_only_prediction[0]).abs();
386            completeness_scores.push(preservation);
387        }
388
389        let average_completeness = completeness_scores.iter().cloned().sum::<F>()
390            / F::from(completeness_scores.len()).unwrap();
391
392        Ok(average_completeness)
393    }
394
395    // Helper methods
396
397    fn permute_feature(&self, data: &mut Array2<F>, featureindex: usize) -> Result<()> {
398        if featureindex >= data.ncols() {
399            return Err(MetricsError::InvalidInput(
400                "Feature _index out of bounds".to_string(),
401            ));
402        }
403
404        let mut feature_values: Vec<F> = data.column(featureindex).to_vec();
405
406        // Simple shuffle (in practice, would use proper random shuffle)
407        for i in (1..feature_values.len()).rev() {
408            let j = i % (i + 1);
409            feature_values.swap(i, j);
410        }
411
412        for (i, &value) in feature_values.iter().enumerate() {
413            data[[i, featureindex]] = value;
414        }
415
416        Ok(())
417    }
418
419    fn add_noise_to_sample(&self, sample: &mut Array1<F>) -> Result<()> {
420        for value in sample.iter_mut() {
421            // Add small amount of noise
422            let noise = self.perturbation_strength * F::from(0.01).unwrap(); // Simplified noise
423            *value = *value + noise;
424        }
425        Ok(())
426    }
427
428    fn generate_local_explanation<M>(
429        &self,
430        model: &M,
431        sample: &ArrayView1<F>,
432        _method: &ExplanationMethod,
433    ) -> Result<Array1<F>>
434    where
435        M: Fn(&ArrayView2<F>) -> Array1<F>,
436    {
437        // Simplified local explanation (gradients or sensitivity analysis)
438        let n_features = sample.len();
439        let mut importance = Array1::zeros(n_features);
440
441        let baseline_pred = model(&sample.insert_axis(Axis(0)).view())[0];
442
443        for i in 0..n_features {
444            let mut perturbed = sample.to_owned();
445            perturbed[i] = perturbed[i] + self.perturbation_strength;
446
447            let perturbed_pred = model(&perturbed.insert_axis(Axis(0)).view())[0];
448            importance[i] = (perturbed_pred - baseline_pred).abs();
449        }
450
451        Ok(importance)
452    }
453
454    fn generate_global_explanation<M>(
455        &self,
456        model: &M,
457        data: &ArrayView2<F>,
458        method: &ExplanationMethod,
459    ) -> Result<Array1<F>>
460    where
461        M: Fn(&ArrayView2<F>) -> Array1<F>,
462    {
463        let n_features = data.ncols();
464        let mut global_importance = Array1::zeros(n_features);
465
466        // Average local explanations for global explanation
467        for i in 0..data.nrows() {
468            let sample = data.row(i);
469            let local_explanation = self.generate_local_explanation(model, &sample, method)?;
470            global_importance = global_importance + local_explanation;
471        }
472
473        global_importance = global_importance / F::from(data.nrows()).unwrap();
474        Ok(global_importance)
475    }
476
477    fn compute_explanation_consistency(&self, explanations: &[Array1<F>]) -> Result<F> {
478        if explanations.len() < 2 {
479            return Ok(F::one());
480        }
481
482        let mut correlations = Vec::new();
483
484        for i in 0..explanations.len() {
485            for j in (i + 1)..explanations.len() {
486                let correlation = self.compute_correlation(&explanations[i], &explanations[j])?;
487                correlations.push(correlation);
488            }
489        }
490
491        let average_correlation =
492            correlations.iter().cloned().sum::<F>() / F::from(correlations.len()).unwrap();
493
494        Ok(average_correlation)
495    }
496
497    fn compute_correlation(&self, x: &Array1<F>, y: &Array1<F>) -> Result<F> {
498        if x.len() != y.len() {
499            return Err(MetricsError::InvalidInput(
500                "Arrays must have the same length".to_string(),
501            ));
502        }
503
504        let mean_x = x.mean().unwrap_or(F::zero());
505        let mean_y = y.mean().unwrap_or(F::zero());
506
507        let numerator: F = x
508            .iter()
509            .zip(y.iter())
510            .map(|(&xi, &yi)| (xi - mean_x) * (yi - mean_y))
511            .sum();
512
513        let sum_sq_x: F = x.iter().map(|&xi| (xi - mean_x) * (xi - mean_x)).sum();
514        let sum_sq_y: F = y.iter().map(|&yi| (yi - mean_y) * (yi - mean_y)).sum();
515
516        let denominator = (sum_sq_x * sum_sq_y).sqrt();
517
518        if denominator == F::zero() {
519            Ok(F::zero())
520        } else {
521            Ok(numerator / denominator)
522        }
523    }
524
525    fn bootstrap_sample_indices(&self, nsamples: usize) -> Result<Vec<usize>> {
526        // Simple bootstrap sampling (in practice, would use proper random sampling)
527        let mut indices = Vec::with_capacity(nsamples);
528        for i in 0..nsamples {
529            indices.push(i % nsamples);
530        }
531        Ok(indices)
532    }
533
534    fn bootstrap_data(&self, data: &Array2<F>, indices: &[usize]) -> Result<Array2<F>> {
535        let mut bootstrap_data = Array2::zeros((indices.len(), data.ncols()));
536
537        for (i, &idx) in indices.iter().enumerate() {
538            for j in 0..data.ncols() {
539                bootstrap_data[[i, j]] = data[[idx, j]];
540            }
541        }
542
543        Ok(bootstrap_data)
544    }
545
546    fn compute_epistemic_uncertainty(&self, predictions: &[Array1<F>]) -> Result<F> {
547        if predictions.is_empty() {
548            return Ok(F::zero());
549        }
550
551        let n_predictions = predictions.len();
552        let nsamples = predictions[0].len();
553
554        let mut variances = Vec::new();
555
556        for i in 0..nsamples {
557            let sample_predictions: Vec<F> = predictions.iter().map(|pred| pred[i]).collect();
558
559            let mean =
560                sample_predictions.iter().cloned().sum::<F>() / F::from(n_predictions).unwrap();
561            let variance = sample_predictions
562                .iter()
563                .map(|&pred| (pred - mean) * (pred - mean))
564                .sum::<F>()
565                / F::from(n_predictions - 1).unwrap();
566
567            variances.push(variance);
568        }
569
570        let average_variance =
571            variances.iter().cloned().sum::<F>() / F::from(variances.len()).unwrap();
572        Ok(average_variance.sqrt())
573    }
574
575    fn compute_aleatoric_uncertainty(&self, predictions: &[Array1<F>]) -> Result<F> {
576        // Simplified aleatoric uncertainty computation
577        // In practice, this would require model-specific uncertainty estimates
578        Ok(F::from(0.1).unwrap())
579    }
580
581    fn mask_important_features(
582        &self,
583        sample: &ArrayView1<F>,
584        explanation: &Array1<F>,
585        k: usize,
586    ) -> Result<Array1<F>> {
587        let mut masked = sample.to_owned();
588
589        // Find top-k most important features
590        let mut importance_indices: Vec<(usize, F)> = explanation
591            .iter()
592            .enumerate()
593            .map(|(i, &imp)| (i, imp))
594            .collect();
595        importance_indices
596            .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
597
598        // Mask top-k features (set to zero or mean)
599        for i in 0..k.min(importance_indices.len()) {
600            let feature_idx = importance_indices[i].0;
601            masked[feature_idx] = F::zero(); // Or use feature mean
602        }
603
604        Ok(masked)
605    }
606
607    fn keep_important_features_only(
608        &self,
609        sample: &ArrayView1<F>,
610        explanation: &Array1<F>,
611        k: usize,
612    ) -> Result<Array1<F>> {
613        let mut filtered = Array1::zeros(sample.len());
614
615        // Find top-k most important features
616        let mut importance_indices: Vec<(usize, F)> = explanation
617            .iter()
618            .enumerate()
619            .map(|(i, &imp)| (i, imp))
620            .collect();
621        importance_indices
622            .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
623
624        // Keep only top-k features
625        for i in 0..k.min(importance_indices.len()) {
626            let feature_idx = importance_indices[i].0;
627            filtered[feature_idx] = sample[feature_idx];
628        }
629
630        Ok(filtered)
631    }
632
633    // Complete LIME (Local Interpretable Model-agnostic Explanations) implementation
634    fn compute_lime_importance<M>(
635        &self,
636        model: &M,
637        x_test: &Array2<F>,
638        feature_names: &[String],
639    ) -> Result<HashMap<String, F>>
640    where
641        M: Fn(&ArrayView2<F>) -> Array1<F>,
642    {
643        if x_test.is_empty() || feature_names.is_empty() {
644            return Err(MetricsError::InvalidInput(
645                "Empty input data or feature _names".to_string(),
646            ));
647        }
648
649        if x_test.ncols() != feature_names.len() {
650            return Err(MetricsError::InvalidInput(
651                "Number of features doesn't match feature _names length".to_string(),
652            ));
653        }
654
655        let mut importance_scores = HashMap::new();
656        let nsamples = std::cmp::min(1000, self.n_perturbations); // Limit for efficiency
657
658        // Process each instance separately for local explanations
659        for instance in x_test.axis_iter(Axis(0)) {
660            let instance_importance =
661                self.compute_lime_for_instance(model, &instance, feature_names, nsamples)?;
662
663            // Aggregate importance scores across instances
664            for (feature_name, importance) in instance_importance {
665                let current_score = importance_scores
666                    .get(&feature_name)
667                    .copied()
668                    .unwrap_or(F::zero());
669                importance_scores.insert(
670                    feature_name,
671                    current_score + importance / F::from(x_test.nrows()).unwrap(),
672                );
673            }
674        }
675
676        Ok(importance_scores)
677    }
678
679    /// Compute LIME importance for a single instance
680    fn compute_lime_for_instance<M>(
681        &self,
682        model: &M,
683        instance: &ArrayView1<F>,
684        feature_names: &[String],
685        nsamples: usize,
686    ) -> Result<HashMap<String, F>>
687    where
688        M: Fn(&ArrayView2<F>) -> Array1<F>,
689    {
690        let _n_features = instance.len();
691
692        // Generate perturbed _samples around the instance
693        let (perturbed_samples, weights) = self.generate_lime_samples(instance, nsamples)?;
694
695        // Get model predictions for perturbed _samples
696        let predictions = model(&perturbed_samples.view());
697
698        // Train interpretable model (linear regression) on perturbed data
699        let coefficients =
700            self.fit_interpretable_model(&perturbed_samples, &predictions, &weights)?;
701
702        // Create importance map
703        let mut importance = HashMap::new();
704        for (i, name) in feature_names.iter().enumerate() {
705            if i < coefficients.len() {
706                importance.insert(name.clone(), coefficients[i].abs());
707            }
708        }
709
710        Ok(importance)
711    }
712
713    /// Generate perturbed samples for LIME with distance-based weights
714    fn generate_lime_samples(
715        &self,
716        instance: &ArrayView1<F>,
717        nsamples: usize,
718    ) -> Result<(Array2<F>, Array1<F>)> {
719        let n_features = instance.len();
720        let mut perturbed_samples = Array2::zeros((nsamples, n_features));
721        let mut weights = Array1::zeros(nsamples);
722
723        // Calculate feature statistics for perturbation
724        let feature_mean = instance.mean().unwrap_or(F::zero());
725        let feature_std = {
726            let variance = instance
727                .iter()
728                .map(|&x| (x - feature_mean) * (x - feature_mean))
729                .sum::<F>()
730                / F::from(n_features).unwrap();
731            variance.sqrt()
732        };
733
734        for i in 0..nsamples {
735            let mut perturbed_instance = instance.to_owned();
736            let mut distance_sum = F::zero();
737
738            // Randomly perturb features
739            for j in 0..n_features {
740                // Use simple uniform perturbation around the original value
741                let perturbation_factor = F::from((i + j) as f64 / (nsamples * n_features) as f64)
742                    .unwrap()
743                    - F::from(0.5).unwrap();
744                let perturbation = perturbation_factor * self.perturbation_strength * feature_std;
745
746                perturbed_instance[j] = instance[j] + perturbation;
747                distance_sum = distance_sum + perturbation.abs();
748            }
749
750            // Store perturbed sample
751            for j in 0..n_features {
752                perturbed_samples[[i, j]] = perturbed_instance[j];
753            }
754
755            // Calculate weight based on distance (closer _samples get higher weight)
756            let distance = distance_sum / F::from(n_features).unwrap();
757            weights[i] = (-distance * F::from(2.0).unwrap()).exp(); // Gaussian-like kernel
758        }
759
760        Ok((perturbed_samples, weights))
761    }
762
763    /// Fit interpretable linear model using weighted least squares
764    fn fit_interpretable_model(
765        &self,
766        samples: &Array2<F>,
767        targets: &Array1<F>,
768        weights: &Array1<F>,
769    ) -> Result<Vec<F>> {
770        let nsamples = samples.nrows();
771        let n_features = samples.ncols();
772
773        if nsamples == 0 || n_features == 0 {
774            return Ok(vec![F::zero(); n_features]);
775        }
776
777        // Weighted least squares: (X'WX)^(-1)X'Wy
778        // For simplicity, we'll use a regularized version to avoid singularity
779        let mut xtx = Array2::zeros((n_features, n_features));
780        let mut xty = Array1::zeros(n_features);
781
782        // Compute X'WX and X'Wy
783        for i in 0..nsamples {
784            let weight = weights[i];
785            let target = targets[i];
786
787            for j in 0..n_features {
788                let x_ij = samples[[i, j]];
789
790                // X'Wy
791                xty[j] = xty[j] + weight * x_ij * target;
792
793                // X'WX
794                for k in 0..n_features {
795                    let x_ik = samples[[i, k]];
796                    xtx[[j, k]] = xtx[[j, k]] + weight * x_ij * x_ik;
797                }
798            }
799        }
800
801        // Add regularization to diagonal (Ridge regression)
802        let regularization = F::from(1e-6).unwrap();
803        for i in 0..n_features {
804            xtx[[i, i]] = xtx[[i, i]] + regularization;
805        }
806
807        // Solve linear system using simple Gaussian elimination
808        let coefficients = self.solve_linear_system(&xtx, &xty)?;
809
810        Ok(coefficients)
811    }
812
813    /// Simple linear system solver for weighted least squares
814    fn solve_linear_system(&self, a: &Array2<F>, b: &Array1<F>) -> Result<Vec<F>> {
815        let n = a.nrows();
816        if n != a.ncols() || n != b.len() {
817            return Err(MetricsError::InvalidInput(
818                "Matrix dimensions mismatch".to_string(),
819            ));
820        }
821
822        // Create augmented matrix for Gaussian elimination
823        let mut aug = Array2::zeros((n, n + 1));
824        for i in 0..n {
825            for j in 0..n {
826                aug[[i, j]] = a[[i, j]];
827            }
828            aug[[i, n]] = b[i];
829        }
830
831        // Forward elimination
832        for i in 0..n {
833            // Find pivot
834            let mut max_row = i;
835            for k in (i + 1)..n {
836                if aug[[k, i]].abs() > aug[[max_row, i]].abs() {
837                    max_row = k;
838                }
839            }
840
841            // Swap rows if needed
842            if max_row != i {
843                for j in 0..=n {
844                    let temp = aug[[i, j]];
845                    aug[[i, j]] = aug[[max_row, j]];
846                    aug[[max_row, j]] = temp;
847                }
848            }
849
850            // Check for singular matrix
851            if aug[[i, i]].abs() < F::from(1e-10).unwrap() {
852                // Use pseudoinverse approach for singular case
853                return Ok(vec![F::zero(); n]);
854            }
855
856            // Eliminate column
857            for k in (i + 1)..n {
858                let factor = aug[[k, i]] / aug[[i, i]];
859                for j in i..=n {
860                    aug[[k, j]] = aug[[k, j]] - factor * aug[[i, j]];
861                }
862            }
863        }
864
865        // Back substitution
866        let mut x = vec![F::zero(); n];
867        for i in (0..n).rev() {
868            x[i] = aug[[i, n]];
869            for j in (i + 1)..n {
870                x[i] = x[i] - aug[[i, j]] * x[j];
871            }
872            x[i] = x[i] / aug[[i, i]];
873        }
874
875        Ok(x)
876    }
877
878    /// Complete SHAP (SHapley Additive exPlanations) implementation
879    fn compute_shap_importance<M>(
880        &self,
881        model: &M,
882        x_test: &Array2<F>,
883        feature_names: &[String],
884    ) -> Result<HashMap<String, F>>
885    where
886        M: Fn(&ArrayView2<F>) -> Array1<F>,
887    {
888        if x_test.is_empty() || feature_names.is_empty() {
889            return Err(MetricsError::InvalidInput(
890                "Empty input data or feature _names".to_string(),
891            ));
892        }
893
894        if x_test.ncols() != feature_names.len() {
895            return Err(MetricsError::InvalidInput(
896                "Number of features doesn't match feature _names length".to_string(),
897            ));
898        }
899
900        let mut importance_scores = HashMap::new();
901
902        // Compute background mean for baseline prediction
903        let background_mean = self.compute_background_mean(x_test)?;
904
905        // Process each instance separately for local explanations
906        for instance in x_test.axis_iter(Axis(0)) {
907            let instance_importance =
908                self.compute_shap_for_instance(model, &instance, &background_mean, feature_names)?;
909
910            // Aggregate importance scores across instances
911            for (feature_name, importance) in instance_importance {
912                let current_score = importance_scores
913                    .get(&feature_name)
914                    .copied()
915                    .unwrap_or(F::zero());
916                importance_scores.insert(
917                    feature_name,
918                    current_score + importance / F::from(x_test.nrows()).unwrap(),
919                );
920            }
921        }
922
923        Ok(importance_scores)
924    }
925
926    /// Compute SHAP values for a single instance
927    fn compute_shap_for_instance<M>(
928        &self,
929        model: &M,
930        instance: &ArrayView1<F>,
931        background_mean: &Array1<F>,
932        feature_names: &[String],
933    ) -> Result<HashMap<String, F>>
934    where
935        M: Fn(&ArrayView2<F>) -> Array1<F>,
936    {
937        let n_features = instance.len();
938
939        // Use efficient approximation for SHAP values
940        // This implements a sampling-based approximation of Shapley values
941        let max_coalitions = std::cmp::min(
942            2_usize.pow(std::cmp::min(n_features, 10) as u32),
943            self.n_perturbations,
944        );
945
946        let shapley_values = self.compute_shapley_values_approximation(
947            model,
948            instance,
949            background_mean,
950            max_coalitions,
951        )?;
952
953        // Create importance map
954        let mut importance = HashMap::new();
955        for (i, name) in feature_names.iter().enumerate() {
956            if i < shapley_values.len() {
957                importance.insert(name.clone(), shapley_values[i].abs());
958            }
959        }
960
961        Ok(importance)
962    }
963
964    /// Compute background mean for SHAP baseline
965    fn compute_background_mean(&self, xdata: &Array2<F>) -> Result<Array1<F>> {
966        if xdata.is_empty() {
967            return Err(MetricsError::InvalidInput(
968                "Empty _data for background computation".to_string(),
969            ));
970        }
971
972        let n_features = xdata.ncols();
973        let mut background = Array1::zeros(n_features);
974
975        for j in 0..n_features {
976            let column_sum: F = xdata.column(j).iter().cloned().sum();
977            background[j] = column_sum / F::from(xdata.nrows()).unwrap();
978        }
979
980        Ok(background)
981    }
982
983    /// Efficient approximation of Shapley values using sampling
984    fn compute_shapley_values_approximation<M>(
985        &self,
986        model: &M,
987        instance: &ArrayView1<F>,
988        background: &Array1<F>,
989        max_coalitions: usize,
990    ) -> Result<Vec<F>>
991    where
992        M: Fn(&ArrayView2<F>) -> Array1<F>,
993    {
994        let n_features = instance.len();
995        let mut shapley_values = vec![F::zero(); n_features];
996
997        // Get baseline prediction (no features)
998        let baseline_input =
999            Array2::from_shape_vec((1, n_features), background.to_vec()).map_err(|_| {
1000                MetricsError::InvalidInput("Failed to create baseline array".to_string())
1001            })?;
1002        let baseline_pred = model(&baseline_input.view())[0];
1003
1004        // Get full prediction (all features)
1005        let full_input = Array2::from_shape_vec((1, n_features), instance.to_vec())
1006            .map_err(|_| MetricsError::InvalidInput("Failed to create full array".to_string()))?;
1007        let full_pred = model(&full_input.view())[0];
1008
1009        // For efficiency, use sampling-based approximation
1010        let nsamples = std::cmp::min(max_coalitions, 1000);
1011
1012        for i in 0..n_features {
1013            let mut marginal_contributions = Vec::new();
1014
1015            // Sample different _coalitions and compute marginal contribution of feature i
1016            for sample_idx in 0..nsamples {
1017                let coalition = self.generate_random_coalition(n_features, i, sample_idx);
1018
1019                // Compute prediction with coalition including feature i
1020                let with_i =
1021                    self.create_coalition_input(instance, background, &coalition, Some(i))?;
1022                let pred_with_i = model(&with_i.view())[0];
1023
1024                // Compute prediction with coalition excluding feature i
1025                let without_i =
1026                    self.create_coalition_input(instance, background, &coalition, None)?;
1027                let pred_without_i = model(&without_i.view())[0];
1028
1029                // Marginal contribution
1030                let marginal_contrib = pred_with_i - pred_without_i;
1031                marginal_contributions.push(marginal_contrib);
1032            }
1033
1034            // Average marginal contributions to get Shapley value
1035            if !marginal_contributions.is_empty() {
1036                let sum: F = marginal_contributions.iter().cloned().sum();
1037                shapley_values[i] = sum / F::from(marginal_contributions.len()).unwrap();
1038            }
1039        }
1040
1041        // Ensure Shapley values sum to difference between full and baseline predictions
1042        // (efficiency property of Shapley values)
1043        let total_difference = full_pred - baseline_pred;
1044        let shapley_sum: F = shapley_values.iter().cloned().sum();
1045
1046        if shapley_sum != F::zero() {
1047            let normalization_factor = total_difference / shapley_sum;
1048            for val in shapley_values.iter_mut() {
1049                *val = *val * normalization_factor;
1050            }
1051        }
1052
1053        Ok(shapley_values)
1054    }
1055
1056    /// Generate a random coalition (subset of features) for sampling
1057    fn generate_random_coalition(
1058        &self,
1059        n_features: usize,
1060        target_feature: usize,
1061        seed: usize,
1062    ) -> Vec<bool> {
1063        let mut coalition = vec![false; n_features];
1064
1065        // Use simple deterministic "random" based on seed for reproducibility
1066        let mut pseudo_random = seed;
1067
1068        for i in 0..n_features {
1069            if i != target_feature {
1070                pseudo_random = pseudo_random.wrapping_mul(1103515245).wrapping_add(12345);
1071                coalition[i] = pseudo_random.is_multiple_of(2);
1072            }
1073        }
1074
1075        coalition
1076    }
1077
1078    /// Create input array for a specific coalition
1079    fn create_coalition_input(
1080        &self,
1081        instance: &ArrayView1<F>,
1082        background: &Array1<F>,
1083        coalition: &[bool],
1084        include_target: Option<usize>,
1085    ) -> Result<Array2<F>> {
1086        let n_features = instance.len();
1087        let mut coalition_input = background.clone();
1088
1089        // Include features in coalition
1090        for (i, &in_coalition) in coalition.iter().enumerate() {
1091            if in_coalition {
1092                coalition_input[i] = instance[i];
1093            }
1094        }
1095
1096        // Include or exclude _target feature
1097        if let Some(target_idx) = include_target {
1098            if target_idx < n_features {
1099                coalition_input[target_idx] = instance[target_idx];
1100            }
1101        }
1102
1103        // Convert to 2D array for model input
1104        Array2::from_shape_vec((1, n_features), coalition_input.to_vec()).map_err(|_| {
1105            MetricsError::InvalidInput("Failed to create coalition input array".to_string())
1106        })
1107    }
1108
1109    /// Complete gradient-based importance computation using numerical differentiation
1110    fn compute_gradient_importance<M>(
1111        &self,
1112        model: &M,
1113        x_test: &Array2<F>,
1114        feature_names: &[String],
1115    ) -> Result<HashMap<String, F>>
1116    where
1117        M: Fn(&ArrayView2<F>) -> Array1<F>,
1118    {
1119        if x_test.is_empty() || feature_names.is_empty() {
1120            return Err(MetricsError::InvalidInput(
1121                "Empty input data or feature _names".to_string(),
1122            ));
1123        }
1124
1125        if x_test.ncols() != feature_names.len() {
1126            return Err(MetricsError::InvalidInput(
1127                "Number of features doesn't match feature _names length".to_string(),
1128            ));
1129        }
1130
1131        let mut importance_scores = HashMap::new();
1132
1133        // Process each instance separately for local explanations
1134        for instance in x_test.axis_iter(Axis(0)) {
1135            let instance_importance =
1136                self.compute_gradient_for_instance(model, &instance, feature_names)?;
1137
1138            // Aggregate importance scores across instances
1139            for (feature_name, importance) in instance_importance {
1140                let current_score = importance_scores
1141                    .get(&feature_name)
1142                    .copied()
1143                    .unwrap_or(F::zero());
1144                importance_scores.insert(
1145                    feature_name,
1146                    current_score + importance / F::from(x_test.nrows()).unwrap(),
1147                );
1148            }
1149        }
1150
1151        Ok(importance_scores)
1152    }
1153
1154    /// Compute gradient-based importance for a single instance
1155    fn compute_gradient_for_instance<M>(
1156        &self,
1157        model: &M,
1158        instance: &ArrayView1<F>,
1159        feature_names: &[String],
1160    ) -> Result<HashMap<String, F>>
1161    where
1162        M: Fn(&ArrayView2<F>) -> Array1<F>,
1163    {
1164        let n_features = instance.len();
1165
1166        // Compute numerical gradients using finite differences
1167        let gradients = self.compute_numerical_gradients(model, instance)?;
1168
1169        // Multiple gradient-based attribution methods
1170        let saliency_map = self.compute_saliency_map(&gradients, instance)?;
1171        let integrated_gradients = self.compute_integrated_gradients(model, instance)?;
1172        let gradient_times_input = self.compute_gradient_times_input(&gradients, instance)?;
1173
1174        // Combine different gradient methods with equal weighting
1175        let mut importance = HashMap::new();
1176        for (i, name) in feature_names.iter().enumerate() {
1177            if i < n_features {
1178                let combined_importance =
1179                    (saliency_map[i] + integrated_gradients[i] + gradient_times_input[i])
1180                        / F::from(3.0).unwrap();
1181                importance.insert(name.clone(), combined_importance.abs());
1182            }
1183        }
1184
1185        Ok(importance)
1186    }
1187
1188    /// Compute numerical gradients using finite differences
1189    fn compute_numerical_gradients<M>(&self, model: &M, instance: &ArrayView1<F>) -> Result<Vec<F>>
1190    where
1191        M: Fn(&ArrayView2<F>) -> Array1<F>,
1192    {
1193        let n_features = instance.len();
1194        let mut gradients = vec![F::zero(); n_features];
1195
1196        // Use adaptive step size based on feature magnitude
1197        let epsilon_base = F::from(1e-5).unwrap();
1198
1199        // Get baseline prediction
1200        let baseline_input =
1201            Array2::from_shape_vec((1, n_features), instance.to_vec()).map_err(|_| {
1202                MetricsError::InvalidInput("Failed to create baseline array".to_string())
1203            })?;
1204        let _baseline_pred = model(&baseline_input.view())[0];
1205
1206        // Compute partial derivatives using central differences
1207        for i in 0..n_features {
1208            let feature_magnitude = instance[i].abs().max(F::from(1.0).unwrap());
1209            let epsilon = epsilon_base * feature_magnitude;
1210
1211            // Forward step
1212            let mut forward_instance = instance.to_owned();
1213            forward_instance[i] = forward_instance[i] + epsilon;
1214            let forward_input = Array2::from_shape_vec((1, n_features), forward_instance.to_vec())
1215                .map_err(|_| {
1216                    MetricsError::InvalidInput("Failed to create forward array".to_string())
1217                })?;
1218            let forward_pred = model(&forward_input.view())[0];
1219
1220            // Backward step
1221            let mut backward_instance = instance.to_owned();
1222            backward_instance[i] = backward_instance[i] - epsilon;
1223            let backward_input =
1224                Array2::from_shape_vec((1, n_features), backward_instance.to_vec()).map_err(
1225                    |_| MetricsError::InvalidInput("Failed to create backward array".to_string()),
1226                )?;
1227            let backward_pred = model(&backward_input.view())[0];
1228
1229            // Central difference approximation
1230            gradients[i] = (forward_pred - backward_pred) / (F::from(2.0).unwrap() * epsilon);
1231        }
1232
1233        Ok(gradients)
1234    }
1235
1236    /// Compute saliency map (simple gradient magnitude)
1237    fn compute_saliency_map(&self, gradients: &[F], instance: &ArrayView1<F>) -> Result<Vec<F>> {
1238        // Saliency map is simply the absolute gradient values
1239        Ok(gradients.iter().map(|&g| g.abs()).collect())
1240    }
1241
1242    /// Compute integrated gradients approximation
1243    fn compute_integrated_gradients<M>(&self, model: &M, instance: &ArrayView1<F>) -> Result<Vec<F>>
1244    where
1245        M: Fn(&ArrayView2<F>) -> Array1<F>,
1246    {
1247        let n_features = instance.len();
1248        let mut integrated_grads = vec![F::zero(); n_features];
1249
1250        // Use zero baseline for integrated gradients
1251        let baseline = Array1::zeros(n_features);
1252        let n_steps = 50; // Number of integration steps
1253
1254        // Approximate integral using Riemann sum
1255        for step in 0..n_steps {
1256            let alpha = F::from(step as f64).unwrap() / F::from(n_steps as f64).unwrap();
1257
1258            // Interpolate between baseline and instance
1259            let mut interpolated = Array1::zeros(n_features);
1260            for i in 0..n_features {
1261                interpolated[i] = baseline[i] + alpha * (instance[i] - baseline[i]);
1262            }
1263
1264            // Compute gradients at interpolated point
1265            let step_gradients = self.compute_numerical_gradients(model, &interpolated.view())?;
1266
1267            // Accumulate gradients
1268            for i in 0..n_features {
1269                integrated_grads[i] =
1270                    integrated_grads[i] + step_gradients[i] * (instance[i] - baseline[i]);
1271            }
1272        }
1273
1274        // Average over steps
1275        for grad in integrated_grads.iter_mut() {
1276            *grad = *grad / F::from(n_steps).unwrap();
1277        }
1278
1279        Ok(integrated_grads)
1280    }
1281
1282    /// Compute gradient × input attribution
1283    fn compute_gradient_times_input(
1284        &self,
1285        gradients: &[F],
1286        instance: &ArrayView1<F>,
1287    ) -> Result<Vec<F>> {
1288        let mut grad_times_input = Vec::new();
1289
1290        for (i, &grad) in gradients.iter().enumerate() {
1291            if i < instance.len() {
1292                grad_times_input.push(grad * instance[i]);
1293            }
1294        }
1295
1296        Ok(grad_times_input)
1297    }
1298}
1299
1300/// Explanation method types
1301#[derive(Debug, Clone)]
1302pub enum ExplanationMethod {
1303    /// Permutation importance
1304    Permutation,
1305    /// LIME (Local Interpretable Model-agnostic Explanations)
1306    LIME,
1307    /// SHAP (SHapley Additive exPlanations)
1308    SHAP,
1309    /// Gradient-based explanations
1310    GradientBased,
1311}
1312
1313/// Compute model interpretability score
1314#[allow(dead_code)]
1315pub fn compute_interpretability_score<F: Float + std::iter::Sum>(
1316    explainability_metrics: &ExplainabilityMetrics<F>,
1317) -> F {
1318    // Weighted combination of different explainability aspects
1319    let feature_importance_score = if explainability_metrics.feature_importance.is_empty() {
1320        F::zero()
1321    } else {
1322        explainability_metrics
1323            .feature_importance
1324            .values()
1325            .cloned()
1326            .sum::<F>()
1327            / F::from(explainability_metrics.feature_importance.len()).unwrap()
1328    };
1329
1330    let weights = [
1331        F::from(0.25).unwrap(), // feature importance
1332        F::from(0.2).unwrap(),  // local consistency
1333        F::from(0.2).unwrap(),  // global stability
1334        F::from(0.15).unwrap(), // faithfulness
1335        F::from(0.15).unwrap(), // completeness
1336        F::from(0.05).unwrap(), // uncertainty
1337    ];
1338
1339    let scores = [
1340        feature_importance_score,
1341        explainability_metrics.local_consistency,
1342        explainability_metrics.global_stability,
1343        explainability_metrics.faithfulness,
1344        explainability_metrics.completeness,
1345        F::one() - explainability_metrics.uncertainty_metrics.total_uncertainty, // Lower uncertainty is better
1346    ];
1347
1348    weights
1349        .iter()
1350        .zip(scores.iter())
1351        .map(|(&w, &s)| w * s)
1352        .sum()
1353}
1354
1355#[cfg(test)]
1356mod tests {
1357    use super::*;
1358    use scirs2_core::ndarray::array;
1359
1360    #[test]
1361    fn test_explainability_evaluator_creation() {
1362        let evaluator = ExplainabilityEvaluator::<f64>::new()
1363            .with_perturbations(50)
1364            .with_perturbation_strength(0.05)
1365            .with_importance_threshold(0.02);
1366
1367        assert_eq!(evaluator.n_perturbations, 50);
1368        assert_eq!(evaluator.perturbation_strength, 0.05);
1369        assert_eq!(evaluator.importance_threshold, 0.02);
1370    }
1371
1372    #[test]
1373    fn test_correlation_computation() {
1374        let evaluator = ExplainabilityEvaluator::<f64>::new();
1375
1376        let x = array![1.0, 2.0, 3.0, 4.0, 5.0];
1377        let y = array![2.0, 4.0, 6.0, 8.0, 10.0]; // Perfect correlation
1378
1379        let correlation = evaluator.compute_correlation(&x, &y).unwrap();
1380        assert!((correlation - 1.0).abs() < 1e-10);
1381    }
1382
1383    #[test]
1384    fn test_permutation_feature() {
1385        let evaluator = ExplainabilityEvaluator::<f64>::new();
1386        let mut data = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
1387        let original_data = data.clone();
1388
1389        evaluator.permute_feature(&mut data, 1).unwrap();
1390
1391        // Feature 1 should be different, others should be the same
1392        assert_eq!(data.column(0), original_data.column(0));
1393        assert_eq!(data.column(2), original_data.column(2));
1394        // Column 1 should have the same values but potentially in different order
1395        assert_eq!(data.column(1).len(), original_data.column(1).len());
1396    }
1397
1398    #[test]
1399    fn test_interpretability_score() {
1400        let mut feature_importance = HashMap::new();
1401        feature_importance.insert("feature1".to_string(), 0.5);
1402        feature_importance.insert("feature2".to_string(), 0.3);
1403
1404        let metrics = ExplainabilityMetrics {
1405            feature_importance,
1406            local_consistency: 0.8,
1407            global_stability: 0.7,
1408            uncertainty_metrics: UncertaintyMetrics {
1409                epistemic_uncertainty: 0.1,
1410                aleatoric_uncertainty: 0.05,
1411                total_uncertainty: 0.15,
1412                coverage: 0.95,
1413                calibration_error: 0.02,
1414            },
1415            faithfulness: 0.9,
1416            completeness: 0.85,
1417        };
1418
1419        let score = compute_interpretability_score(&metrics);
1420        assert!(score > 0.0 && score <= 1.0);
1421    }
1422
1423    #[test]
1424    fn test_bootstrap_sampling() {
1425        let evaluator = ExplainabilityEvaluator::<f64>::new();
1426        let indices = evaluator.bootstrap_sample_indices(10).unwrap();
1427
1428        assert_eq!(indices.len(), 10);
1429        // All indices should be valid (0-9)
1430        assert!(indices.iter().all(|&i| i < 10));
1431    }
1432
1433    #[test]
1434    fn test_mask_important_features() {
1435        let evaluator = ExplainabilityEvaluator::<f64>::new();
1436        let sample = array![1.0, 2.0, 3.0, 4.0, 5.0];
1437        let explanation = array![0.1, 0.5, 0.2, 0.8, 0.3]; // Feature 3 most important, then 1
1438
1439        let masked = evaluator
1440            .mask_important_features(&sample.view(), &explanation, 2)
1441            .unwrap();
1442
1443        // Features 3 and 1 (most important) should be masked to 0
1444        assert_eq!(masked[3], 0.0);
1445        assert_eq!(masked[1], 0.0);
1446        // Other features should remain unchanged
1447        assert_eq!(masked[0], 1.0);
1448        assert_eq!(masked[2], 3.0);
1449        assert_eq!(masked[4], 5.0);
1450    }
1451}