tensorlogic_train/metrics/
calibration.rs1use crate::{TrainError, TrainResult};
4use scirs2_core::ndarray::{ArrayView, Ix2};
5
6use super::Metric;
7
8#[derive(Debug, Clone)]
18pub struct ExpectedCalibrationError {
19 pub num_bins: usize,
21}
22
23impl Default for ExpectedCalibrationError {
24 fn default() -> Self {
25 Self { num_bins: 10 }
26 }
27}
28
29impl ExpectedCalibrationError {
30 pub fn new(num_bins: usize) -> Self {
32 Self { num_bins }
33 }
34}
35
36impl Metric for ExpectedCalibrationError {
37 fn compute(
38 &self,
39 predictions: &ArrayView<f64, Ix2>,
40 targets: &ArrayView<f64, Ix2>,
41 ) -> TrainResult<f64> {
42 if predictions.shape() != targets.shape() {
43 return Err(TrainError::MetricsError(format!(
44 "Shape mismatch: predictions {:?} vs targets {:?}",
45 predictions.shape(),
46 targets.shape()
47 )));
48 }
49
50 let n_samples = predictions.nrows();
51 if n_samples == 0 {
52 return Ok(0.0);
53 }
54
55 let mut bin_counts = vec![0usize; self.num_bins];
57 let mut bin_confidences = vec![0.0; self.num_bins];
58 let mut bin_accuracies = vec![0.0; self.num_bins];
59
60 for i in 0..n_samples {
61 let mut pred_class = 0;
63 let mut max_confidence = predictions[[i, 0]];
64 for j in 1..predictions.ncols() {
65 if predictions[[i, j]] > max_confidence {
66 max_confidence = predictions[[i, j]];
67 pred_class = j;
68 }
69 }
70
71 let mut true_class = 0;
73 let mut max_target = targets[[i, 0]];
74 for j in 1..targets.ncols() {
75 if targets[[i, j]] > max_target {
76 max_target = targets[[i, j]];
77 true_class = j;
78 }
79 }
80
81 let bin_idx =
83 ((max_confidence * self.num_bins as f64).floor() as usize).min(self.num_bins - 1);
84
85 bin_counts[bin_idx] += 1;
87 bin_confidences[bin_idx] += max_confidence;
88 if pred_class == true_class {
89 bin_accuracies[bin_idx] += 1.0;
90 }
91 }
92
93 let mut ece = 0.0;
95 for i in 0..self.num_bins {
96 if bin_counts[i] > 0 {
97 let bin_confidence = bin_confidences[i] / bin_counts[i] as f64;
98 let bin_accuracy = bin_accuracies[i] / bin_counts[i] as f64;
99 let weight = bin_counts[i] as f64 / n_samples as f64;
100
101 ece += weight * (bin_confidence - bin_accuracy).abs();
102 }
103 }
104
105 Ok(ece)
106 }
107
108 fn name(&self) -> &str {
109 "expected_calibration_error"
110 }
111}
112
113#[derive(Debug, Clone)]
123pub struct MaximumCalibrationError {
124 pub num_bins: usize,
126}
127
128impl Default for MaximumCalibrationError {
129 fn default() -> Self {
130 Self { num_bins: 10 }
131 }
132}
133
134impl MaximumCalibrationError {
135 pub fn new(num_bins: usize) -> Self {
137 Self { num_bins }
138 }
139}
140
141impl Metric for MaximumCalibrationError {
142 fn compute(
143 &self,
144 predictions: &ArrayView<f64, Ix2>,
145 targets: &ArrayView<f64, Ix2>,
146 ) -> TrainResult<f64> {
147 if predictions.shape() != targets.shape() {
148 return Err(TrainError::MetricsError(format!(
149 "Shape mismatch: predictions {:?} vs targets {:?}",
150 predictions.shape(),
151 targets.shape()
152 )));
153 }
154
155 let n_samples = predictions.nrows();
156 if n_samples == 0 {
157 return Ok(0.0);
158 }
159
160 let mut bin_counts = vec![0usize; self.num_bins];
162 let mut bin_confidences = vec![0.0; self.num_bins];
163 let mut bin_accuracies = vec![0.0; self.num_bins];
164
165 for i in 0..n_samples {
166 let mut pred_class = 0;
168 let mut max_confidence = predictions[[i, 0]];
169 for j in 1..predictions.ncols() {
170 if predictions[[i, j]] > max_confidence {
171 max_confidence = predictions[[i, j]];
172 pred_class = j;
173 }
174 }
175
176 let mut true_class = 0;
178 let mut max_target = targets[[i, 0]];
179 for j in 1..targets.ncols() {
180 if targets[[i, j]] > max_target {
181 max_target = targets[[i, j]];
182 true_class = j;
183 }
184 }
185
186 let bin_idx =
188 ((max_confidence * self.num_bins as f64).floor() as usize).min(self.num_bins - 1);
189
190 bin_counts[bin_idx] += 1;
192 bin_confidences[bin_idx] += max_confidence;
193 if pred_class == true_class {
194 bin_accuracies[bin_idx] += 1.0;
195 }
196 }
197
198 let mut mce: f64 = 0.0;
200 for i in 0..self.num_bins {
201 if bin_counts[i] > 0 {
202 let bin_confidence = bin_confidences[i] / bin_counts[i] as f64;
203 let bin_accuracy = bin_accuracies[i] / bin_counts[i] as f64;
204 let calibration_error = (bin_confidence - bin_accuracy).abs();
205
206 mce = mce.max(calibration_error);
207 }
208 }
209
210 Ok(mce)
211 }
212
213 fn name(&self) -> &str {
214 "maximum_calibration_error"
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221 use scirs2_core::ndarray::array;
222
223 #[test]
224 fn test_expected_calibration_error_perfect() {
225 let metric = ExpectedCalibrationError::default();
226
227 let predictions = array![[0.95, 0.05], [0.05, 0.95], [0.95, 0.05], [0.05, 0.95]];
230 let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
231
232 let ece = metric
233 .compute(&predictions.view(), &targets.view())
234 .unwrap();
235
236 assert!(ece < 0.1);
238 }
239
240 #[test]
241 fn test_expected_calibration_error_poor() {
242 let metric = ExpectedCalibrationError::default();
243
244 let predictions = array![[0.9, 0.1], [0.9, 0.1], [0.9, 0.1], [0.9, 0.1]];
246 let targets = array![[0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [0.0, 1.0]];
247
248 let ece = metric
249 .compute(&predictions.view(), &targets.view())
250 .unwrap();
251
252 assert!(ece > 0.5);
254 }
255
256 #[test]
257 fn test_expected_calibration_error_custom_bins() {
258 let metric = ExpectedCalibrationError::new(5); let predictions = array![[0.9, 0.1], [0.8, 0.2], [0.7, 0.3], [0.6, 0.4]];
261 let targets = array![[1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [1.0, 0.0]];
262
263 let ece = metric
264 .compute(&predictions.view(), &targets.view())
265 .unwrap();
266
267 assert!((0.0..=1.0).contains(&ece));
268 }
269
270 #[test]
271 fn test_maximum_calibration_error_perfect() {
272 let metric = MaximumCalibrationError::default();
273
274 let predictions = array![[0.95, 0.05], [0.05, 0.95], [0.95, 0.05], [0.05, 0.95]];
276 let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
277
278 let mce = metric
279 .compute(&predictions.view(), &targets.view())
280 .unwrap();
281
282 assert!(mce < 0.15);
284 }
285
286 #[test]
287 fn test_maximum_calibration_error_poor() {
288 let metric = MaximumCalibrationError::default();
289
290 let predictions = array![[0.9, 0.1], [0.9, 0.1], [0.9, 0.1], [0.9, 0.1]];
292 let targets = array![[0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [0.0, 1.0]];
293
294 let mce = metric
295 .compute(&predictions.view(), &targets.view())
296 .unwrap();
297
298 assert!(mce > 0.5);
300 }
301
302 #[test]
303 fn test_calibration_metrics_empty() {
304 let ece_metric = ExpectedCalibrationError::default();
305 let mce_metric = MaximumCalibrationError::default();
306
307 use scirs2_core::ndarray::Array;
308 let empty_predictions: Array<f64, _> = Array::zeros((0, 2));
309 let empty_targets: Array<f64, _> = Array::zeros((0, 2));
310
311 let ece = ece_metric
312 .compute(&empty_predictions.view(), &empty_targets.view())
313 .unwrap();
314 let mce = mce_metric
315 .compute(&empty_predictions.view(), &empty_targets.view())
316 .unwrap();
317
318 assert_eq!(ece, 0.0);
319 assert_eq!(mce, 0.0);
320 }
321
322 #[test]
323 fn test_calibration_metrics_shape_mismatch() {
324 let metric = ExpectedCalibrationError::default();
325
326 let predictions = array![[0.9, 0.1], [0.8, 0.2]];
327 let targets = array![[1.0, 0.0, 0.0]]; let result = metric.compute(&predictions.view(), &targets.view());
330 assert!(result.is_err());
331 }
332}