1use crate::types::DataMatrix;
8use rand::prelude::*;
9use rand::{Rng, SeedableRng, rng};
10use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct SHAPConfig {
20 pub n_samples: usize,
22 pub use_kernel_shap: bool,
24 pub regularization: f64,
26 pub seed: Option<u64>,
28}
29
30impl Default for SHAPConfig {
31 fn default() -> Self {
32 Self {
33 n_samples: 100,
34 use_kernel_shap: true,
35 regularization: 0.01,
36 seed: None,
37 }
38 }
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct SHAPExplanation {
44 pub base_value: f64,
46 pub shap_values: Vec<f64>,
48 pub feature_names: Option<Vec<String>>,
50 pub prediction: f64,
52 pub shap_sum: f64,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct SHAPBatchResult {
59 pub base_value: f64,
61 pub shap_values: Vec<Vec<f64>>,
63 pub feature_names: Option<Vec<String>>,
65 pub feature_importance: Vec<f64>,
67}
68
69#[derive(Debug, Clone)]
80pub struct SHAPValues {
81 metadata: KernelMetadata,
82}
83
84impl Default for SHAPValues {
85 fn default() -> Self {
86 Self::new()
87 }
88}
89
90impl SHAPValues {
91 #[must_use]
93 pub fn new() -> Self {
94 Self {
95 metadata: KernelMetadata::batch("ml/shap-values", Domain::StatisticalML)
96 .with_description("Kernel SHAP for model-agnostic feature explanations")
97 .with_throughput(1_000)
98 .with_latency_us(500.0),
99 }
100 }
101
102 pub fn explain<F>(
110 instance: &[f64],
111 background: &DataMatrix,
112 predict_fn: F,
113 config: &SHAPConfig,
114 ) -> SHAPExplanation
115 where
116 F: Fn(&[f64]) -> f64,
117 {
118 let n_features = instance.len();
119
120 if n_features == 0 || background.n_samples == 0 {
121 return SHAPExplanation {
122 base_value: 0.0,
123 shap_values: Vec::new(),
124 feature_names: None,
125 prediction: 0.0,
126 shap_sum: 0.0,
127 };
128 }
129
130 let base_value: f64 = (0..background.n_samples)
132 .map(|i| predict_fn(background.row(i)))
133 .sum::<f64>()
134 / background.n_samples as f64;
135
136 let prediction = predict_fn(instance);
137
138 let shap_values = if config.use_kernel_shap {
140 Self::kernel_shap(instance, background, &predict_fn, config)
141 } else {
142 Self::sampling_shap(instance, background, &predict_fn, config)
143 };
144
145 let shap_sum: f64 = shap_values.iter().sum();
146
147 SHAPExplanation {
148 base_value,
149 shap_values,
150 feature_names: None,
151 prediction,
152 shap_sum,
153 }
154 }
155
156 fn kernel_shap<F>(
158 instance: &[f64],
159 background: &DataMatrix,
160 predict_fn: &F,
161 config: &SHAPConfig,
162 ) -> Vec<f64>
163 where
164 F: Fn(&[f64]) -> f64,
165 {
166 let n_features = instance.len();
167 let n_samples = config.n_samples;
168
169 let mut rng = match config.seed {
170 Some(seed) => StdRng::seed_from_u64(seed),
171 None => StdRng::from_rng(&mut rng()),
172 };
173
174 let mut coalitions: Vec<Vec<bool>> = Vec::with_capacity(n_samples);
176 let mut predictions: Vec<f64> = Vec::with_capacity(n_samples);
177 let mut weights: Vec<f64> = Vec::with_capacity(n_samples);
178
179 coalitions.push(vec![true; n_features]);
181 coalitions.push(vec![false; n_features]);
182
183 for coalition in &coalitions[..2] {
184 let masked = Self::create_masked_instance(instance, background, coalition, &mut rng);
185 predictions.push(predict_fn(&masked));
186 }
187
188 weights.push(1e6); weights.push(1e6); for _ in 2..n_samples {
193 let coalition: Vec<bool> = (0..n_features).map(|_| rng.random_bool(0.5)).collect();
194
195 let z: usize = coalition.iter().filter(|&&b| b).count();
196 let weight = Self::kernel_shap_weight(n_features, z);
197
198 let masked = Self::create_masked_instance(instance, background, &coalition, &mut rng);
199 let pred = predict_fn(&masked);
200
201 coalitions.push(coalition);
202 predictions.push(pred);
203 weights.push(weight);
204 }
205
206 Self::solve_weighted_regression(&coalitions, &predictions, &weights, config.regularization)
208 }
209
210 fn sampling_shap<F>(
212 instance: &[f64],
213 background: &DataMatrix,
214 predict_fn: &F,
215 config: &SHAPConfig,
216 ) -> Vec<f64>
217 where
218 F: Fn(&[f64]) -> f64,
219 {
220 let n_features = instance.len();
221 let mut shap_values = vec![0.0; n_features];
222 let samples_per_feature = config.n_samples / n_features;
223
224 let mut rng = match config.seed {
225 Some(seed) => StdRng::seed_from_u64(seed),
226 None => StdRng::from_rng(&mut rng()),
227 };
228
229 for feature_idx in 0..n_features {
230 let mut contributions = Vec::with_capacity(samples_per_feature);
231
232 for _ in 0..samples_per_feature {
233 let mut perm: Vec<usize> = (0..n_features).collect();
235 perm.shuffle(&mut rng);
236
237 let feature_pos = perm.iter().position(|&i| i == feature_idx).unwrap();
238
239 let before: Vec<bool> = (0..n_features)
241 .map(|i| {
242 let pos = perm.iter().position(|&p| p == i).unwrap();
243 pos < feature_pos
244 })
245 .collect();
246
247 let mut with_feature = before.clone();
249 with_feature[feature_idx] = true;
250
251 let bg_idx = rng.random_range(0..background.n_samples);
253 let bg = background.row(bg_idx);
254
255 let x_with: Vec<f64> = (0..n_features)
257 .map(|i| if with_feature[i] { instance[i] } else { bg[i] })
258 .collect();
259
260 let x_without: Vec<f64> = (0..n_features)
261 .map(|i| if before[i] { instance[i] } else { bg[i] })
262 .collect();
263
264 let contribution = predict_fn(&x_with) - predict_fn(&x_without);
265 contributions.push(contribution);
266 }
267
268 shap_values[feature_idx] =
269 contributions.iter().sum::<f64>() / contributions.len() as f64;
270 }
271
272 shap_values
273 }
274
275 fn kernel_shap_weight(n_features: usize, coalition_size: usize) -> f64 {
277 if coalition_size == 0 || coalition_size == n_features {
278 return 1e6; }
280
281 let m = n_features as f64;
282 let z = coalition_size as f64;
283
284 let binomial = Self::binomial(n_features, coalition_size);
286 if binomial == 0.0 {
287 return 0.0;
288 }
289
290 (m - 1.0) / (binomial * z * (m - z))
291 }
292
293 fn binomial(n: usize, k: usize) -> f64 {
295 if k > n {
296 return 0.0;
297 }
298 let k = k.min(n - k);
299 let mut result = 1.0;
300 for i in 0..k {
301 result *= (n - i) as f64 / (i + 1) as f64;
302 }
303 result
304 }
305
306 fn create_masked_instance(
308 instance: &[f64],
309 background: &DataMatrix,
310 coalition: &[bool],
311 rng: &mut StdRng,
312 ) -> Vec<f64> {
313 let bg_idx = rng.random_range(0..background.n_samples);
314 let bg = background.row(bg_idx);
315
316 coalition
317 .iter()
318 .enumerate()
319 .map(|(i, &included)| if included { instance[i] } else { bg[i] })
320 .collect()
321 }
322
323 #[allow(clippy::needless_range_loop)]
325 fn solve_weighted_regression(
326 coalitions: &[Vec<bool>],
327 predictions: &[f64],
328 weights: &[f64],
329 regularization: f64,
330 ) -> Vec<f64> {
331 if coalitions.is_empty() {
332 return Vec::new();
333 }
334
335 let n_features = coalitions[0].len();
336 let n_samples = coalitions.len();
337
338 let mut x: Vec<Vec<f64>> = Vec::with_capacity(n_samples);
340 for coalition in coalitions {
341 let row: Vec<f64> = coalition
342 .iter()
343 .map(|&b| if b { 1.0 } else { 0.0 })
344 .collect();
345 x.push(row);
346 }
347
348 let mut xtw_x = vec![vec![0.0; n_features]; n_features];
350 for i in 0..n_features {
351 for j in 0..n_features {
352 for k in 0..n_samples {
353 xtw_x[i][j] += x[k][i] * weights[k] * x[k][j];
354 }
355 }
356 }
357
358 for i in 0..n_features {
360 xtw_x[i][i] += regularization;
361 }
362
363 let mut xtw_y = vec![0.0; n_features];
365 for i in 0..n_features {
366 for k in 0..n_samples {
367 xtw_y[i] += x[k][i] * weights[k] * predictions[k];
368 }
369 }
370
371 Self::solve_linear_system(&xtw_x, &xtw_y)
373 }
374
375 #[allow(clippy::needless_range_loop)]
377 fn solve_linear_system(a: &[Vec<f64>], b: &[f64]) -> Vec<f64> {
378 let n = b.len();
379 if n == 0 {
380 return Vec::new();
381 }
382
383 let mut aug: Vec<Vec<f64>> = a
385 .iter()
386 .enumerate()
387 .map(|(i, row)| {
388 let mut new_row = row.clone();
389 new_row.push(b[i]);
390 new_row
391 })
392 .collect();
393
394 for i in 0..n {
396 let mut max_idx = i;
398 let mut max_val = aug[i][i].abs();
399 for k in (i + 1)..n {
400 if aug[k][i].abs() > max_val {
401 max_val = aug[k][i].abs();
402 max_idx = k;
403 }
404 }
405
406 aug.swap(i, max_idx);
407
408 if aug[i][i].abs() < 1e-10 {
409 continue;
410 }
411
412 for k in (i + 1)..n {
413 let factor = aug[k][i] / aug[i][i];
414 for j in i..=n {
415 aug[k][j] -= factor * aug[i][j];
416 }
417 }
418 }
419
420 let mut x = vec![0.0; n];
422 for i in (0..n).rev() {
423 if aug[i][i].abs() < 1e-10 {
424 x[i] = 0.0;
425 continue;
426 }
427 x[i] = aug[i][n];
428 for j in (i + 1)..n {
429 x[i] -= aug[i][j] * x[j];
430 }
431 x[i] /= aug[i][i];
432 }
433
434 x
435 }
436
437 pub fn explain_batch<F>(
439 instances: &DataMatrix,
440 background: &DataMatrix,
441 predict_fn: F,
442 config: &SHAPConfig,
443 feature_names: Option<Vec<String>>,
444 ) -> SHAPBatchResult
445 where
446 F: Fn(&[f64]) -> f64,
447 {
448 if instances.n_samples == 0 {
449 return SHAPBatchResult {
450 base_value: 0.0,
451 shap_values: Vec::new(),
452 feature_names: None,
453 feature_importance: Vec::new(),
454 };
455 }
456
457 let base_value: f64 = (0..background.n_samples)
459 .map(|i| predict_fn(background.row(i)))
460 .sum::<f64>()
461 / background.n_samples.max(1) as f64;
462
463 let mut shap_values: Vec<Vec<f64>> = Vec::with_capacity(instances.n_samples);
465
466 for i in 0..instances.n_samples {
467 let instance = instances.row(i);
468 let explanation = Self::explain(instance, background, &predict_fn, config);
469 shap_values.push(explanation.shap_values);
470 }
471
472 let n_features = instances.n_features;
474 let mut feature_importance = vec![0.0; n_features];
475
476 for values in &shap_values {
477 for (i, &v) in values.iter().enumerate() {
478 feature_importance[i] += v.abs();
479 }
480 }
481
482 for imp in &mut feature_importance {
483 *imp /= shap_values.len() as f64;
484 }
485
486 SHAPBatchResult {
487 base_value,
488 shap_values,
489 feature_names,
490 feature_importance,
491 }
492 }
493}
494
495impl GpuKernel for SHAPValues {
496 fn metadata(&self) -> &KernelMetadata {
497 &self.metadata
498 }
499}
500
501#[derive(Debug, Clone, Serialize, Deserialize)]
507pub struct FeatureImportanceConfig {
508 pub n_permutations: usize,
510 pub seed: Option<u64>,
512 pub metric: ImportanceMetric,
514}
515
516impl Default for FeatureImportanceConfig {
517 fn default() -> Self {
518 Self {
519 n_permutations: 10,
520 seed: None,
521 metric: ImportanceMetric::Accuracy,
522 }
523 }
524}
525
526#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
528pub enum ImportanceMetric {
529 Accuracy,
531 MSE,
533 MAE,
535 R2,
537}
538
539#[derive(Debug, Clone, Serialize, Deserialize)]
541pub struct FeatureImportanceResult {
542 pub importances: Vec<f64>,
544 pub std_devs: Vec<f64>,
546 pub feature_names: Option<Vec<String>>,
548 pub baseline_score: f64,
550 pub ranking: Vec<usize>,
552}
553
554#[derive(Debug, Clone)]
560pub struct FeatureImportance {
561 metadata: KernelMetadata,
562}
563
564impl Default for FeatureImportance {
565 fn default() -> Self {
566 Self::new()
567 }
568}
569
570impl FeatureImportance {
571 #[must_use]
573 pub fn new() -> Self {
574 Self {
575 metadata: KernelMetadata::batch("ml/feature-importance", Domain::StatisticalML)
576 .with_description("Permutation-based feature importance")
577 .with_throughput(5_000)
578 .with_latency_us(200.0),
579 }
580 }
581
582 pub fn compute<F>(
591 data: &DataMatrix,
592 targets: &[f64],
593 predict_fn: F,
594 config: &FeatureImportanceConfig,
595 feature_names: Option<Vec<String>>,
596 ) -> FeatureImportanceResult
597 where
598 F: Fn(&[f64]) -> f64,
599 {
600 if data.n_samples == 0 || data.n_features == 0 {
601 return FeatureImportanceResult {
602 importances: Vec::new(),
603 std_devs: Vec::new(),
604 feature_names: None,
605 baseline_score: 0.0,
606 ranking: Vec::new(),
607 };
608 }
609
610 let mut rng = match config.seed {
611 Some(seed) => StdRng::seed_from_u64(seed),
612 None => StdRng::from_rng(&mut rng()),
613 };
614
615 let predictions: Vec<f64> = (0..data.n_samples)
617 .map(|i| predict_fn(data.row(i)))
618 .collect();
619 let baseline_score = Self::compute_score(&predictions, targets, config.metric);
620
621 let mut importances = Vec::with_capacity(data.n_features);
623 let mut std_devs = Vec::with_capacity(data.n_features);
624
625 for feature_idx in 0..data.n_features {
626 let mut scores = Vec::with_capacity(config.n_permutations);
627
628 for _ in 0..config.n_permutations {
629 let mut perm_data = data.data.clone();
631 let mut perm_indices: Vec<usize> = (0..data.n_samples).collect();
632 perm_indices.shuffle(&mut rng);
633
634 for (i, &perm_idx) in perm_indices.iter().enumerate() {
636 perm_data[i * data.n_features + feature_idx] =
637 data.data[perm_idx * data.n_features + feature_idx];
638 }
639
640 let perm_matrix = DataMatrix::new(perm_data, data.n_samples, data.n_features);
641
642 let perm_predictions: Vec<f64> = (0..perm_matrix.n_samples)
644 .map(|i| predict_fn(perm_matrix.row(i)))
645 .collect();
646
647 let score = Self::compute_score(&perm_predictions, targets, config.metric);
648 scores.push(score);
649 }
650
651 let mean_score: f64 = scores.iter().sum::<f64>() / scores.len() as f64;
653 let importance = baseline_score - mean_score;
654
655 let variance: f64 =
656 scores.iter().map(|s| (s - mean_score).powi(2)).sum::<f64>() / scores.len() as f64;
657 let std_dev = variance.sqrt();
658
659 importances.push(importance);
660 std_devs.push(std_dev);
661 }
662
663 let mut ranking: Vec<usize> = (0..data.n_features).collect();
665 ranking.sort_by(|&a, &b| {
666 importances[b]
667 .partial_cmp(&importances[a])
668 .unwrap_or(std::cmp::Ordering::Equal)
669 });
670
671 FeatureImportanceResult {
672 importances,
673 std_devs,
674 feature_names,
675 baseline_score,
676 ranking,
677 }
678 }
679
680 fn compute_score(predictions: &[f64], targets: &[f64], metric: ImportanceMetric) -> f64 {
682 if predictions.is_empty() || targets.is_empty() {
683 return 0.0;
684 }
685
686 match metric {
687 ImportanceMetric::Accuracy => {
688 let correct: usize = predictions
689 .iter()
690 .zip(targets.iter())
691 .filter(|&(p, t)| (p.round() - t.round()).abs() < 0.5)
692 .count();
693 correct as f64 / predictions.len() as f64
694 }
695 ImportanceMetric::MSE => {
696 let mse: f64 = predictions
697 .iter()
698 .zip(targets.iter())
699 .map(|(p, t)| (p - t).powi(2))
700 .sum::<f64>()
701 / predictions.len() as f64;
702 -mse }
704 ImportanceMetric::MAE => {
705 let mae: f64 = predictions
706 .iter()
707 .zip(targets.iter())
708 .map(|(p, t)| (p - t).abs())
709 .sum::<f64>()
710 / predictions.len() as f64;
711 -mae }
713 ImportanceMetric::R2 => {
714 let mean_target: f64 = targets.iter().sum::<f64>() / targets.len() as f64;
715 let ss_res: f64 = predictions
716 .iter()
717 .zip(targets.iter())
718 .map(|(p, t)| (t - p).powi(2))
719 .sum();
720 let ss_tot: f64 = targets.iter().map(|t| (t - mean_target).powi(2)).sum();
721 if ss_tot.abs() < 1e-10 {
722 0.0
723 } else {
724 1.0 - ss_res / ss_tot
725 }
726 }
727 }
728 }
729}
730
731impl GpuKernel for FeatureImportance {
732 fn metadata(&self) -> &KernelMetadata {
733 &self.metadata
734 }
735}
736
737#[cfg(test)]
738mod tests {
739 use super::*;
740
741 #[test]
742 fn test_shap_values_metadata() {
743 let kernel = SHAPValues::new();
744 assert_eq!(kernel.metadata().id, "ml/shap-values");
745 }
746
747 #[test]
748 fn test_shap_basic() {
749 let predict_fn = |x: &[f64]| x[0] + 2.0 * x[1];
751
752 let background = DataMatrix::new(vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0], 4, 2);
753
754 let config = SHAPConfig {
755 n_samples: 50,
756 use_kernel_shap: true,
757 regularization: 0.1,
758 seed: Some(42),
759 };
760
761 let instance = vec![1.0, 1.0];
762 let explanation = SHAPValues::explain(&instance, &background, predict_fn, &config);
763
764 assert!(explanation.shap_values.len() == 2);
766 assert!(explanation.prediction > 0.0);
767 }
768
769 #[test]
770 fn test_shap_batch() {
771 let predict_fn = |x: &[f64]| x[0] * 2.0;
772
773 let background = DataMatrix::new(vec![0.0, 0.5, 1.0, 1.5], 4, 1);
774 let instances = DataMatrix::new(vec![0.5, 1.0, 2.0], 3, 1);
775
776 let config = SHAPConfig {
777 n_samples: 20,
778 seed: Some(42),
779 ..Default::default()
780 };
781
782 let result = SHAPValues::explain_batch(&instances, &background, predict_fn, &config, None);
783
784 assert_eq!(result.shap_values.len(), 3);
785 assert_eq!(result.feature_importance.len(), 1);
786 }
787
788 #[test]
789 fn test_shap_empty() {
790 let predict_fn = |x: &[f64]| x.iter().sum();
791 let background = DataMatrix::new(vec![], 0, 0);
792 let config = SHAPConfig::default();
793
794 let explanation = SHAPValues::explain(&[], &background, predict_fn, &config);
795 assert!(explanation.shap_values.is_empty());
796 }
797
798 #[test]
799 fn test_kernel_shap_weight() {
800 assert!(SHAPValues::kernel_shap_weight(5, 0) > 1000.0);
802 assert!(SHAPValues::kernel_shap_weight(5, 5) > 1000.0);
803
804 let w = SHAPValues::kernel_shap_weight(5, 2);
806 assert!(w > 0.0 && w < 1000.0);
807 }
808
809 #[test]
810 fn test_feature_importance_metadata() {
811 let kernel = FeatureImportance::new();
812 assert_eq!(kernel.metadata().id, "ml/feature-importance");
813 }
814
815 #[test]
816 fn test_feature_importance_basic() {
817 let predict_fn = |x: &[f64]| x[0];
819
820 let data = DataMatrix::new(
821 vec![1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 0.0, 4.0, 0.0, 0.0],
822 4,
823 3,
824 );
825 let targets = vec![1.0, 2.0, 3.0, 4.0];
826
827 let config = FeatureImportanceConfig {
828 n_permutations: 5,
829 seed: Some(42),
830 metric: ImportanceMetric::MSE,
831 };
832
833 let result = FeatureImportance::compute(&data, &targets, predict_fn, &config, None);
834
835 assert_eq!(result.importances.len(), 3);
837 assert!(result.importances[0].abs() > result.importances[1].abs());
838 assert!(result.importances[0].abs() > result.importances[2].abs());
839 assert_eq!(result.ranking[0], 0);
840 }
841
842 #[test]
843 fn test_feature_importance_empty() {
844 let predict_fn = |_: &[f64]| 0.0;
845 let data = DataMatrix::new(vec![], 0, 0);
846 let targets: Vec<f64> = vec![];
847 let config = FeatureImportanceConfig::default();
848
849 let result = FeatureImportance::compute(&data, &targets, predict_fn, &config, None);
850 assert!(result.importances.is_empty());
851 }
852
853 #[test]
854 fn test_metrics() {
855 let preds = vec![1.0, 2.0, 3.0];
856 let targets = vec![1.0, 2.0, 3.0];
857
858 let acc = FeatureImportance::compute_score(&preds, &targets, ImportanceMetric::Accuracy);
860 assert!((acc - 1.0).abs() < 0.01);
861
862 let mse = FeatureImportance::compute_score(&preds, &targets, ImportanceMetric::MSE);
863 assert!((mse - 0.0).abs() < 0.01);
864
865 let r2 = FeatureImportance::compute_score(&preds, &targets, ImportanceMetric::R2);
866 assert!((r2 - 1.0).abs() < 0.01);
867 }
868
869 #[test]
870 fn test_binomial() {
871 assert!((SHAPValues::binomial(5, 2) - 10.0).abs() < 0.01);
872 assert!((SHAPValues::binomial(10, 3) - 120.0).abs() < 0.01);
873 assert!((SHAPValues::binomial(5, 0) - 1.0).abs() < 0.01);
874 assert!((SHAPValues::binomial(5, 5) - 1.0).abs() < 0.01);
875 }
876}