1use crate::types::DataMatrix;
8use rand::prelude::*;
9use rand::{Rng, SeedableRng, rng};
10use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct SHAPConfig {
20 pub n_samples: usize,
22 pub use_kernel_shap: bool,
24 pub regularization: f64,
26 pub seed: Option<u64>,
28}
29
30impl Default for SHAPConfig {
31 fn default() -> Self {
32 Self {
33 n_samples: 100,
34 use_kernel_shap: true,
35 regularization: 0.01,
36 seed: None,
37 }
38 }
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct SHAPExplanation {
44 pub base_value: f64,
46 pub shap_values: Vec<f64>,
48 pub feature_names: Option<Vec<String>>,
50 pub prediction: f64,
52 pub shap_sum: f64,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct SHAPBatchResult {
59 pub base_value: f64,
61 pub shap_values: Vec<Vec<f64>>,
63 pub feature_names: Option<Vec<String>>,
65 pub feature_importance: Vec<f64>,
67}
68
69#[derive(Debug, Clone)]
80pub struct SHAPValues {
81 metadata: KernelMetadata,
82}
83
84impl Default for SHAPValues {
85 fn default() -> Self {
86 Self::new()
87 }
88}
89
90impl SHAPValues {
91 #[must_use]
93 pub fn new() -> Self {
94 Self {
95 metadata: KernelMetadata::batch("ml/shap-values", Domain::StatisticalML)
96 .with_description("Kernel SHAP for model-agnostic feature explanations")
97 .with_throughput(1_000)
98 .with_latency_us(500.0),
99 }
100 }
101
102 pub fn explain<F>(
110 instance: &[f64],
111 background: &DataMatrix,
112 predict_fn: F,
113 config: &SHAPConfig,
114 ) -> SHAPExplanation
115 where
116 F: Fn(&[f64]) -> f64,
117 {
118 let n_features = instance.len();
119
120 if n_features == 0 || background.n_samples == 0 {
121 return SHAPExplanation {
122 base_value: 0.0,
123 shap_values: Vec::new(),
124 feature_names: None,
125 prediction: 0.0,
126 shap_sum: 0.0,
127 };
128 }
129
130 let base_value: f64 = (0..background.n_samples)
132 .map(|i| predict_fn(background.row(i)))
133 .sum::<f64>()
134 / background.n_samples as f64;
135
136 let prediction = predict_fn(instance);
137
138 let shap_values = if config.use_kernel_shap {
140 Self::kernel_shap(instance, background, &predict_fn, config)
141 } else {
142 Self::sampling_shap(instance, background, &predict_fn, config)
143 };
144
145 let shap_sum: f64 = shap_values.iter().sum();
146
147 SHAPExplanation {
148 base_value,
149 shap_values,
150 feature_names: None,
151 prediction,
152 shap_sum,
153 }
154 }
155
156 fn kernel_shap<F>(
158 instance: &[f64],
159 background: &DataMatrix,
160 predict_fn: &F,
161 config: &SHAPConfig,
162 ) -> Vec<f64>
163 where
164 F: Fn(&[f64]) -> f64,
165 {
166 let n_features = instance.len();
167 let n_samples = config.n_samples;
168
169 let mut rng = match config.seed {
170 Some(seed) => StdRng::seed_from_u64(seed),
171 None => StdRng::from_rng(&mut rng()),
172 };
173
174 let mut coalitions: Vec<Vec<bool>> = Vec::with_capacity(n_samples);
176 let mut predictions: Vec<f64> = Vec::with_capacity(n_samples);
177 let mut weights: Vec<f64> = Vec::with_capacity(n_samples);
178
179 coalitions.push(vec![true; n_features]);
181 coalitions.push(vec![false; n_features]);
182
183 for coalition in &coalitions[..2] {
184 let masked = Self::create_masked_instance(instance, background, coalition, &mut rng);
185 predictions.push(predict_fn(&masked));
186 }
187
188 weights.push(1e6); weights.push(1e6); for _ in 2..n_samples {
193 let coalition: Vec<bool> = (0..n_features).map(|_| rng.random_bool(0.5)).collect();
194
195 let z: usize = coalition.iter().filter(|&&b| b).count();
196 let weight = Self::kernel_shap_weight(n_features, z);
197
198 let masked = Self::create_masked_instance(instance, background, &coalition, &mut rng);
199 let pred = predict_fn(&masked);
200
201 coalitions.push(coalition);
202 predictions.push(pred);
203 weights.push(weight);
204 }
205
206 Self::solve_weighted_regression(&coalitions, &predictions, &weights, config.regularization)
208 }
209
210 fn sampling_shap<F>(
212 instance: &[f64],
213 background: &DataMatrix,
214 predict_fn: &F,
215 config: &SHAPConfig,
216 ) -> Vec<f64>
217 where
218 F: Fn(&[f64]) -> f64,
219 {
220 let n_features = instance.len();
221 let mut shap_values = vec![0.0; n_features];
222 let samples_per_feature = config.n_samples / n_features;
223
224 let mut rng = match config.seed {
225 Some(seed) => StdRng::seed_from_u64(seed),
226 None => StdRng::from_rng(&mut rng()),
227 };
228
229 for feature_idx in 0..n_features {
230 let mut contributions = Vec::with_capacity(samples_per_feature);
231
232 for _ in 0..samples_per_feature {
233 let mut perm: Vec<usize> = (0..n_features).collect();
235 perm.shuffle(&mut rng);
236
237 let feature_pos = perm.iter().position(|&i| i == feature_idx).unwrap();
238
239 let before: Vec<bool> = (0..n_features)
241 .map(|i| {
242 let pos = perm.iter().position(|&p| p == i).unwrap();
243 pos < feature_pos
244 })
245 .collect();
246
247 let mut with_feature = before.clone();
249 with_feature[feature_idx] = true;
250
251 let bg_idx = rng.random_range(0..background.n_samples);
253 let bg = background.row(bg_idx);
254
255 let x_with: Vec<f64> = (0..n_features)
257 .map(|i| if with_feature[i] { instance[i] } else { bg[i] })
258 .collect();
259
260 let x_without: Vec<f64> = (0..n_features)
261 .map(|i| if before[i] { instance[i] } else { bg[i] })
262 .collect();
263
264 let contribution = predict_fn(&x_with) - predict_fn(&x_without);
265 contributions.push(contribution);
266 }
267
268 shap_values[feature_idx] =
269 contributions.iter().sum::<f64>() / contributions.len() as f64;
270 }
271
272 shap_values
273 }
274
275 fn kernel_shap_weight(n_features: usize, coalition_size: usize) -> f64 {
277 if coalition_size == 0 || coalition_size == n_features {
278 return 1e6; }
280
281 let m = n_features as f64;
282 let z = coalition_size as f64;
283
284 let binomial = Self::binomial(n_features, coalition_size);
286 if binomial == 0.0 {
287 return 0.0;
288 }
289
290 (m - 1.0) / (binomial * z * (m - z))
291 }
292
293 fn binomial(n: usize, k: usize) -> f64 {
295 if k > n {
296 return 0.0;
297 }
298 let k = k.min(n - k);
299 let mut result = 1.0;
300 for i in 0..k {
301 result *= (n - i) as f64 / (i + 1) as f64;
302 }
303 result
304 }
305
306 fn create_masked_instance(
308 instance: &[f64],
309 background: &DataMatrix,
310 coalition: &[bool],
311 rng: &mut StdRng,
312 ) -> Vec<f64> {
313 let bg_idx = rng.random_range(0..background.n_samples);
314 let bg = background.row(bg_idx);
315
316 coalition
317 .iter()
318 .enumerate()
319 .map(|(i, &included)| if included { instance[i] } else { bg[i] })
320 .collect()
321 }
322
323 fn solve_weighted_regression(
325 coalitions: &[Vec<bool>],
326 predictions: &[f64],
327 weights: &[f64],
328 regularization: f64,
329 ) -> Vec<f64> {
330 if coalitions.is_empty() {
331 return Vec::new();
332 }
333
334 let n_features = coalitions[0].len();
335 let n_samples = coalitions.len();
336
337 let mut x: Vec<Vec<f64>> = Vec::with_capacity(n_samples);
339 for coalition in coalitions {
340 let row: Vec<f64> = coalition
341 .iter()
342 .map(|&b| if b { 1.0 } else { 0.0 })
343 .collect();
344 x.push(row);
345 }
346
347 let mut xtw_x = vec![vec![0.0; n_features]; n_features];
349 for i in 0..n_features {
350 for j in 0..n_features {
351 for k in 0..n_samples {
352 xtw_x[i][j] += x[k][i] * weights[k] * x[k][j];
353 }
354 }
355 }
356
357 for i in 0..n_features {
359 xtw_x[i][i] += regularization;
360 }
361
362 let mut xtw_y = vec![0.0; n_features];
364 for i in 0..n_features {
365 for k in 0..n_samples {
366 xtw_y[i] += x[k][i] * weights[k] * predictions[k];
367 }
368 }
369
370 Self::solve_linear_system(&xtw_x, &xtw_y)
372 }
373
374 fn solve_linear_system(a: &[Vec<f64>], b: &[f64]) -> Vec<f64> {
376 let n = b.len();
377 if n == 0 {
378 return Vec::new();
379 }
380
381 let mut aug: Vec<Vec<f64>> = a
383 .iter()
384 .enumerate()
385 .map(|(i, row)| {
386 let mut new_row = row.clone();
387 new_row.push(b[i]);
388 new_row
389 })
390 .collect();
391
392 for i in 0..n {
394 let mut max_idx = i;
396 let mut max_val = aug[i][i].abs();
397 for k in (i + 1)..n {
398 if aug[k][i].abs() > max_val {
399 max_val = aug[k][i].abs();
400 max_idx = k;
401 }
402 }
403
404 aug.swap(i, max_idx);
405
406 if aug[i][i].abs() < 1e-10 {
407 continue;
408 }
409
410 for k in (i + 1)..n {
411 let factor = aug[k][i] / aug[i][i];
412 for j in i..=n {
413 aug[k][j] -= factor * aug[i][j];
414 }
415 }
416 }
417
418 let mut x = vec![0.0; n];
420 for i in (0..n).rev() {
421 if aug[i][i].abs() < 1e-10 {
422 x[i] = 0.0;
423 continue;
424 }
425 x[i] = aug[i][n];
426 for j in (i + 1)..n {
427 x[i] -= aug[i][j] * x[j];
428 }
429 x[i] /= aug[i][i];
430 }
431
432 x
433 }
434
435 pub fn explain_batch<F>(
437 instances: &DataMatrix,
438 background: &DataMatrix,
439 predict_fn: F,
440 config: &SHAPConfig,
441 feature_names: Option<Vec<String>>,
442 ) -> SHAPBatchResult
443 where
444 F: Fn(&[f64]) -> f64,
445 {
446 if instances.n_samples == 0 {
447 return SHAPBatchResult {
448 base_value: 0.0,
449 shap_values: Vec::new(),
450 feature_names: None,
451 feature_importance: Vec::new(),
452 };
453 }
454
455 let base_value: f64 = (0..background.n_samples)
457 .map(|i| predict_fn(background.row(i)))
458 .sum::<f64>()
459 / background.n_samples.max(1) as f64;
460
461 let mut shap_values: Vec<Vec<f64>> = Vec::with_capacity(instances.n_samples);
463
464 for i in 0..instances.n_samples {
465 let instance = instances.row(i);
466 let explanation = Self::explain(instance, background, &predict_fn, config);
467 shap_values.push(explanation.shap_values);
468 }
469
470 let n_features = instances.n_features;
472 let mut feature_importance = vec![0.0; n_features];
473
474 for values in &shap_values {
475 for (i, &v) in values.iter().enumerate() {
476 feature_importance[i] += v.abs();
477 }
478 }
479
480 for imp in &mut feature_importance {
481 *imp /= shap_values.len() as f64;
482 }
483
484 SHAPBatchResult {
485 base_value,
486 shap_values,
487 feature_names,
488 feature_importance,
489 }
490 }
491}
492
493impl GpuKernel for SHAPValues {
494 fn metadata(&self) -> &KernelMetadata {
495 &self.metadata
496 }
497}
498
499#[derive(Debug, Clone, Serialize, Deserialize)]
505pub struct FeatureImportanceConfig {
506 pub n_permutations: usize,
508 pub seed: Option<u64>,
510 pub metric: ImportanceMetric,
512}
513
514impl Default for FeatureImportanceConfig {
515 fn default() -> Self {
516 Self {
517 n_permutations: 10,
518 seed: None,
519 metric: ImportanceMetric::Accuracy,
520 }
521 }
522}
523
524#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
526pub enum ImportanceMetric {
527 Accuracy,
529 MSE,
531 MAE,
533 R2,
535}
536
537#[derive(Debug, Clone, Serialize, Deserialize)]
539pub struct FeatureImportanceResult {
540 pub importances: Vec<f64>,
542 pub std_devs: Vec<f64>,
544 pub feature_names: Option<Vec<String>>,
546 pub baseline_score: f64,
548 pub ranking: Vec<usize>,
550}
551
552#[derive(Debug, Clone)]
558pub struct FeatureImportance {
559 metadata: KernelMetadata,
560}
561
562impl Default for FeatureImportance {
563 fn default() -> Self {
564 Self::new()
565 }
566}
567
568impl FeatureImportance {
569 #[must_use]
571 pub fn new() -> Self {
572 Self {
573 metadata: KernelMetadata::batch("ml/feature-importance", Domain::StatisticalML)
574 .with_description("Permutation-based feature importance")
575 .with_throughput(5_000)
576 .with_latency_us(200.0),
577 }
578 }
579
580 pub fn compute<F>(
589 data: &DataMatrix,
590 targets: &[f64],
591 predict_fn: F,
592 config: &FeatureImportanceConfig,
593 feature_names: Option<Vec<String>>,
594 ) -> FeatureImportanceResult
595 where
596 F: Fn(&[f64]) -> f64,
597 {
598 if data.n_samples == 0 || data.n_features == 0 {
599 return FeatureImportanceResult {
600 importances: Vec::new(),
601 std_devs: Vec::new(),
602 feature_names: None,
603 baseline_score: 0.0,
604 ranking: Vec::new(),
605 };
606 }
607
608 let mut rng = match config.seed {
609 Some(seed) => StdRng::seed_from_u64(seed),
610 None => StdRng::from_rng(&mut rng()),
611 };
612
613 let predictions: Vec<f64> = (0..data.n_samples)
615 .map(|i| predict_fn(data.row(i)))
616 .collect();
617 let baseline_score = Self::compute_score(&predictions, targets, config.metric);
618
619 let mut importances = Vec::with_capacity(data.n_features);
621 let mut std_devs = Vec::with_capacity(data.n_features);
622
623 for feature_idx in 0..data.n_features {
624 let mut scores = Vec::with_capacity(config.n_permutations);
625
626 for _ in 0..config.n_permutations {
627 let mut perm_data = data.data.clone();
629 let mut perm_indices: Vec<usize> = (0..data.n_samples).collect();
630 perm_indices.shuffle(&mut rng);
631
632 for (i, &perm_idx) in perm_indices.iter().enumerate() {
634 perm_data[i * data.n_features + feature_idx] =
635 data.data[perm_idx * data.n_features + feature_idx];
636 }
637
638 let perm_matrix = DataMatrix::new(perm_data, data.n_samples, data.n_features);
639
640 let perm_predictions: Vec<f64> = (0..perm_matrix.n_samples)
642 .map(|i| predict_fn(perm_matrix.row(i)))
643 .collect();
644
645 let score = Self::compute_score(&perm_predictions, targets, config.metric);
646 scores.push(score);
647 }
648
649 let mean_score: f64 = scores.iter().sum::<f64>() / scores.len() as f64;
651 let importance = baseline_score - mean_score;
652
653 let variance: f64 =
654 scores.iter().map(|s| (s - mean_score).powi(2)).sum::<f64>() / scores.len() as f64;
655 let std_dev = variance.sqrt();
656
657 importances.push(importance);
658 std_devs.push(std_dev);
659 }
660
661 let mut ranking: Vec<usize> = (0..data.n_features).collect();
663 ranking.sort_by(|&a, &b| {
664 importances[b]
665 .partial_cmp(&importances[a])
666 .unwrap_or(std::cmp::Ordering::Equal)
667 });
668
669 FeatureImportanceResult {
670 importances,
671 std_devs,
672 feature_names,
673 baseline_score,
674 ranking,
675 }
676 }
677
678 fn compute_score(predictions: &[f64], targets: &[f64], metric: ImportanceMetric) -> f64 {
680 if predictions.is_empty() || targets.is_empty() {
681 return 0.0;
682 }
683
684 match metric {
685 ImportanceMetric::Accuracy => {
686 let correct: usize = predictions
687 .iter()
688 .zip(targets.iter())
689 .filter(|&(p, t)| (p.round() - t.round()).abs() < 0.5)
690 .count();
691 correct as f64 / predictions.len() as f64
692 }
693 ImportanceMetric::MSE => {
694 let mse: f64 = predictions
695 .iter()
696 .zip(targets.iter())
697 .map(|(p, t)| (p - t).powi(2))
698 .sum::<f64>()
699 / predictions.len() as f64;
700 -mse }
702 ImportanceMetric::MAE => {
703 let mae: f64 = predictions
704 .iter()
705 .zip(targets.iter())
706 .map(|(p, t)| (p - t).abs())
707 .sum::<f64>()
708 / predictions.len() as f64;
709 -mae }
711 ImportanceMetric::R2 => {
712 let mean_target: f64 = targets.iter().sum::<f64>() / targets.len() as f64;
713 let ss_res: f64 = predictions
714 .iter()
715 .zip(targets.iter())
716 .map(|(p, t)| (t - p).powi(2))
717 .sum();
718 let ss_tot: f64 = targets.iter().map(|t| (t - mean_target).powi(2)).sum();
719 if ss_tot.abs() < 1e-10 {
720 0.0
721 } else {
722 1.0 - ss_res / ss_tot
723 }
724 }
725 }
726 }
727}
728
729impl GpuKernel for FeatureImportance {
730 fn metadata(&self) -> &KernelMetadata {
731 &self.metadata
732 }
733}
734
735#[cfg(test)]
736mod tests {
737 use super::*;
738
739 #[test]
740 fn test_shap_values_metadata() {
741 let kernel = SHAPValues::new();
742 assert_eq!(kernel.metadata().id, "ml/shap-values");
743 }
744
745 #[test]
746 fn test_shap_basic() {
747 let predict_fn = |x: &[f64]| x[0] + 2.0 * x[1];
749
750 let background = DataMatrix::new(vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0], 4, 2);
751
752 let config = SHAPConfig {
753 n_samples: 50,
754 use_kernel_shap: true,
755 regularization: 0.1,
756 seed: Some(42),
757 };
758
759 let instance = vec![1.0, 1.0];
760 let explanation = SHAPValues::explain(&instance, &background, predict_fn, &config);
761
762 assert!(explanation.shap_values.len() == 2);
764 assert!(explanation.prediction > 0.0);
765 }
766
767 #[test]
768 fn test_shap_batch() {
769 let predict_fn = |x: &[f64]| x[0] * 2.0;
770
771 let background = DataMatrix::new(vec![0.0, 0.5, 1.0, 1.5], 4, 1);
772 let instances = DataMatrix::new(vec![0.5, 1.0, 2.0], 3, 1);
773
774 let config = SHAPConfig {
775 n_samples: 20,
776 seed: Some(42),
777 ..Default::default()
778 };
779
780 let result = SHAPValues::explain_batch(&instances, &background, predict_fn, &config, None);
781
782 assert_eq!(result.shap_values.len(), 3);
783 assert_eq!(result.feature_importance.len(), 1);
784 }
785
786 #[test]
787 fn test_shap_empty() {
788 let predict_fn = |x: &[f64]| x.iter().sum();
789 let background = DataMatrix::new(vec![], 0, 0);
790 let config = SHAPConfig::default();
791
792 let explanation = SHAPValues::explain(&[], &background, predict_fn, &config);
793 assert!(explanation.shap_values.is_empty());
794 }
795
796 #[test]
797 fn test_kernel_shap_weight() {
798 assert!(SHAPValues::kernel_shap_weight(5, 0) > 1000.0);
800 assert!(SHAPValues::kernel_shap_weight(5, 5) > 1000.0);
801
802 let w = SHAPValues::kernel_shap_weight(5, 2);
804 assert!(w > 0.0 && w < 1000.0);
805 }
806
807 #[test]
808 fn test_feature_importance_metadata() {
809 let kernel = FeatureImportance::new();
810 assert_eq!(kernel.metadata().id, "ml/feature-importance");
811 }
812
813 #[test]
814 fn test_feature_importance_basic() {
815 let predict_fn = |x: &[f64]| x[0];
817
818 let data = DataMatrix::new(
819 vec![1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 0.0, 4.0, 0.0, 0.0],
820 4,
821 3,
822 );
823 let targets = vec![1.0, 2.0, 3.0, 4.0];
824
825 let config = FeatureImportanceConfig {
826 n_permutations: 5,
827 seed: Some(42),
828 metric: ImportanceMetric::MSE,
829 };
830
831 let result = FeatureImportance::compute(&data, &targets, predict_fn, &config, None);
832
833 assert_eq!(result.importances.len(), 3);
835 assert!(result.importances[0].abs() > result.importances[1].abs());
836 assert!(result.importances[0].abs() > result.importances[2].abs());
837 assert_eq!(result.ranking[0], 0);
838 }
839
840 #[test]
841 fn test_feature_importance_empty() {
842 let predict_fn = |_: &[f64]| 0.0;
843 let data = DataMatrix::new(vec![], 0, 0);
844 let targets: Vec<f64> = vec![];
845 let config = FeatureImportanceConfig::default();
846
847 let result = FeatureImportance::compute(&data, &targets, predict_fn, &config, None);
848 assert!(result.importances.is_empty());
849 }
850
851 #[test]
852 fn test_metrics() {
853 let preds = vec![1.0, 2.0, 3.0];
854 let targets = vec![1.0, 2.0, 3.0];
855
856 let acc = FeatureImportance::compute_score(&preds, &targets, ImportanceMetric::Accuracy);
858 assert!((acc - 1.0).abs() < 0.01);
859
860 let mse = FeatureImportance::compute_score(&preds, &targets, ImportanceMetric::MSE);
861 assert!((mse - 0.0).abs() < 0.01);
862
863 let r2 = FeatureImportance::compute_score(&preds, &targets, ImportanceMetric::R2);
864 assert!((r2 - 1.0).abs() < 0.01);
865 }
866
867 #[test]
868 fn test_binomial() {
869 assert!((SHAPValues::binomial(5, 2) - 10.0).abs() < 0.01);
870 assert!((SHAPValues::binomial(10, 3) - 120.0).abs() < 0.01);
871 assert!((SHAPValues::binomial(5, 0) - 1.0).abs() < 0.01);
872 assert!((SHAPValues::binomial(5, 5) - 1.0).abs() < 0.01);
873 }
874}