Skip to main content

yscv_eval/
metrics.rs

1//! Advanced evaluation metrics: top-k accuracy, ROC/AUC, IoU, SSIM, PSNR.
2
3use crate::EvalError;
4
5/// Top-k accuracy: fraction of samples where the correct label is in the top-k predictions.
6///
7/// `scores`: `[N, C]` matrix (N samples, C classes) — raw scores or probabilities.
8/// `targets`: `[N]` — true class indices.
9pub fn top_k_accuracy(
10    scores: &[f32],
11    num_classes: usize,
12    targets: &[usize],
13    k: usize,
14) -> Result<f32, EvalError> {
15    if scores.is_empty() || num_classes == 0 {
16        return Ok(0.0);
17    }
18    let n = scores.len() / num_classes;
19    if n != targets.len() {
20        return Err(EvalError::CountLengthMismatch {
21            ground_truth: targets.len(),
22            predictions: n,
23        });
24    }
25
26    let mut correct = 0;
27    for i in 0..n {
28        let row = &scores[i * num_classes..(i + 1) * num_classes];
29        let mut indices: Vec<usize> = (0..num_classes).collect();
30        indices.sort_unstable_by(|&a, &b| {
31            row[b]
32                .partial_cmp(&row[a])
33                .unwrap_or(std::cmp::Ordering::Equal)
34        });
35        if indices[..k.min(num_classes)].contains(&targets[i]) {
36            correct += 1;
37        }
38    }
39    Ok(correct as f32 / n as f32)
40}
41
42/// Compute ROC curve from binary classification scores and labels.
43///
44/// Returns `(fpr, tpr, thresholds)` — sorted by decreasing threshold.
45pub fn roc_curve(
46    scores: &[f32],
47    labels: &[bool],
48) -> Result<(Vec<f32>, Vec<f32>, Vec<f32>), EvalError> {
49    if scores.len() != labels.len() {
50        return Err(EvalError::CountLengthMismatch {
51            ground_truth: labels.len(),
52            predictions: scores.len(),
53        });
54    }
55
56    let n = scores.len();
57    let total_pos = labels.iter().filter(|&&l| l).count() as f32;
58    let total_neg = n as f32 - total_pos;
59
60    if total_pos == 0.0 || total_neg == 0.0 {
61        return Ok((
62            vec![0.0, 1.0],
63            vec![0.0, 1.0],
64            vec![f32::INFINITY, f32::NEG_INFINITY],
65        ));
66    }
67
68    // Sort by score descending
69    let mut indices: Vec<usize> = (0..n).collect();
70    indices.sort_unstable_by(|&a, &b| {
71        scores[b]
72            .partial_cmp(&scores[a])
73            .unwrap_or(std::cmp::Ordering::Equal)
74    });
75
76    let mut fpr_list = vec![0.0f32];
77    let mut tpr_list = vec![0.0f32];
78    let mut thresholds = vec![f32::INFINITY];
79
80    let mut tp = 0.0f32;
81    let mut fp = 0.0f32;
82
83    for &i in &indices {
84        if labels[i] {
85            tp += 1.0;
86        } else {
87            fp += 1.0;
88        }
89        fpr_list.push(fp / total_neg);
90        tpr_list.push(tp / total_pos);
91        thresholds.push(scores[i]);
92    }
93
94    Ok((fpr_list, tpr_list, thresholds))
95}
96
97/// Area under the curve using the trapezoidal rule.
98///
99/// `x` and `y` must have the same length and be sorted by x.
100pub fn auc(x: &[f32], y: &[f32]) -> Result<f32, EvalError> {
101    if x.len() != y.len() {
102        return Err(EvalError::CountLengthMismatch {
103            ground_truth: x.len(),
104            predictions: y.len(),
105        });
106    }
107    if x.len() < 2 {
108        return Ok(0.0);
109    }
110
111    let mut area = 0.0f32;
112    for i in 1..x.len() {
113        area += (x[i] - x[i - 1]) * (y[i] + y[i - 1]) / 2.0;
114    }
115    Ok(area.abs())
116}
117
118/// Mean Intersection over Union for semantic segmentation.
119///
120/// `predictions` and `targets` are flat label maps (same length), `num_classes` classes.
121pub fn mean_iou(
122    predictions: &[usize],
123    targets: &[usize],
124    num_classes: usize,
125) -> Result<f32, EvalError> {
126    if predictions.len() != targets.len() {
127        return Err(EvalError::CountLengthMismatch {
128            ground_truth: targets.len(),
129            predictions: predictions.len(),
130        });
131    }
132
133    let mut intersection = vec![0usize; num_classes];
134    let mut union = vec![0usize; num_classes];
135
136    for (&p, &t) in predictions.iter().zip(targets.iter()) {
137        if t < num_classes {
138            if p == t {
139                intersection[t] += 1;
140            }
141            union[t] += 1;
142        }
143        if p < num_classes && p != t {
144            union[p] += 1;
145        }
146    }
147
148    let mut sum_iou = 0.0f32;
149    let mut valid_classes = 0;
150    for c in 0..num_classes {
151        if union[c] > 0 {
152            sum_iou += intersection[c] as f32 / union[c] as f32;
153            valid_classes += 1;
154        }
155    }
156
157    if valid_classes == 0 {
158        return Ok(0.0);
159    }
160    Ok(sum_iou / valid_classes as f32)
161}
162
163/// Per-class Dice coefficient: 2 * |pred ∩ target| / (|pred| + |target|).
164///
165/// Returns a `Vec` of Dice scores, one per class.  Classes that appear in
166/// neither predictions nor targets receive a score of 0.0.
167pub fn dice_score(predictions: &[usize], targets: &[usize], num_classes: usize) -> Vec<f32> {
168    let mut tp = vec![0usize; num_classes];
169    let mut fp = vec![0usize; num_classes];
170    let mut fn_ = vec![0usize; num_classes];
171
172    for (&p, &t) in predictions.iter().zip(targets.iter()) {
173        if p == t {
174            if p < num_classes {
175                tp[p] += 1;
176            }
177        } else {
178            if p < num_classes {
179                fp[p] += 1;
180            }
181            if t < num_classes {
182                fn_[t] += 1;
183            }
184        }
185    }
186
187    (0..num_classes)
188        .map(|c| {
189            let denom = 2 * tp[c] + fp[c] + fn_[c];
190            if denom == 0 {
191                0.0
192            } else {
193                (2 * tp[c]) as f32 / denom as f32
194            }
195        })
196        .collect()
197}
198
199/// Per-class Intersection over Union: |pred ∩ target| / |pred ∪ target|.
200///
201/// Returns a `Vec` of IoU scores, one per class.  Classes that appear in
202/// neither predictions nor targets receive a score of 0.0.
203pub fn per_class_iou(predictions: &[usize], targets: &[usize], num_classes: usize) -> Vec<f32> {
204    let mut tp = vec![0usize; num_classes];
205    let mut fp = vec![0usize; num_classes];
206    let mut fn_ = vec![0usize; num_classes];
207
208    for (&p, &t) in predictions.iter().zip(targets.iter()) {
209        if p == t {
210            if p < num_classes {
211                tp[p] += 1;
212            }
213        } else {
214            if p < num_classes {
215                fp[p] += 1;
216            }
217            if t < num_classes {
218                fn_[t] += 1;
219            }
220        }
221    }
222
223    (0..num_classes)
224        .map(|c| {
225            let denom = tp[c] + fp[c] + fn_[c];
226            if denom == 0 {
227                0.0
228            } else {
229                tp[c] as f32 / denom as f32
230            }
231        })
232        .collect()
233}
234
235/// Structural Similarity Index (SSIM) between two grayscale images.
236///
237/// Both inputs are flat f32 slices of the same length (H*W).
238pub fn ssim(img1: &[f32], img2: &[f32]) -> Result<f32, EvalError> {
239    if img1.len() != img2.len() {
240        return Err(EvalError::CountLengthMismatch {
241            ground_truth: img1.len(),
242            predictions: img2.len(),
243        });
244    }
245    let n = img1.len() as f32;
246    if n == 0.0 {
247        return Ok(1.0);
248    }
249
250    let c1 = (0.01f32 * 1.0).powi(2); // L=1.0 for [0,1] range
251    let c2 = (0.03f32 * 1.0).powi(2);
252
253    let mu1: f32 = img1.iter().sum::<f32>() / n;
254    let mu2: f32 = img2.iter().sum::<f32>() / n;
255
256    let sigma1_sq: f32 = img1.iter().map(|&v| (v - mu1).powi(2)).sum::<f32>() / n;
257    let sigma2_sq: f32 = img2.iter().map(|&v| (v - mu2).powi(2)).sum::<f32>() / n;
258    let sigma12: f32 = img1
259        .iter()
260        .zip(img2.iter())
261        .map(|(&a, &b)| (a - mu1) * (b - mu2))
262        .sum::<f32>()
263        / n;
264
265    let numerator = (2.0 * mu1 * mu2 + c1) * (2.0 * sigma12 + c2);
266    let denominator = (mu1.powi(2) + mu2.powi(2) + c1) * (sigma1_sq + sigma2_sq + c2);
267
268    Ok(numerator / denominator)
269}
270
271/// Peak Signal-to-Noise Ratio between two images.
272///
273/// Both inputs are flat f32 slices (same length). `max_val` is the maximum pixel value.
274pub fn psnr(img1: &[f32], img2: &[f32], max_val: f32) -> Result<f32, EvalError> {
275    if img1.len() != img2.len() {
276        return Err(EvalError::CountLengthMismatch {
277            ground_truth: img1.len(),
278            predictions: img2.len(),
279        });
280    }
281    let mse: f32 = img1
282        .iter()
283        .zip(img2.iter())
284        .map(|(&a, &b)| (a - b).powi(2))
285        .sum::<f32>()
286        / img1.len() as f32;
287
288    if mse == 0.0 {
289        return Ok(f32::INFINITY);
290    }
291    Ok(10.0 * (max_val.powi(2) / mse).log10())
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    #[test]
299    fn test_top_k_accuracy() {
300        // 2 samples, 3 classes
301        // sample 0: scores=[0.1, 0.8, 0.5], target=2 → top-1=class1 → wrong, top-2=[1,2] → correct
302        // sample 1: scores=[0.7, 0.2, 0.1], target=0 → top-1=class0 → correct
303        let scores = vec![0.1, 0.8, 0.5, 0.7, 0.2, 0.1];
304        let targets = vec![2, 0];
305        let acc = top_k_accuracy(&scores, 3, &targets, 1).unwrap();
306        assert!((acc - 0.5).abs() < 1e-6); // only second sample correct at k=1
307        let acc_k2 = top_k_accuracy(&scores, 3, &targets, 2).unwrap();
308        assert!((acc_k2 - 1.0).abs() < 1e-6); // both correct at k=2
309    }
310
311    #[test]
312    fn test_roc_curve_and_auc() {
313        let scores = vec![0.9, 0.8, 0.4, 0.3, 0.1];
314        let labels = vec![true, true, false, false, false];
315        let (fpr, tpr, _) = roc_curve(&scores, &labels).unwrap();
316        let area = auc(&fpr, &tpr).unwrap();
317        assert!(area > 0.9, "AUC should be high: {area}");
318    }
319
320    #[test]
321    fn test_auc_perfect() {
322        let fpr = vec![0.0, 0.0, 1.0];
323        let tpr = vec![0.0, 1.0, 1.0];
324        let area = auc(&fpr, &tpr).unwrap();
325        assert!((area - 1.0).abs() < 1e-6);
326    }
327
328    #[test]
329    fn test_mean_iou() {
330        let preds = vec![0, 0, 1, 1, 2, 2];
331        let targets = vec![0, 0, 1, 1, 2, 2];
332        let miou = mean_iou(&preds, &targets, 3).unwrap();
333        assert!((miou - 1.0).abs() < 1e-6);
334    }
335
336    #[test]
337    fn test_mean_iou_partial() {
338        let preds = vec![0, 1, 1, 0];
339        let targets = vec![0, 0, 1, 1];
340        let miou = mean_iou(&preds, &targets, 2).unwrap();
341        // class 0: intersection=1, union=3 → 1/3
342        // class 1: intersection=1, union=3 → 1/3
343        // mean = 1/3
344        assert!((miou - 1.0 / 3.0).abs() < 0.01);
345    }
346
347    #[test]
348    fn test_ssim_identical() {
349        let img = vec![0.5f32; 100];
350        let val = ssim(&img, &img).unwrap();
351        assert!((val - 1.0).abs() < 1e-4);
352    }
353
354    #[test]
355    fn test_psnr_identical() {
356        let img = vec![0.5f32; 100];
357        let val = psnr(&img, &img, 1.0).unwrap();
358        assert!(val.is_infinite() && val > 0.0);
359    }
360
361    #[test]
362    fn dice_score_perfect() {
363        let preds = vec![0, 0, 1, 1, 2, 2];
364        let targets = vec![0, 0, 1, 1, 2, 2];
365        let scores = dice_score(&preds, &targets, 3);
366        for &s in &scores {
367            assert!((s - 1.0).abs() < 1e-6, "expected 1.0, got {s}");
368        }
369    }
370
371    #[test]
372    fn dice_score_partial() {
373        // preds:   [0, 1, 1, 0]
374        // targets: [0, 0, 1, 1]
375        // class 0: tp=1, fp=1, fn=1 → dice = 2/4 = 0.5
376        // class 1: tp=1, fp=1, fn=1 → dice = 2/4 = 0.5
377        let preds = vec![0, 1, 1, 0];
378        let targets = vec![0, 0, 1, 1];
379        let scores = dice_score(&preds, &targets, 2);
380        assert!((scores[0] - 0.5).abs() < 1e-6, "class 0: {}", scores[0]);
381        assert!((scores[1] - 0.5).abs() < 1e-6, "class 1: {}", scores[1]);
382    }
383
384    #[test]
385    fn per_class_iou_known_values() {
386        // preds:   [0, 1, 1, 0]
387        // targets: [0, 0, 1, 1]
388        // class 0: tp=1, fp=1, fn=1 → iou = 1/3
389        // class 1: tp=1, fp=1, fn=1 → iou = 1/3
390        let preds = vec![0, 1, 1, 0];
391        let targets = vec![0, 0, 1, 1];
392        let ious = per_class_iou(&preds, &targets, 2);
393        assert!((ious[0] - 1.0 / 3.0).abs() < 1e-6, "class 0: {}", ious[0]);
394        assert!((ious[1] - 1.0 / 3.0).abs() < 1e-6, "class 1: {}", ious[1]);
395    }
396
397    #[test]
398    fn per_class_iou_no_overlap() {
399        // preds are all class 0, targets are all class 1
400        let preds = vec![0, 0, 0, 0];
401        let targets = vec![1, 1, 1, 1];
402        let ious = per_class_iou(&preds, &targets, 2);
403        assert!((ious[0]).abs() < 1e-6, "class 0 should be 0: {}", ious[0]);
404        assert!((ious[1]).abs() < 1e-6, "class 1 should be 0: {}", ious[1]);
405    }
406
407    #[test]
408    fn test_psnr_different() {
409        let img1 = vec![0.0f32; 100];
410        let img2 = vec![1.0f32; 100];
411        let val = psnr(&img1, &img2, 1.0).unwrap();
412        assert!((val - 0.0).abs() < 1e-6); // MSE=1, PSNR=0
413    }
414}