sklears_model_selection/epistemic_uncertainty/
calibration.rs

1use scirs2_core::ndarray::Array1;
2// use scirs2_core::numeric::Float;
3
4#[derive(Debug, Clone)]
5pub enum CalibrationMethod {
6    /// PlattScaling
7    PlattScaling,
8    /// IsotonicRegression
9    IsotonicRegression,
10    /// TemperatureScaling
11    TemperatureScaling,
12    /// HistogramBinning
13    HistogramBinning {
14        n_bins: usize,
15    },
16
17    None,
18}
19
20pub fn apply_temperature_scaling(logits: &Array1<f64>, temperature: f64) -> Array1<f64> {
21    logits.mapv(|x| x / temperature)
22}
23
24pub fn compute_calibration_error(
25    confidences: &Array1<f64>,
26    accuracies: &Array1<f64>,
27    n_bins: usize,
28) -> f64 {
29    let bin_boundaries = Array1::<f64>::linspace(0.0, 1.0, n_bins + 1);
30    let mut calibration_error = 0.0;
31    let n_samples = confidences.len();
32
33    for i in 0..n_bins {
34        let lower_bound = bin_boundaries[i];
35        let upper_bound = bin_boundaries[i + 1];
36
37        let mask: Vec<bool> = confidences
38            .iter()
39            .map(|&conf| conf > lower_bound && conf <= upper_bound)
40            .collect();
41
42        let bin_size = mask.iter().filter(|&&m| m).count();
43        if bin_size > 0 {
44            let bin_accuracy: f64 = mask
45                .iter()
46                .zip(accuracies.iter())
47                .filter(|(&m, _)| m)
48                .map(|(_, &acc)| acc)
49                .sum::<f64>()
50                / bin_size as f64;
51
52            let bin_confidence: f64 = mask
53                .iter()
54                .zip(confidences.iter())
55                .filter(|(&m, _)| m)
56                .map(|(_, &conf)| conf)
57                .sum::<f64>()
58                / bin_size as f64;
59
60            calibration_error +=
61                (bin_size as f64 / n_samples as f64) * (bin_confidence - bin_accuracy).abs();
62        }
63    }
64
65    calibration_error
66}
67
68pub fn platt_scaling(
69    scores: &Array1<f64>,
70    labels: &Array1<f64>,
71) -> Result<(f64, f64), Box<dyn std::error::Error>> {
72    let mut a = 1.0;
73    let mut b = 0.0;
74
75    for _ in 0..100 {
76        let mut gradient_a = 0.0;
77        let mut gradient_b = 0.0;
78        let mut hessian_aa = 0.0;
79        let mut hessian_ab = 0.0;
80        let mut hessian_bb = 0.0;
81
82        for (&score, &label) in scores.iter().zip(labels.iter()) {
83            let z = a * score + b;
84            let p = 1.0 / (1.0 + (-z).exp());
85
86            let error = p - label;
87            gradient_a += error * score;
88            gradient_b += error;
89
90            let weight = p * (1.0 - p);
91            hessian_aa += weight * score * score;
92            hessian_ab += weight * score;
93            hessian_bb += weight;
94        }
95
96        let det = hessian_aa * hessian_bb - hessian_ab * hessian_ab;
97        if det.abs() < 1e-10 {
98            break;
99        }
100
101        let delta_a = -(hessian_bb * gradient_a - hessian_ab * gradient_b) / det;
102        let delta_b = -(hessian_aa * gradient_b - hessian_ab * gradient_a) / det;
103
104        a += delta_a;
105        b += delta_b;
106
107        if (delta_a.abs() + delta_b.abs()) < 1e-6 {
108            break;
109        }
110    }
111
112    Ok((a, b))
113}