rustkernel_ml/
explainability.rs

1//! Explainability kernels for model interpretation.
2//!
3//! This module provides GPU-accelerated explainability algorithms:
4//! - SHAPValues - Kernel SHAP approximation for feature importance
5//! - FeatureImportance - Permutation-based feature importance
6
7use crate::types::DataMatrix;
8use rand::prelude::*;
9use rand::{Rng, SeedableRng, rng};
10use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
11use serde::{Deserialize, Serialize};
12
13// ============================================================================
14// SHAP Values Kernel
15// ============================================================================
16
17/// Configuration for SHAP computation.
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct SHAPConfig {
20    /// Number of samples for approximation.
21    pub n_samples: usize,
22    /// Whether to use kernel SHAP (vs sampling SHAP).
23    pub use_kernel_shap: bool,
24    /// Regularization for weighted least squares.
25    pub regularization: f64,
26    /// Random seed for reproducibility.
27    pub seed: Option<u64>,
28}
29
30impl Default for SHAPConfig {
31    fn default() -> Self {
32        Self {
33            n_samples: 100,
34            use_kernel_shap: true,
35            regularization: 0.01,
36            seed: None,
37        }
38    }
39}
40
41/// SHAP explanation for a single prediction.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct SHAPExplanation {
44    /// Base value (expected prediction over training data).
45    pub base_value: f64,
46    /// SHAP values for each feature.
47    pub shap_values: Vec<f64>,
48    /// Feature names if provided.
49    pub feature_names: Option<Vec<String>>,
50    /// The prediction being explained.
51    pub prediction: f64,
52    /// Sum of SHAP values (should equal prediction - base_value).
53    pub shap_sum: f64,
54}
55
56/// Batch SHAP results.
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct SHAPBatchResult {
59    /// Base value.
60    pub base_value: f64,
61    /// SHAP values matrix (samples x features).
62    pub shap_values: Vec<Vec<f64>>,
63    /// Feature names.
64    pub feature_names: Option<Vec<String>>,
65    /// Mean absolute SHAP values per feature.
66    pub feature_importance: Vec<f64>,
67}
68
69/// SHAP Values kernel.
70///
71/// Computes SHAP (SHapley Additive exPlanations) values for model predictions.
72/// Uses Kernel SHAP approximation which is model-agnostic and works with any
73/// prediction function.
74///
75/// SHAP values satisfy:
76/// - Local accuracy: f(x) = base_value + sum(shap_values)
77/// - Missingness: Missing features have 0 contribution
78/// - Consistency: If a feature's contribution increases, its SHAP value increases
79#[derive(Debug, Clone)]
80pub struct SHAPValues {
81    metadata: KernelMetadata,
82}
83
84impl Default for SHAPValues {
85    fn default() -> Self {
86        Self::new()
87    }
88}
89
90impl SHAPValues {
91    /// Create a new SHAP Values kernel.
92    #[must_use]
93    pub fn new() -> Self {
94        Self {
95            metadata: KernelMetadata::batch("ml/shap-values", Domain::StatisticalML)
96                .with_description("Kernel SHAP for model-agnostic feature explanations")
97                .with_throughput(1_000)
98                .with_latency_us(500.0),
99        }
100    }
101
102    /// Compute SHAP values for a single instance.
103    ///
104    /// # Arguments
105    /// * `instance` - The instance to explain
106    /// * `background` - Background dataset for baseline
107    /// * `predict_fn` - Model prediction function
108    /// * `config` - SHAP configuration
109    pub fn explain<F>(
110        instance: &[f64],
111        background: &DataMatrix,
112        predict_fn: F,
113        config: &SHAPConfig,
114    ) -> SHAPExplanation
115    where
116        F: Fn(&[f64]) -> f64,
117    {
118        let n_features = instance.len();
119
120        if n_features == 0 || background.n_samples == 0 {
121            return SHAPExplanation {
122                base_value: 0.0,
123                shap_values: Vec::new(),
124                feature_names: None,
125                prediction: 0.0,
126                shap_sum: 0.0,
127            };
128        }
129
130        // Compute base value as expected prediction over background
131        let base_value: f64 = (0..background.n_samples)
132            .map(|i| predict_fn(background.row(i)))
133            .sum::<f64>()
134            / background.n_samples as f64;
135
136        let prediction = predict_fn(instance);
137
138        // Use Kernel SHAP
139        let shap_values = if config.use_kernel_shap {
140            Self::kernel_shap(instance, background, &predict_fn, config)
141        } else {
142            Self::sampling_shap(instance, background, &predict_fn, config)
143        };
144
145        let shap_sum: f64 = shap_values.iter().sum();
146
147        SHAPExplanation {
148            base_value,
149            shap_values,
150            feature_names: None,
151            prediction,
152            shap_sum,
153        }
154    }
155
156    /// Kernel SHAP implementation using weighted linear regression.
157    fn kernel_shap<F>(
158        instance: &[f64],
159        background: &DataMatrix,
160        predict_fn: &F,
161        config: &SHAPConfig,
162    ) -> Vec<f64>
163    where
164        F: Fn(&[f64]) -> f64,
165    {
166        let n_features = instance.len();
167        let n_samples = config.n_samples;
168
169        let mut rng = match config.seed {
170            Some(seed) => StdRng::seed_from_u64(seed),
171            None => StdRng::from_rng(&mut rng()),
172        };
173
174        // Generate coalition samples
175        let mut coalitions: Vec<Vec<bool>> = Vec::with_capacity(n_samples);
176        let mut predictions: Vec<f64> = Vec::with_capacity(n_samples);
177        let mut weights: Vec<f64> = Vec::with_capacity(n_samples);
178
179        // Always include full and empty coalitions
180        coalitions.push(vec![true; n_features]);
181        coalitions.push(vec![false; n_features]);
182
183        for coalition in &coalitions[..2] {
184            let masked = Self::create_masked_instance(instance, background, coalition, &mut rng);
185            predictions.push(predict_fn(&masked));
186        }
187
188        weights.push(1e6); // High weight for full coalition
189        weights.push(1e6); // High weight for empty coalition
190
191        // Sample random coalitions
192        for _ in 2..n_samples {
193            let coalition: Vec<bool> = (0..n_features).map(|_| rng.random_bool(0.5)).collect();
194
195            let z: usize = coalition.iter().filter(|&&b| b).count();
196            let weight = Self::kernel_shap_weight(n_features, z);
197
198            let masked = Self::create_masked_instance(instance, background, &coalition, &mut rng);
199            let pred = predict_fn(&masked);
200
201            coalitions.push(coalition);
202            predictions.push(pred);
203            weights.push(weight);
204        }
205
206        // Solve weighted least squares: (X^T W X + λI)^-1 X^T W y
207        Self::solve_weighted_regression(&coalitions, &predictions, &weights, config.regularization)
208    }
209
210    /// Sampling SHAP implementation (simpler, faster, less accurate).
211    fn sampling_shap<F>(
212        instance: &[f64],
213        background: &DataMatrix,
214        predict_fn: &F,
215        config: &SHAPConfig,
216    ) -> Vec<f64>
217    where
218        F: Fn(&[f64]) -> f64,
219    {
220        let n_features = instance.len();
221        let mut shap_values = vec![0.0; n_features];
222        let samples_per_feature = config.n_samples / n_features;
223
224        let mut rng = match config.seed {
225            Some(seed) => StdRng::seed_from_u64(seed),
226            None => StdRng::from_rng(&mut rng()),
227        };
228
229        for feature_idx in 0..n_features {
230            let mut contributions = Vec::with_capacity(samples_per_feature);
231
232            for _ in 0..samples_per_feature {
233                // Random permutation
234                let mut perm: Vec<usize> = (0..n_features).collect();
235                perm.shuffle(&mut rng);
236
237                let feature_pos = perm.iter().position(|&i| i == feature_idx).unwrap();
238
239                // Features before this one in permutation
240                let before: Vec<bool> = (0..n_features)
241                    .map(|i| {
242                        let pos = perm.iter().position(|&p| p == i).unwrap();
243                        pos < feature_pos
244                    })
245                    .collect();
246
247                // Include current feature
248                let mut with_feature = before.clone();
249                with_feature[feature_idx] = true;
250
251                // Sample background
252                let bg_idx = rng.random_range(0..background.n_samples);
253                let bg = background.row(bg_idx);
254
255                // Create masked instances
256                let x_with: Vec<f64> = (0..n_features)
257                    .map(|i| if with_feature[i] { instance[i] } else { bg[i] })
258                    .collect();
259
260                let x_without: Vec<f64> = (0..n_features)
261                    .map(|i| if before[i] { instance[i] } else { bg[i] })
262                    .collect();
263
264                let contribution = predict_fn(&x_with) - predict_fn(&x_without);
265                contributions.push(contribution);
266            }
267
268            shap_values[feature_idx] =
269                contributions.iter().sum::<f64>() / contributions.len() as f64;
270        }
271
272        shap_values
273    }
274
275    /// Kernel SHAP weight function.
276    fn kernel_shap_weight(n_features: usize, coalition_size: usize) -> f64 {
277        if coalition_size == 0 || coalition_size == n_features {
278            return 1e6; // Very high weight for full/empty coalitions
279        }
280
281        let m = n_features as f64;
282        let z = coalition_size as f64;
283
284        // SHAP kernel weight: (M-1) / (C(M,z) * z * (M-z))
285        let binomial = Self::binomial(n_features, coalition_size);
286        if binomial == 0.0 {
287            return 0.0;
288        }
289
290        (m - 1.0) / (binomial * z * (m - z))
291    }
292
293    /// Binomial coefficient.
294    fn binomial(n: usize, k: usize) -> f64 {
295        if k > n {
296            return 0.0;
297        }
298        let k = k.min(n - k);
299        let mut result = 1.0;
300        for i in 0..k {
301            result *= (n - i) as f64 / (i + 1) as f64;
302        }
303        result
304    }
305
306    /// Create masked instance using background data.
307    fn create_masked_instance(
308        instance: &[f64],
309        background: &DataMatrix,
310        coalition: &[bool],
311        rng: &mut StdRng,
312    ) -> Vec<f64> {
313        let bg_idx = rng.random_range(0..background.n_samples);
314        let bg = background.row(bg_idx);
315
316        coalition
317            .iter()
318            .enumerate()
319            .map(|(i, &included)| if included { instance[i] } else { bg[i] })
320            .collect()
321    }
322
323    /// Solve weighted least squares regression.
324    #[allow(clippy::needless_range_loop)]
325    fn solve_weighted_regression(
326        coalitions: &[Vec<bool>],
327        predictions: &[f64],
328        weights: &[f64],
329        regularization: f64,
330    ) -> Vec<f64> {
331        if coalitions.is_empty() {
332            return Vec::new();
333        }
334
335        let n_features = coalitions[0].len();
336        let n_samples = coalitions.len();
337
338        // Build design matrix X (coalitions as 0/1)
339        let mut x: Vec<Vec<f64>> = Vec::with_capacity(n_samples);
340        for coalition in coalitions {
341            let row: Vec<f64> = coalition
342                .iter()
343                .map(|&b| if b { 1.0 } else { 0.0 })
344                .collect();
345            x.push(row);
346        }
347
348        // Compute X^T W X
349        let mut xtw_x = vec![vec![0.0; n_features]; n_features];
350        for i in 0..n_features {
351            for j in 0..n_features {
352                for k in 0..n_samples {
353                    xtw_x[i][j] += x[k][i] * weights[k] * x[k][j];
354                }
355            }
356        }
357
358        // Add regularization
359        for i in 0..n_features {
360            xtw_x[i][i] += regularization;
361        }
362
363        // Compute X^T W y
364        let mut xtw_y = vec![0.0; n_features];
365        for i in 0..n_features {
366            for k in 0..n_samples {
367                xtw_y[i] += x[k][i] * weights[k] * predictions[k];
368            }
369        }
370
371        // Solve using simple Cholesky-like approach
372        Self::solve_linear_system(&xtw_x, &xtw_y)
373    }
374
375    /// Simple linear system solver.
376    #[allow(clippy::needless_range_loop)]
377    fn solve_linear_system(a: &[Vec<f64>], b: &[f64]) -> Vec<f64> {
378        let n = b.len();
379        if n == 0 {
380            return Vec::new();
381        }
382
383        // Gaussian elimination with partial pivoting
384        let mut aug: Vec<Vec<f64>> = a
385            .iter()
386            .enumerate()
387            .map(|(i, row)| {
388                let mut new_row = row.clone();
389                new_row.push(b[i]);
390                new_row
391            })
392            .collect();
393
394        // Forward elimination
395        for i in 0..n {
396            // Find pivot
397            let mut max_idx = i;
398            let mut max_val = aug[i][i].abs();
399            for k in (i + 1)..n {
400                if aug[k][i].abs() > max_val {
401                    max_val = aug[k][i].abs();
402                    max_idx = k;
403                }
404            }
405
406            aug.swap(i, max_idx);
407
408            if aug[i][i].abs() < 1e-10 {
409                continue;
410            }
411
412            for k in (i + 1)..n {
413                let factor = aug[k][i] / aug[i][i];
414                for j in i..=n {
415                    aug[k][j] -= factor * aug[i][j];
416                }
417            }
418        }
419
420        // Back substitution
421        let mut x = vec![0.0; n];
422        for i in (0..n).rev() {
423            if aug[i][i].abs() < 1e-10 {
424                x[i] = 0.0;
425                continue;
426            }
427            x[i] = aug[i][n];
428            for j in (i + 1)..n {
429                x[i] -= aug[i][j] * x[j];
430            }
431            x[i] /= aug[i][i];
432        }
433
434        x
435    }
436
437    /// Explain multiple instances.
438    pub fn explain_batch<F>(
439        instances: &DataMatrix,
440        background: &DataMatrix,
441        predict_fn: F,
442        config: &SHAPConfig,
443        feature_names: Option<Vec<String>>,
444    ) -> SHAPBatchResult
445    where
446        F: Fn(&[f64]) -> f64,
447    {
448        if instances.n_samples == 0 {
449            return SHAPBatchResult {
450                base_value: 0.0,
451                shap_values: Vec::new(),
452                feature_names: None,
453                feature_importance: Vec::new(),
454            };
455        }
456
457        // Compute base value
458        let base_value: f64 = (0..background.n_samples)
459            .map(|i| predict_fn(background.row(i)))
460            .sum::<f64>()
461            / background.n_samples.max(1) as f64;
462
463        // Compute SHAP values for each instance
464        let mut shap_values: Vec<Vec<f64>> = Vec::with_capacity(instances.n_samples);
465
466        for i in 0..instances.n_samples {
467            let instance = instances.row(i);
468            let explanation = Self::explain(instance, background, &predict_fn, config);
469            shap_values.push(explanation.shap_values);
470        }
471
472        // Compute feature importance as mean absolute SHAP values
473        let n_features = instances.n_features;
474        let mut feature_importance = vec![0.0; n_features];
475
476        for values in &shap_values {
477            for (i, &v) in values.iter().enumerate() {
478                feature_importance[i] += v.abs();
479            }
480        }
481
482        for imp in &mut feature_importance {
483            *imp /= shap_values.len() as f64;
484        }
485
486        SHAPBatchResult {
487            base_value,
488            shap_values,
489            feature_names,
490            feature_importance,
491        }
492    }
493}
494
495impl GpuKernel for SHAPValues {
496    fn metadata(&self) -> &KernelMetadata {
497        &self.metadata
498    }
499}
500
501// ============================================================================
502// Feature Importance Kernel
503// ============================================================================
504
505/// Configuration for permutation feature importance.
506#[derive(Debug, Clone, Serialize, Deserialize)]
507pub struct FeatureImportanceConfig {
508    /// Number of permutations per feature.
509    pub n_permutations: usize,
510    /// Random seed.
511    pub seed: Option<u64>,
512    /// Metric to use (higher is better).
513    pub metric: ImportanceMetric,
514}
515
516impl Default for FeatureImportanceConfig {
517    fn default() -> Self {
518        Self {
519            n_permutations: 10,
520            seed: None,
521            metric: ImportanceMetric::Accuracy,
522        }
523    }
524}
525
526/// Metric for measuring importance.
527#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
528pub enum ImportanceMetric {
529    /// Classification accuracy.
530    Accuracy,
531    /// Mean squared error (for regression).
532    MSE,
533    /// Mean absolute error.
534    MAE,
535    /// R-squared score.
536    R2,
537}
538
539/// Feature importance result.
540#[derive(Debug, Clone, Serialize, Deserialize)]
541pub struct FeatureImportanceResult {
542    /// Importance scores per feature.
543    pub importances: Vec<f64>,
544    /// Standard deviations of importance scores.
545    pub std_devs: Vec<f64>,
546    /// Feature names if provided.
547    pub feature_names: Option<Vec<String>>,
548    /// Baseline score (without permutation).
549    pub baseline_score: f64,
550    /// Ranked feature indices (most important first).
551    pub ranking: Vec<usize>,
552}
553
554/// Permutation Feature Importance kernel.
555///
556/// Computes feature importance by measuring how much model performance
557/// degrades when each feature is randomly shuffled. Features that cause
558/// larger degradation are more important.
559#[derive(Debug, Clone)]
560pub struct FeatureImportance {
561    metadata: KernelMetadata,
562}
563
564impl Default for FeatureImportance {
565    fn default() -> Self {
566        Self::new()
567    }
568}
569
570impl FeatureImportance {
571    /// Create a new Feature Importance kernel.
572    #[must_use]
573    pub fn new() -> Self {
574        Self {
575            metadata: KernelMetadata::batch("ml/feature-importance", Domain::StatisticalML)
576                .with_description("Permutation-based feature importance")
577                .with_throughput(5_000)
578                .with_latency_us(200.0),
579        }
580    }
581
582    /// Compute permutation feature importance.
583    ///
584    /// # Arguments
585    /// * `data` - Input features
586    /// * `targets` - True labels/values
587    /// * `predict_fn` - Model prediction function
588    /// * `config` - Configuration
589    /// * `feature_names` - Optional feature names
590    pub fn compute<F>(
591        data: &DataMatrix,
592        targets: &[f64],
593        predict_fn: F,
594        config: &FeatureImportanceConfig,
595        feature_names: Option<Vec<String>>,
596    ) -> FeatureImportanceResult
597    where
598        F: Fn(&[f64]) -> f64,
599    {
600        if data.n_samples == 0 || data.n_features == 0 {
601            return FeatureImportanceResult {
602                importances: Vec::new(),
603                std_devs: Vec::new(),
604                feature_names: None,
605                baseline_score: 0.0,
606                ranking: Vec::new(),
607            };
608        }
609
610        let mut rng = match config.seed {
611            Some(seed) => StdRng::seed_from_u64(seed),
612            None => StdRng::from_rng(&mut rng()),
613        };
614
615        // Compute baseline score
616        let predictions: Vec<f64> = (0..data.n_samples)
617            .map(|i| predict_fn(data.row(i)))
618            .collect();
619        let baseline_score = Self::compute_score(&predictions, targets, config.metric);
620
621        // Compute importance for each feature
622        let mut importances = Vec::with_capacity(data.n_features);
623        let mut std_devs = Vec::with_capacity(data.n_features);
624
625        for feature_idx in 0..data.n_features {
626            let mut scores = Vec::with_capacity(config.n_permutations);
627
628            for _ in 0..config.n_permutations {
629                // Create permuted data
630                let mut perm_data = data.data.clone();
631                let mut perm_indices: Vec<usize> = (0..data.n_samples).collect();
632                perm_indices.shuffle(&mut rng);
633
634                // Shuffle feature values
635                for (i, &perm_idx) in perm_indices.iter().enumerate() {
636                    perm_data[i * data.n_features + feature_idx] =
637                        data.data[perm_idx * data.n_features + feature_idx];
638                }
639
640                let perm_matrix = DataMatrix::new(perm_data, data.n_samples, data.n_features);
641
642                // Compute predictions with permuted feature
643                let perm_predictions: Vec<f64> = (0..perm_matrix.n_samples)
644                    .map(|i| predict_fn(perm_matrix.row(i)))
645                    .collect();
646
647                let score = Self::compute_score(&perm_predictions, targets, config.metric);
648                scores.push(score);
649            }
650
651            // Importance = baseline - mean(permuted scores)
652            let mean_score: f64 = scores.iter().sum::<f64>() / scores.len() as f64;
653            let importance = baseline_score - mean_score;
654
655            let variance: f64 =
656                scores.iter().map(|s| (s - mean_score).powi(2)).sum::<f64>() / scores.len() as f64;
657            let std_dev = variance.sqrt();
658
659            importances.push(importance);
660            std_devs.push(std_dev);
661        }
662
663        // Compute ranking
664        let mut ranking: Vec<usize> = (0..data.n_features).collect();
665        ranking.sort_by(|&a, &b| {
666            importances[b]
667                .partial_cmp(&importances[a])
668                .unwrap_or(std::cmp::Ordering::Equal)
669        });
670
671        FeatureImportanceResult {
672            importances,
673            std_devs,
674            feature_names,
675            baseline_score,
676            ranking,
677        }
678    }
679
680    /// Compute score based on metric.
681    fn compute_score(predictions: &[f64], targets: &[f64], metric: ImportanceMetric) -> f64 {
682        if predictions.is_empty() || targets.is_empty() {
683            return 0.0;
684        }
685
686        match metric {
687            ImportanceMetric::Accuracy => {
688                let correct: usize = predictions
689                    .iter()
690                    .zip(targets.iter())
691                    .filter(|&(p, t)| (p.round() - t.round()).abs() < 0.5)
692                    .count();
693                correct as f64 / predictions.len() as f64
694            }
695            ImportanceMetric::MSE => {
696                let mse: f64 = predictions
697                    .iter()
698                    .zip(targets.iter())
699                    .map(|(p, t)| (p - t).powi(2))
700                    .sum::<f64>()
701                    / predictions.len() as f64;
702                -mse // Negative because higher is better
703            }
704            ImportanceMetric::MAE => {
705                let mae: f64 = predictions
706                    .iter()
707                    .zip(targets.iter())
708                    .map(|(p, t)| (p - t).abs())
709                    .sum::<f64>()
710                    / predictions.len() as f64;
711                -mae // Negative because higher is better
712            }
713            ImportanceMetric::R2 => {
714                let mean_target: f64 = targets.iter().sum::<f64>() / targets.len() as f64;
715                let ss_res: f64 = predictions
716                    .iter()
717                    .zip(targets.iter())
718                    .map(|(p, t)| (t - p).powi(2))
719                    .sum();
720                let ss_tot: f64 = targets.iter().map(|t| (t - mean_target).powi(2)).sum();
721                if ss_tot.abs() < 1e-10 {
722                    0.0
723                } else {
724                    1.0 - ss_res / ss_tot
725                }
726            }
727        }
728    }
729}
730
731impl GpuKernel for FeatureImportance {
732    fn metadata(&self) -> &KernelMetadata {
733        &self.metadata
734    }
735}
736
737#[cfg(test)]
738mod tests {
739    use super::*;
740
741    #[test]
742    fn test_shap_values_metadata() {
743        let kernel = SHAPValues::new();
744        assert_eq!(kernel.metadata().id, "ml/shap-values");
745    }
746
747    #[test]
748    fn test_shap_basic() {
749        // Simple linear model: f(x) = x[0] + 2*x[1]
750        let predict_fn = |x: &[f64]| x[0] + 2.0 * x[1];
751
752        let background = DataMatrix::new(vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0], 4, 2);
753
754        let config = SHAPConfig {
755            n_samples: 50,
756            use_kernel_shap: true,
757            regularization: 0.1,
758            seed: Some(42),
759        };
760
761        let instance = vec![1.0, 1.0];
762        let explanation = SHAPValues::explain(&instance, &background, predict_fn, &config);
763
764        // For linear model, SHAP values should approximate coefficients
765        assert!(explanation.shap_values.len() == 2);
766        assert!(explanation.prediction > 0.0);
767    }
768
769    #[test]
770    fn test_shap_batch() {
771        let predict_fn = |x: &[f64]| x[0] * 2.0;
772
773        let background = DataMatrix::new(vec![0.0, 0.5, 1.0, 1.5], 4, 1);
774        let instances = DataMatrix::new(vec![0.5, 1.0, 2.0], 3, 1);
775
776        let config = SHAPConfig {
777            n_samples: 20,
778            seed: Some(42),
779            ..Default::default()
780        };
781
782        let result = SHAPValues::explain_batch(&instances, &background, predict_fn, &config, None);
783
784        assert_eq!(result.shap_values.len(), 3);
785        assert_eq!(result.feature_importance.len(), 1);
786    }
787
788    #[test]
789    fn test_shap_empty() {
790        let predict_fn = |x: &[f64]| x.iter().sum();
791        let background = DataMatrix::new(vec![], 0, 0);
792        let config = SHAPConfig::default();
793
794        let explanation = SHAPValues::explain(&[], &background, predict_fn, &config);
795        assert!(explanation.shap_values.is_empty());
796    }
797
798    #[test]
799    fn test_kernel_shap_weight() {
800        // Edge cases
801        assert!(SHAPValues::kernel_shap_weight(5, 0) > 1000.0);
802        assert!(SHAPValues::kernel_shap_weight(5, 5) > 1000.0);
803
804        // Middle values should have finite weights
805        let w = SHAPValues::kernel_shap_weight(5, 2);
806        assert!(w > 0.0 && w < 1000.0);
807    }
808
809    #[test]
810    fn test_feature_importance_metadata() {
811        let kernel = FeatureImportance::new();
812        assert_eq!(kernel.metadata().id, "ml/feature-importance");
813    }
814
815    #[test]
816    fn test_feature_importance_basic() {
817        // Model that only uses first feature
818        let predict_fn = |x: &[f64]| x[0];
819
820        let data = DataMatrix::new(
821            vec![1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 0.0, 4.0, 0.0, 0.0],
822            4,
823            3,
824        );
825        let targets = vec![1.0, 2.0, 3.0, 4.0];
826
827        let config = FeatureImportanceConfig {
828            n_permutations: 5,
829            seed: Some(42),
830            metric: ImportanceMetric::MSE,
831        };
832
833        let result = FeatureImportance::compute(&data, &targets, predict_fn, &config, None);
834
835        // First feature should be most important
836        assert_eq!(result.importances.len(), 3);
837        assert!(result.importances[0].abs() > result.importances[1].abs());
838        assert!(result.importances[0].abs() > result.importances[2].abs());
839        assert_eq!(result.ranking[0], 0);
840    }
841
842    #[test]
843    fn test_feature_importance_empty() {
844        let predict_fn = |_: &[f64]| 0.0;
845        let data = DataMatrix::new(vec![], 0, 0);
846        let targets: Vec<f64> = vec![];
847        let config = FeatureImportanceConfig::default();
848
849        let result = FeatureImportance::compute(&data, &targets, predict_fn, &config, None);
850        assert!(result.importances.is_empty());
851    }
852
853    #[test]
854    fn test_metrics() {
855        let preds = vec![1.0, 2.0, 3.0];
856        let targets = vec![1.0, 2.0, 3.0];
857
858        // Perfect predictions
859        let acc = FeatureImportance::compute_score(&preds, &targets, ImportanceMetric::Accuracy);
860        assert!((acc - 1.0).abs() < 0.01);
861
862        let mse = FeatureImportance::compute_score(&preds, &targets, ImportanceMetric::MSE);
863        assert!((mse - 0.0).abs() < 0.01);
864
865        let r2 = FeatureImportance::compute_score(&preds, &targets, ImportanceMetric::R2);
866        assert!((r2 - 1.0).abs() < 0.01);
867    }
868
869    #[test]
870    fn test_binomial() {
871        assert!((SHAPValues::binomial(5, 2) - 10.0).abs() < 0.01);
872        assert!((SHAPValues::binomial(10, 3) - 120.0).abs() < 0.01);
873        assert!((SHAPValues::binomial(5, 0) - 1.0).abs() < 0.01);
874        assert!((SHAPValues::binomial(5, 5) - 1.0).abs() < 0.01);
875    }
876}