Skip to main content

scry_learn/explain/
permutation.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Model-agnostic permutation importance (Breiman 2001).
3//!
4//! Measures the decrease in a scoring function when each feature is
5//! randomly permuted, breaking the association between the feature and
6//! the target.
7
8use crate::rng::FastRng;
9
10/// Result of permutation importance analysis.
11#[derive(Clone, Debug)]
12#[non_exhaustive]
13pub struct PermutationImportance {
14    /// Mean score decrease per feature (higher = more important).
15    pub importances_mean: Vec<f64>,
16    /// Standard deviation of score decrease per feature.
17    pub importances_std: Vec<f64>,
18    /// Raw score decreases: `importances_raw[feature][repeat]`.
19    pub importances_raw: Vec<Vec<f64>>,
20}
21
22/// Compute permutation importance for any model.
23///
24/// # Arguments
25///
26/// * `features` - Column-major feature matrix: `features[feature_idx][sample_idx]`.
27/// * `target` - Target values, one per sample.
28/// * `predict` - Prediction function: given column-major features, returns predictions.
29/// * `scorer` - Scoring function: `scorer(y_true, y_pred) -> score`.
30///   Higher is better (e.g. accuracy, R2). The importance is
31///   `baseline_score - permuted_score`.
32/// * `n_repeats` - Number of times to permute each feature (default: 5).
33/// * `seed` - RNG seed for reproducibility.
34///
35/// # Returns
36///
37/// A `PermutationImportance` with mean, std, and raw importances per feature.
38///
39/// # Panics
40///
41/// Panics if `features` is empty or if feature columns have different lengths.
42pub fn permutation_importance(
43    features: &[Vec<f64>],
44    target: &[f64],
45    predict: &dyn Fn(&[Vec<f64>]) -> Vec<f64>,
46    scorer: fn(&[f64], &[f64]) -> f64,
47    n_repeats: usize,
48    seed: u64,
49) -> PermutationImportance {
50    assert!(!features.is_empty(), "features must not be empty");
51    let n_features = features.len();
52    let n_samples = features[0].len();
53    assert_eq!(
54        target.len(),
55        n_samples,
56        "target length must match number of samples"
57    );
58
59    // Compute baseline score on unperturbed data.
60    let baseline_preds = predict(features);
61    let baseline_score = scorer(target, &baseline_preds);
62
63    let mut rng = FastRng::new(seed);
64    let mut importances_raw = vec![Vec::with_capacity(n_repeats); n_features];
65
66    // Pre-allocate a mutable copy of the features for permutation.
67    let mut permuted = features.to_vec();
68
69    for feat_idx in 0..n_features {
70        for _ in 0..n_repeats {
71            // Save the original column.
72            let original_col = features[feat_idx].clone();
73
74            // Create a shuffled index array and apply it.
75            let mut indices: Vec<usize> = (0..n_samples).collect();
76            rng.shuffle(&mut indices);
77
78            for (i, &idx) in indices.iter().enumerate() {
79                permuted[feat_idx][i] = original_col[idx];
80            }
81
82            // Score with permuted feature.
83            let permuted_preds = predict(&permuted);
84            let permuted_score = scorer(target, &permuted_preds);
85
86            importances_raw[feat_idx].push(baseline_score - permuted_score);
87
88            // Restore original column.
89            permuted[feat_idx].clone_from(&features[feat_idx]);
90        }
91    }
92
93    let importances_mean: Vec<f64> = importances_raw
94        .iter()
95        .map(|raw| raw.iter().sum::<f64>() / raw.len() as f64)
96        .collect();
97
98    let importances_std: Vec<f64> = importances_raw
99        .iter()
100        .zip(importances_mean.iter())
101        .map(|(raw, &mean)| {
102            if raw.len() <= 1 {
103                return 0.0;
104            }
105            let variance =
106                raw.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (raw.len() - 1) as f64;
107            variance.sqrt()
108        })
109        .collect();
110
111    PermutationImportance {
112        importances_mean,
113        importances_std,
114        importances_raw,
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121
122    #[test]
123    fn test_permutation_importance_basic() {
124        // Feature 0 is predictive (y = x0), feature 1 is noise.
125        let n = 100;
126        let mut rng = FastRng::new(42);
127        let f0: Vec<f64> = (0..n).map(|i| i as f64).collect();
128        let f1: Vec<f64> = (0..n).map(|_| rng.f64() * 100.0).collect();
129        let target: Vec<f64> = f0.clone();
130        let features = vec![f0, f1];
131
132        let predict = |feats: &[Vec<f64>]| -> Vec<f64> {
133            // Simple model: predict = feature 0.
134            feats[0].clone()
135        };
136
137        let scorer = |y_true: &[f64], y_pred: &[f64]| -> f64 {
138            // Negative MSE (higher is better).
139            let mse = y_true
140                .iter()
141                .zip(y_pred.iter())
142                .map(|(t, p)| (t - p).powi(2))
143                .sum::<f64>()
144                / y_true.len() as f64;
145            -mse
146        };
147
148        let result = permutation_importance(&features, &target, &predict, scorer, 5, 42);
149
150        assert_eq!(result.importances_mean.len(), 2);
151        // Feature 0 should have high importance (positive score decrease).
152        assert!(
153            result.importances_mean[0] > 0.0,
154            "Feature 0 should be important: {}",
155            result.importances_mean[0]
156        );
157        // Feature 1 should have near-zero importance.
158        assert!(
159            result.importances_mean[1].abs() < result.importances_mean[0].abs() * 0.1,
160            "Feature 1 should be less important: {} vs {}",
161            result.importances_mean[1],
162            result.importances_mean[0]
163        );
164    }
165}