1use crate::{TrainError, TrainResult};
4use scirs2_core::ndarray::{ArrayView, Ix2};
5
6use super::Metric;
7
8#[derive(Debug, Clone)]
10pub struct ConfusionMatrix {
11 pub(crate) num_classes: usize,
13 pub(crate) matrix: Vec<Vec<usize>>,
15}
16
17impl ConfusionMatrix {
18 pub fn new(num_classes: usize) -> Self {
23 Self {
24 num_classes,
25 matrix: vec![vec![0; num_classes]; num_classes],
26 }
27 }
28
29 pub fn compute(
38 predictions: &ArrayView<f64, Ix2>,
39 targets: &ArrayView<f64, Ix2>,
40 ) -> TrainResult<Self> {
41 if predictions.shape() != targets.shape() {
42 return Err(TrainError::MetricsError(format!(
43 "Shape mismatch: predictions {:?} vs targets {:?}",
44 predictions.shape(),
45 targets.shape()
46 )));
47 }
48
49 let num_classes = predictions.ncols();
50 let mut matrix = vec![vec![0; num_classes]; num_classes];
51
52 for i in 0..predictions.nrows() {
53 let mut pred_class = 0;
55 let mut max_pred = predictions[[i, 0]];
56 for j in 1..num_classes {
57 if predictions[[i, j]] > max_pred {
58 max_pred = predictions[[i, j]];
59 pred_class = j;
60 }
61 }
62
63 let mut true_class = 0;
65 let mut max_true = targets[[i, 0]];
66 for j in 1..num_classes {
67 if targets[[i, j]] > max_true {
68 max_true = targets[[i, j]];
69 true_class = j;
70 }
71 }
72
73 matrix[true_class][pred_class] += 1;
74 }
75
76 Ok(Self {
77 num_classes,
78 matrix,
79 })
80 }
81
82 pub fn matrix(&self) -> &Vec<Vec<usize>> {
84 &self.matrix
85 }
86
87 pub fn get(&self, true_class: usize, pred_class: usize) -> usize {
89 self.matrix[true_class][pred_class]
90 }
91
92 pub fn precision_per_class(&self) -> Vec<f64> {
94 let mut precisions = Vec::with_capacity(self.num_classes);
95
96 for pred_class in 0..self.num_classes {
97 let mut predicted_positive = 0;
98 let mut true_positive = 0;
99
100 for true_class in 0..self.num_classes {
101 predicted_positive += self.matrix[true_class][pred_class];
102 if true_class == pred_class {
103 true_positive += self.matrix[true_class][pred_class];
104 }
105 }
106
107 let precision = if predicted_positive == 0 {
108 0.0
109 } else {
110 true_positive as f64 / predicted_positive as f64
111 };
112 precisions.push(precision);
113 }
114
115 precisions
116 }
117
118 pub fn recall_per_class(&self) -> Vec<f64> {
120 let mut recalls = Vec::with_capacity(self.num_classes);
121
122 for true_class in 0..self.num_classes {
123 let mut actual_positive = 0;
124 let mut true_positive = 0;
125
126 for pred_class in 0..self.num_classes {
127 actual_positive += self.matrix[true_class][pred_class];
128 if true_class == pred_class {
129 true_positive += self.matrix[true_class][pred_class];
130 }
131 }
132
133 let recall = if actual_positive == 0 {
134 0.0
135 } else {
136 true_positive as f64 / actual_positive as f64
137 };
138 recalls.push(recall);
139 }
140
141 recalls
142 }
143
144 pub fn f1_per_class(&self) -> Vec<f64> {
146 let precisions = self.precision_per_class();
147 let recalls = self.recall_per_class();
148
149 precisions
150 .iter()
151 .zip(recalls.iter())
152 .map(|(p, r)| {
153 if p + r == 0.0 {
154 0.0
155 } else {
156 2.0 * p * r / (p + r)
157 }
158 })
159 .collect()
160 }
161
162 pub fn accuracy(&self) -> f64 {
164 let mut correct = 0;
165 let mut total = 0;
166
167 for i in 0..self.num_classes {
168 for j in 0..self.num_classes {
169 total += self.matrix[i][j];
170 if i == j {
171 correct += self.matrix[i][j];
172 }
173 }
174 }
175
176 if total == 0 {
177 0.0
178 } else {
179 correct as f64 / total as f64
180 }
181 }
182
183 pub fn total_predictions(&self) -> usize {
185 let mut total = 0;
186 for i in 0..self.num_classes {
187 for j in 0..self.num_classes {
188 total += self.matrix[i][j];
189 }
190 }
191 total
192 }
193}
194
195impl std::fmt::Display for ConfusionMatrix {
196 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197 writeln!(f, "Confusion Matrix:")?;
198 write!(f, " ")?;
199
200 for j in 0..self.num_classes {
201 write!(f, "{:5}", j)?;
202 }
203 writeln!(f)?;
204
205 for i in 0..self.num_classes {
206 write!(f, "{:3}| ", i)?;
207 for j in 0..self.num_classes {
208 write!(f, "{:5}", self.matrix[i][j])?;
209 }
210 writeln!(f)?;
211 }
212
213 Ok(())
214 }
215}
216
217#[derive(Debug, Clone)]
219pub struct RocCurve {
220 pub fpr: Vec<f64>,
222 pub tpr: Vec<f64>,
224 pub thresholds: Vec<f64>,
226}
227
228impl RocCurve {
229 pub fn compute(predictions: &[f64], targets: &[bool]) -> TrainResult<Self> {
238 if predictions.len() != targets.len() {
239 return Err(TrainError::MetricsError(format!(
240 "Length mismatch: predictions {} vs targets {}",
241 predictions.len(),
242 targets.len()
243 )));
244 }
245
246 let mut indices: Vec<usize> = (0..predictions.len()).collect();
248 indices.sort_by(|&a, &b| {
249 predictions[b]
250 .partial_cmp(&predictions[a])
251 .unwrap_or(std::cmp::Ordering::Equal)
252 });
253
254 let mut fpr = Vec::new();
255 let mut tpr = Vec::new();
256 let mut thresholds = Vec::new();
257
258 let num_positive = targets.iter().filter(|&&x| x).count();
259 let num_negative = targets.len() - num_positive;
260
261 let mut true_positives = 0;
262 let mut false_positives = 0;
263
264 fpr.push(0.0);
266 tpr.push(0.0);
267 thresholds.push(f64::INFINITY);
268
269 for &idx in &indices {
270 if targets[idx] {
271 true_positives += 1;
272 } else {
273 false_positives += 1;
274 }
275
276 let fpr_val = if num_negative == 0 {
277 0.0
278 } else {
279 false_positives as f64 / num_negative as f64
280 };
281 let tpr_val = if num_positive == 0 {
282 0.0
283 } else {
284 true_positives as f64 / num_positive as f64
285 };
286
287 fpr.push(fpr_val);
288 tpr.push(tpr_val);
289 thresholds.push(predictions[idx]);
290 }
291
292 Ok(Self {
293 fpr,
294 tpr,
295 thresholds,
296 })
297 }
298
299 pub fn auc(&self) -> f64 {
301 let mut auc = 0.0;
302
303 for i in 1..self.fpr.len() {
304 let width = self.fpr[i] - self.fpr[i - 1];
305 let height = (self.tpr[i] + self.tpr[i - 1]) / 2.0;
306 auc += width * height;
307 }
308
309 auc
310 }
311}
312
313#[derive(Debug, Clone)]
315pub struct PerClassMetrics {
316 pub precision: Vec<f64>,
318 pub recall: Vec<f64>,
320 pub f1_score: Vec<f64>,
322 pub support: Vec<usize>,
324}
325
326impl PerClassMetrics {
327 pub fn compute(
336 predictions: &ArrayView<f64, Ix2>,
337 targets: &ArrayView<f64, Ix2>,
338 ) -> TrainResult<Self> {
339 let confusion_matrix = ConfusionMatrix::compute(predictions, targets)?;
340
341 let precision = confusion_matrix.precision_per_class();
342 let recall = confusion_matrix.recall_per_class();
343 let f1_score = confusion_matrix.f1_per_class();
344
345 let num_classes = targets.ncols();
347 let mut support = vec![0; num_classes];
348
349 for i in 0..targets.nrows() {
350 let mut true_class = 0;
352 let mut max_true = targets[[i, 0]];
353 for j in 1..num_classes {
354 if targets[[i, j]] > max_true {
355 max_true = targets[[i, j]];
356 true_class = j;
357 }
358 }
359 support[true_class] += 1;
360 }
361
362 Ok(Self {
363 precision,
364 recall,
365 f1_score,
366 support,
367 })
368 }
369}
370
371impl std::fmt::Display for PerClassMetrics {
372 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
373 writeln!(f, "Per-Class Metrics:")?;
374 writeln!(f, "Class Precision Recall F1-Score Support")?;
375 writeln!(f, "----- --------- ------ -------- -------")?;
376
377 for i in 0..self.precision.len() {
378 writeln!(
379 f,
380 "{:5} {:9.4} {:6.4} {:8.4} {:7}",
381 i, self.precision[i], self.recall[i], self.f1_score[i], self.support[i]
382 )?;
383 }
384
385 let macro_precision: f64 = self.precision.iter().sum::<f64>() / self.precision.len() as f64;
387 let macro_recall: f64 = self.recall.iter().sum::<f64>() / self.recall.len() as f64;
388 let macro_f1: f64 = self.f1_score.iter().sum::<f64>() / self.f1_score.len() as f64;
389 let total_support: usize = self.support.iter().sum();
390
391 writeln!(f, "----- --------- ------ -------- -------")?;
392 writeln!(
393 f,
394 "Macro {:9.4} {:6.4} {:8.4} {:7}",
395 macro_precision, macro_recall, macro_f1, total_support
396 )?;
397
398 Ok(())
399 }
400}
401
402#[derive(Debug, Clone, Default)]
406pub struct MatthewsCorrelationCoefficient;
407
408impl Metric for MatthewsCorrelationCoefficient {
409 fn compute(
410 &self,
411 predictions: &ArrayView<f64, Ix2>,
412 targets: &ArrayView<f64, Ix2>,
413 ) -> TrainResult<f64> {
414 let confusion_matrix = ConfusionMatrix::compute(predictions, targets)?;
415 let num_classes = confusion_matrix.num_classes;
416
417 if num_classes == 2 {
419 let tp = confusion_matrix.matrix[1][1] as f64;
420 let tn = confusion_matrix.matrix[0][0] as f64;
421 let fp = confusion_matrix.matrix[0][1] as f64;
422 let fn_val = confusion_matrix.matrix[1][0] as f64;
423
424 let numerator = (tp * tn) - (fp * fn_val);
425 let denominator = ((tp + fp) * (tp + fn_val) * (tn + fp) * (tn + fn_val)).sqrt();
426
427 if denominator == 0.0 {
428 Ok(0.0)
429 } else {
430 Ok(numerator / denominator)
431 }
432 } else {
433 let mut s = 0.0;
435 let mut c = 0.0;
436 let t = confusion_matrix.total_predictions() as f64;
437
438 let mut p_k = vec![0.0; num_classes];
440 let mut t_k = vec![0.0; num_classes];
441
442 for k in 0..num_classes {
443 for l in 0..num_classes {
444 p_k[k] += confusion_matrix.matrix[l][k] as f64;
445 t_k[k] += confusion_matrix.matrix[k][l] as f64;
446 }
447 }
448
449 for k in 0..num_classes {
451 c += confusion_matrix.matrix[k][k] as f64;
452 }
453
454 for k in 0..num_classes {
456 s += p_k[k] * t_k[k];
457 }
458
459 let numerator = (t * c) - s;
460 let denominator_1 = ((t * t) - s).sqrt();
461 let mut sum_p_sq = 0.0;
462 let mut sum_t_sq = 0.0;
463 for k in 0..num_classes {
464 sum_p_sq += p_k[k] * p_k[k];
465 sum_t_sq += t_k[k] * t_k[k];
466 }
467 let denominator_2 = ((t * t) - sum_p_sq).sqrt();
468 let denominator_3 = ((t * t) - sum_t_sq).sqrt();
469
470 let denominator = denominator_1 * denominator_2 * denominator_3;
471
472 if denominator == 0.0 {
473 Ok(0.0)
474 } else {
475 Ok(numerator / denominator)
476 }
477 }
478 }
479
480 fn name(&self) -> &str {
481 "mcc"
482 }
483}
484
485#[derive(Debug, Clone, Default)]
489pub struct CohensKappa;
490
491impl Metric for CohensKappa {
492 fn compute(
493 &self,
494 predictions: &ArrayView<f64, Ix2>,
495 targets: &ArrayView<f64, Ix2>,
496 ) -> TrainResult<f64> {
497 let confusion_matrix = ConfusionMatrix::compute(predictions, targets)?;
498 let num_classes = confusion_matrix.num_classes;
499 let total = confusion_matrix.total_predictions() as f64;
500
501 let mut observed = 0.0;
503 for i in 0..num_classes {
504 observed += confusion_matrix.matrix[i][i] as f64;
505 }
506 observed /= total;
507
508 let mut expected = 0.0;
510 for i in 0..num_classes {
511 let row_sum: f64 = (0..num_classes)
512 .map(|j| confusion_matrix.matrix[i][j] as f64)
513 .sum();
514 let col_sum: f64 = (0..num_classes)
515 .map(|j| confusion_matrix.matrix[j][i] as f64)
516 .sum();
517 expected += (row_sum / total) * (col_sum / total);
518 }
519
520 if expected >= 1.0 {
521 Ok(0.0)
522 } else {
523 Ok((observed - expected) / (1.0 - expected))
524 }
525 }
526
527 fn name(&self) -> &str {
528 "cohens_kappa"
529 }
530}
531
532#[derive(Debug, Clone, Default)]
535pub struct BalancedAccuracy;
536
537impl Metric for BalancedAccuracy {
538 fn compute(
539 &self,
540 predictions: &ArrayView<f64, Ix2>,
541 targets: &ArrayView<f64, Ix2>,
542 ) -> TrainResult<f64> {
543 let confusion_matrix = ConfusionMatrix::compute(predictions, targets)?;
544 let recalls = confusion_matrix.recall_per_class();
545
546 let sum: f64 = recalls.iter().sum();
548 Ok(sum / recalls.len() as f64)
549 }
550
551 fn name(&self) -> &str {
552 "balanced_accuracy"
553 }
554}
555
556#[cfg(test)]
557mod tests {
558 use super::*;
559 use scirs2_core::ndarray::array;
560
561 #[test]
562 fn test_confusion_matrix() {
563 let predictions = array![
564 [0.9, 0.1, 0.0],
565 [0.1, 0.8, 0.1],
566 [0.2, 0.1, 0.7],
567 [0.8, 0.1, 0.1]
568 ];
569 let targets = array![
570 [1.0, 0.0, 0.0],
571 [0.0, 1.0, 0.0],
572 [0.0, 0.0, 1.0],
573 [1.0, 0.0, 0.0]
574 ];
575
576 let cm = ConfusionMatrix::compute(&predictions.view(), &targets.view()).unwrap();
577
578 assert_eq!(cm.get(0, 0), 2); assert_eq!(cm.get(1, 1), 1); assert_eq!(cm.get(2, 2), 1); assert_eq!(cm.accuracy(), 1.0);
582 }
583
584 #[test]
585 fn test_confusion_matrix_per_class_metrics() {
586 let predictions = array![[0.9, 0.1], [0.2, 0.8], [0.7, 0.3], [0.1, 0.9]];
587 let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
588
589 let cm = ConfusionMatrix::compute(&predictions.view(), &targets.view()).unwrap();
590
591 let precision = cm.precision_per_class();
592 let recall = cm.recall_per_class();
593 let f1 = cm.f1_per_class();
594
595 assert_eq!(precision.len(), 2);
596 assert_eq!(recall.len(), 2);
597 assert_eq!(f1.len(), 2);
598
599 assert_eq!(precision[0], 1.0);
601 assert_eq!(precision[1], 1.0);
602 assert_eq!(recall[0], 1.0);
603 assert_eq!(recall[1], 1.0);
604 }
605
606 #[test]
607 fn test_roc_curve() {
608 let predictions = vec![0.9, 0.8, 0.4, 0.3, 0.1];
609 let targets = vec![true, true, false, true, false];
610
611 let roc = RocCurve::compute(&predictions, &targets).unwrap();
612
613 assert!(!roc.fpr.is_empty());
614 assert!(!roc.tpr.is_empty());
615 assert!(!roc.thresholds.is_empty());
616 assert_eq!(roc.fpr.len(), roc.tpr.len());
617
618 let auc = roc.auc();
619 assert!((0.0..=1.0).contains(&auc));
620 }
621
622 #[test]
623 fn test_roc_auc_perfect() {
624 let predictions = vec![0.9, 0.8, 0.3, 0.1];
625 let targets = vec![true, true, false, false];
626
627 let roc = RocCurve::compute(&predictions, &targets).unwrap();
628 let auc = roc.auc();
629
630 assert!((auc - 1.0).abs() < 1e-6);
632 }
633
634 #[test]
635 fn test_per_class_metrics() {
636 let predictions = array![
637 [0.9, 0.1, 0.0],
638 [0.1, 0.8, 0.1],
639 [0.2, 0.1, 0.7],
640 [0.8, 0.1, 0.1]
641 ];
642 let targets = array![
643 [1.0, 0.0, 0.0],
644 [0.0, 1.0, 0.0],
645 [0.0, 0.0, 1.0],
646 [1.0, 0.0, 0.0]
647 ];
648
649 let metrics = PerClassMetrics::compute(&predictions.view(), &targets.view()).unwrap();
650
651 assert_eq!(metrics.precision.len(), 3);
652 assert_eq!(metrics.recall.len(), 3);
653 assert_eq!(metrics.f1_score.len(), 3);
654 assert_eq!(metrics.support.len(), 3);
655
656 assert_eq!(metrics.support[0], 2);
658 assert_eq!(metrics.support[1], 1);
659 assert_eq!(metrics.support[2], 1);
660 }
661
662 #[test]
663 fn test_matthews_correlation_coefficient() {
664 let metric = MatthewsCorrelationCoefficient;
665
666 let predictions = array![[0.9, 0.1], [0.1, 0.9], [0.9, 0.1], [0.1, 0.9]];
668 let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
669
670 let mcc = metric
671 .compute(&predictions.view(), &targets.view())
672 .unwrap();
673 assert!((mcc - 1.0).abs() < 1e-6);
674
675 let predictions = array![[0.5, 0.5], [0.5, 0.5], [0.5, 0.5], [0.5, 0.5]];
677 let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
678
679 let mcc = metric
680 .compute(&predictions.view(), &targets.view())
681 .unwrap();
682 assert!(mcc.abs() < 0.1);
683 }
684
685 #[test]
686 fn test_cohens_kappa() {
687 let metric = CohensKappa;
688
689 let predictions = array![[0.9, 0.1], [0.1, 0.9], [0.9, 0.1], [0.1, 0.9]];
691 let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
692
693 let kappa = metric
694 .compute(&predictions.view(), &targets.view())
695 .unwrap();
696 assert!((kappa - 1.0).abs() < 1e-6);
697
698 let predictions = array![[0.9, 0.1], [0.9, 0.1], [0.9, 0.1], [0.9, 0.1]];
700 let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
701
702 let kappa = metric
703 .compute(&predictions.view(), &targets.view())
704 .unwrap();
705 assert!((-1.0..=1.0).contains(&kappa));
706 }
707
708 #[test]
709 fn test_balanced_accuracy() {
710 let metric = BalancedAccuracy;
711
712 let predictions = array![[0.9, 0.1], [0.1, 0.9], [0.9, 0.1], [0.1, 0.9]];
714 let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
715
716 let balanced_acc = metric
717 .compute(&predictions.view(), &targets.view())
718 .unwrap();
719 assert!((balanced_acc - 1.0).abs() < 1e-6);
720
721 let predictions = array![[0.9, 0.1], [0.9, 0.1], [0.9, 0.1], [0.1, 0.9]];
723 let targets = array![[1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0]];
724
725 let balanced_acc = metric
726 .compute(&predictions.view(), &targets.view())
727 .unwrap();
728 assert!((balanced_acc - 1.0).abs() < 1e-6);
729 }
730}