Skip to main content

tensorlogic_train/metrics/
calibration.rs

1//! Calibration metrics.
2
3use crate::{TrainError, TrainResult};
4use scirs2_core::ndarray::{ArrayView, Ix2};
5
6use super::Metric;
7
8/// Expected Calibration Error (ECE) metric.
9///
10/// Measures the difference between predicted probabilities and actual accuracy.
11/// ECE divides predictions into bins and computes the average difference between
12/// confidence and accuracy across bins, weighted by bin frequency.
13///
14/// Lower ECE indicates better calibration.
15///
16/// Reference: Guo et al. "On Calibration of Modern Neural Networks" (ICML 2017)
17#[derive(Debug, Clone)]
18pub struct ExpectedCalibrationError {
19    /// Number of bins for calibration
20    pub num_bins: usize,
21}
22
23impl Default for ExpectedCalibrationError {
24    fn default() -> Self {
25        Self { num_bins: 10 }
26    }
27}
28
29impl ExpectedCalibrationError {
30    /// Create with custom number of bins.
31    pub fn new(num_bins: usize) -> Self {
32        Self { num_bins }
33    }
34}
35
36impl Metric for ExpectedCalibrationError {
37    fn compute(
38        &self,
39        predictions: &ArrayView<f64, Ix2>,
40        targets: &ArrayView<f64, Ix2>,
41    ) -> TrainResult<f64> {
42        if predictions.shape() != targets.shape() {
43            return Err(TrainError::MetricsError(format!(
44                "Shape mismatch: predictions {:?} vs targets {:?}",
45                predictions.shape(),
46                targets.shape()
47            )));
48        }
49
50        let n_samples = predictions.nrows();
51        if n_samples == 0 {
52            return Ok(0.0);
53        }
54
55        // Initialize bins
56        let mut bin_counts = vec![0usize; self.num_bins];
57        let mut bin_confidences = vec![0.0; self.num_bins];
58        let mut bin_accuracies = vec![0.0; self.num_bins];
59
60        for i in 0..n_samples {
61            // Get predicted class and confidence
62            let mut pred_class = 0;
63            let mut max_confidence = predictions[[i, 0]];
64            for j in 1..predictions.ncols() {
65                if predictions[[i, j]] > max_confidence {
66                    max_confidence = predictions[[i, j]];
67                    pred_class = j;
68                }
69            }
70
71            // Get true class
72            let mut true_class = 0;
73            let mut max_target = targets[[i, 0]];
74            for j in 1..targets.ncols() {
75                if targets[[i, j]] > max_target {
76                    max_target = targets[[i, j]];
77                    true_class = j;
78                }
79            }
80
81            // Determine bin index
82            let bin_idx =
83                ((max_confidence * self.num_bins as f64).floor() as usize).min(self.num_bins - 1);
84
85            // Update bin statistics
86            bin_counts[bin_idx] += 1;
87            bin_confidences[bin_idx] += max_confidence;
88            if pred_class == true_class {
89                bin_accuracies[bin_idx] += 1.0;
90            }
91        }
92
93        // Compute ECE
94        let mut ece = 0.0;
95        for i in 0..self.num_bins {
96            if bin_counts[i] > 0 {
97                let bin_confidence = bin_confidences[i] / bin_counts[i] as f64;
98                let bin_accuracy = bin_accuracies[i] / bin_counts[i] as f64;
99                let weight = bin_counts[i] as f64 / n_samples as f64;
100
101                ece += weight * (bin_confidence - bin_accuracy).abs();
102            }
103        }
104
105        Ok(ece)
106    }
107
108    fn name(&self) -> &str {
109        "expected_calibration_error"
110    }
111}
112
113/// Maximum Calibration Error (MCE) metric.
114///
115/// Measures the worst-case calibration error across all bins.
116/// MCE is the maximum absolute difference between confidence and accuracy
117/// in any bin.
118///
119/// Lower MCE indicates better calibration.
120///
121/// Reference: Guo et al. "On Calibration of Modern Neural Networks" (ICML 2017)
122#[derive(Debug, Clone)]
123pub struct MaximumCalibrationError {
124    /// Number of bins for calibration
125    pub num_bins: usize,
126}
127
128impl Default for MaximumCalibrationError {
129    fn default() -> Self {
130        Self { num_bins: 10 }
131    }
132}
133
134impl MaximumCalibrationError {
135    /// Create with custom number of bins.
136    pub fn new(num_bins: usize) -> Self {
137        Self { num_bins }
138    }
139}
140
141impl Metric for MaximumCalibrationError {
142    fn compute(
143        &self,
144        predictions: &ArrayView<f64, Ix2>,
145        targets: &ArrayView<f64, Ix2>,
146    ) -> TrainResult<f64> {
147        if predictions.shape() != targets.shape() {
148            return Err(TrainError::MetricsError(format!(
149                "Shape mismatch: predictions {:?} vs targets {:?}",
150                predictions.shape(),
151                targets.shape()
152            )));
153        }
154
155        let n_samples = predictions.nrows();
156        if n_samples == 0 {
157            return Ok(0.0);
158        }
159
160        // Initialize bins
161        let mut bin_counts = vec![0usize; self.num_bins];
162        let mut bin_confidences = vec![0.0; self.num_bins];
163        let mut bin_accuracies = vec![0.0; self.num_bins];
164
165        for i in 0..n_samples {
166            // Get predicted class and confidence
167            let mut pred_class = 0;
168            let mut max_confidence = predictions[[i, 0]];
169            for j in 1..predictions.ncols() {
170                if predictions[[i, j]] > max_confidence {
171                    max_confidence = predictions[[i, j]];
172                    pred_class = j;
173                }
174            }
175
176            // Get true class
177            let mut true_class = 0;
178            let mut max_target = targets[[i, 0]];
179            for j in 1..targets.ncols() {
180                if targets[[i, j]] > max_target {
181                    max_target = targets[[i, j]];
182                    true_class = j;
183                }
184            }
185
186            // Determine bin index
187            let bin_idx =
188                ((max_confidence * self.num_bins as f64).floor() as usize).min(self.num_bins - 1);
189
190            // Update bin statistics
191            bin_counts[bin_idx] += 1;
192            bin_confidences[bin_idx] += max_confidence;
193            if pred_class == true_class {
194                bin_accuracies[bin_idx] += 1.0;
195            }
196        }
197
198        // Compute MCE (maximum calibration error)
199        let mut mce: f64 = 0.0;
200        for i in 0..self.num_bins {
201            if bin_counts[i] > 0 {
202                let bin_confidence = bin_confidences[i] / bin_counts[i] as f64;
203                let bin_accuracy = bin_accuracies[i] / bin_counts[i] as f64;
204                let calibration_error = (bin_confidence - bin_accuracy).abs();
205
206                mce = mce.max(calibration_error);
207            }
208        }
209
210        Ok(mce)
211    }
212
213    fn name(&self) -> &str {
214        "maximum_calibration_error"
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221    use scirs2_core::ndarray::array;
222
223    #[test]
224    fn test_expected_calibration_error_perfect() {
225        let metric = ExpectedCalibrationError::default();
226
227        // Perfect calibration: confidence matches accuracy
228        // All predictions at 100% confidence and all correct
229        let predictions = array![[0.95, 0.05], [0.05, 0.95], [0.95, 0.05], [0.05, 0.95]];
230        let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
231
232        let ece = metric
233            .compute(&predictions.view(), &targets.view())
234            .unwrap();
235
236        // Should be very small for perfectly calibrated predictions
237        assert!(ece < 0.1);
238    }
239
240    #[test]
241    fn test_expected_calibration_error_poor() {
242        let metric = ExpectedCalibrationError::default();
243
244        // Poor calibration: high confidence but wrong predictions
245        let predictions = array![[0.9, 0.1], [0.9, 0.1], [0.9, 0.1], [0.9, 0.1]];
246        let targets = array![[0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [0.0, 1.0]];
247
248        let ece = metric
249            .compute(&predictions.view(), &targets.view())
250            .unwrap();
251
252        // Should be high for poorly calibrated predictions
253        assert!(ece > 0.5);
254    }
255
256    #[test]
257    fn test_expected_calibration_error_custom_bins() {
258        let metric = ExpectedCalibrationError::new(5); // Use 5 bins instead of 10
259
260        let predictions = array![[0.9, 0.1], [0.8, 0.2], [0.7, 0.3], [0.6, 0.4]];
261        let targets = array![[1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [1.0, 0.0]];
262
263        let ece = metric
264            .compute(&predictions.view(), &targets.view())
265            .unwrap();
266
267        assert!((0.0..=1.0).contains(&ece));
268    }
269
270    #[test]
271    fn test_maximum_calibration_error_perfect() {
272        let metric = MaximumCalibrationError::default();
273
274        // Perfect calibration
275        let predictions = array![[0.95, 0.05], [0.05, 0.95], [0.95, 0.05], [0.05, 0.95]];
276        let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
277
278        let mce = metric
279            .compute(&predictions.view(), &targets.view())
280            .unwrap();
281
282        // Should be small for well-calibrated predictions
283        assert!(mce < 0.15);
284    }
285
286    #[test]
287    fn test_maximum_calibration_error_poor() {
288        let metric = MaximumCalibrationError::default();
289
290        // One bin with very poor calibration
291        let predictions = array![[0.9, 0.1], [0.9, 0.1], [0.9, 0.1], [0.9, 0.1]];
292        let targets = array![[0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [0.0, 1.0]];
293
294        let mce = metric
295            .compute(&predictions.view(), &targets.view())
296            .unwrap();
297
298        // MCE should capture the worst bin
299        assert!(mce > 0.5);
300    }
301
302    #[test]
303    fn test_calibration_metrics_empty() {
304        let ece_metric = ExpectedCalibrationError::default();
305        let mce_metric = MaximumCalibrationError::default();
306
307        use scirs2_core::ndarray::Array;
308        let empty_predictions: Array<f64, _> = Array::zeros((0, 2));
309        let empty_targets: Array<f64, _> = Array::zeros((0, 2));
310
311        let ece = ece_metric
312            .compute(&empty_predictions.view(), &empty_targets.view())
313            .unwrap();
314        let mce = mce_metric
315            .compute(&empty_predictions.view(), &empty_targets.view())
316            .unwrap();
317
318        assert_eq!(ece, 0.0);
319        assert_eq!(mce, 0.0);
320    }
321
322    #[test]
323    fn test_calibration_metrics_shape_mismatch() {
324        let metric = ExpectedCalibrationError::default();
325
326        let predictions = array![[0.9, 0.1], [0.8, 0.2]];
327        let targets = array![[1.0, 0.0, 0.0]]; // Wrong shape
328
329        let result = metric.compute(&predictions.view(), &targets.view());
330        assert!(result.is_err());
331    }
332}