sklears_model_selection/epistemic_uncertainty/
calibration.rs1use scirs2_core::ndarray::Array1;
2#[derive(Debug, Clone)]
5pub enum CalibrationMethod {
6 PlattScaling,
8 IsotonicRegression,
10 TemperatureScaling,
12 HistogramBinning {
14 n_bins: usize,
15 },
16
17 None,
18}
19
20pub fn apply_temperature_scaling(logits: &Array1<f64>, temperature: f64) -> Array1<f64> {
21 logits.mapv(|x| x / temperature)
22}
23
24pub fn compute_calibration_error(
25 confidences: &Array1<f64>,
26 accuracies: &Array1<f64>,
27 n_bins: usize,
28) -> f64 {
29 let bin_boundaries = Array1::<f64>::linspace(0.0, 1.0, n_bins + 1);
30 let mut calibration_error = 0.0;
31 let n_samples = confidences.len();
32
33 for i in 0..n_bins {
34 let lower_bound = bin_boundaries[i];
35 let upper_bound = bin_boundaries[i + 1];
36
37 let mask: Vec<bool> = confidences
38 .iter()
39 .map(|&conf| conf > lower_bound && conf <= upper_bound)
40 .collect();
41
42 let bin_size = mask.iter().filter(|&&m| m).count();
43 if bin_size > 0 {
44 let bin_accuracy: f64 = mask
45 .iter()
46 .zip(accuracies.iter())
47 .filter(|(&m, _)| m)
48 .map(|(_, &acc)| acc)
49 .sum::<f64>()
50 / bin_size as f64;
51
52 let bin_confidence: f64 = mask
53 .iter()
54 .zip(confidences.iter())
55 .filter(|(&m, _)| m)
56 .map(|(_, &conf)| conf)
57 .sum::<f64>()
58 / bin_size as f64;
59
60 calibration_error +=
61 (bin_size as f64 / n_samples as f64) * (bin_confidence - bin_accuracy).abs();
62 }
63 }
64
65 calibration_error
66}
67
68pub fn platt_scaling(
69 scores: &Array1<f64>,
70 labels: &Array1<f64>,
71) -> Result<(f64, f64), Box<dyn std::error::Error>> {
72 let mut a = 1.0;
73 let mut b = 0.0;
74
75 for _ in 0..100 {
76 let mut gradient_a = 0.0;
77 let mut gradient_b = 0.0;
78 let mut hessian_aa = 0.0;
79 let mut hessian_ab = 0.0;
80 let mut hessian_bb = 0.0;
81
82 for (&score, &label) in scores.iter().zip(labels.iter()) {
83 let z = a * score + b;
84 let p = 1.0 / (1.0 + (-z).exp());
85
86 let error = p - label;
87 gradient_a += error * score;
88 gradient_b += error;
89
90 let weight = p * (1.0 - p);
91 hessian_aa += weight * score * score;
92 hessian_ab += weight * score;
93 hessian_bb += weight;
94 }
95
96 let det = hessian_aa * hessian_bb - hessian_ab * hessian_ab;
97 if det.abs() < 1e-10 {
98 break;
99 }
100
101 let delta_a = -(hessian_bb * gradient_a - hessian_ab * gradient_b) / det;
102 let delta_b = -(hessian_aa * gradient_b - hessian_ab * gradient_a) / det;
103
104 a += delta_a;
105 b += delta_b;
106
107 if (delta_a.abs() + delta_b.abs()) < 1e-6 {
108 break;
109 }
110 }
111
112 Ok((a, b))
113}