1use crate::{TrainError, TrainResult};
4use scirs2_core::ndarray::{ArrayView, Ix2};
5use std::collections::HashMap;
6
7pub trait Metric {
9 fn compute(
11 &self,
12 predictions: &ArrayView<f64, Ix2>,
13 targets: &ArrayView<f64, Ix2>,
14 ) -> TrainResult<f64>;
15
16 fn name(&self) -> &str;
18
19 fn reset(&mut self) {}
21}
22
23#[derive(Debug, Clone)]
25pub struct Accuracy {
26 pub threshold: f64,
28}
29
30impl Default for Accuracy {
31 fn default() -> Self {
32 Self { threshold: 0.5 }
33 }
34}
35
36impl Metric for Accuracy {
37 fn compute(
38 &self,
39 predictions: &ArrayView<f64, Ix2>,
40 targets: &ArrayView<f64, Ix2>,
41 ) -> TrainResult<f64> {
42 if predictions.shape() != targets.shape() {
43 return Err(TrainError::MetricsError(format!(
44 "Shape mismatch: predictions {:?} vs targets {:?}",
45 predictions.shape(),
46 targets.shape()
47 )));
48 }
49
50 let mut correct = 0;
51 let total = predictions.nrows();
52
53 for i in 0..total {
54 let mut pred_class = 0;
56 let mut max_pred = predictions[[i, 0]];
57 for j in 1..predictions.ncols() {
58 if predictions[[i, j]] > max_pred {
59 max_pred = predictions[[i, j]];
60 pred_class = j;
61 }
62 }
63
64 let mut true_class = 0;
66 let mut max_true = targets[[i, 0]];
67 for j in 1..targets.ncols() {
68 if targets[[i, j]] > max_true {
69 max_true = targets[[i, j]];
70 true_class = j;
71 }
72 }
73
74 if pred_class == true_class {
75 correct += 1;
76 }
77 }
78
79 Ok(correct as f64 / total as f64)
80 }
81
82 fn name(&self) -> &str {
83 "accuracy"
84 }
85}
86
87#[derive(Debug, Clone, Default)]
89pub struct Precision {
90 pub class_id: Option<usize>,
92}
93
94impl Metric for Precision {
95 fn compute(
96 &self,
97 predictions: &ArrayView<f64, Ix2>,
98 targets: &ArrayView<f64, Ix2>,
99 ) -> TrainResult<f64> {
100 if predictions.shape() != targets.shape() {
101 return Err(TrainError::MetricsError(format!(
102 "Shape mismatch: predictions {:?} vs targets {:?}",
103 predictions.shape(),
104 targets.shape()
105 )));
106 }
107
108 let num_classes = predictions.ncols();
109 let mut true_positives = vec![0; num_classes];
110 let mut predicted_positives = vec![0; num_classes];
111
112 for i in 0..predictions.nrows() {
113 let mut pred_class = 0;
115 let mut max_pred = predictions[[i, 0]];
116 for j in 1..num_classes {
117 if predictions[[i, j]] > max_pred {
118 max_pred = predictions[[i, j]];
119 pred_class = j;
120 }
121 }
122
123 let mut true_class = 0;
125 let mut max_true = targets[[i, 0]];
126 for j in 1..num_classes {
127 if targets[[i, j]] > max_true {
128 max_true = targets[[i, j]];
129 true_class = j;
130 }
131 }
132
133 predicted_positives[pred_class] += 1;
134 if pred_class == true_class {
135 true_positives[pred_class] += 1;
136 }
137 }
138
139 if let Some(class_id) = self.class_id {
140 if predicted_positives[class_id] == 0 {
142 Ok(0.0)
143 } else {
144 Ok(true_positives[class_id] as f64 / predicted_positives[class_id] as f64)
145 }
146 } else {
147 let mut total_precision = 0.0;
149 let mut valid_classes = 0;
150
151 for class_id in 0..num_classes {
152 if predicted_positives[class_id] > 0 {
153 total_precision +=
154 true_positives[class_id] as f64 / predicted_positives[class_id] as f64;
155 valid_classes += 1;
156 }
157 }
158
159 if valid_classes == 0 {
160 Ok(0.0)
161 } else {
162 Ok(total_precision / valid_classes as f64)
163 }
164 }
165 }
166
167 fn name(&self) -> &str {
168 "precision"
169 }
170}
171
172#[derive(Debug, Clone, Default)]
174pub struct Recall {
175 pub class_id: Option<usize>,
177}
178
179impl Metric for Recall {
180 fn compute(
181 &self,
182 predictions: &ArrayView<f64, Ix2>,
183 targets: &ArrayView<f64, Ix2>,
184 ) -> TrainResult<f64> {
185 if predictions.shape() != targets.shape() {
186 return Err(TrainError::MetricsError(format!(
187 "Shape mismatch: predictions {:?} vs targets {:?}",
188 predictions.shape(),
189 targets.shape()
190 )));
191 }
192
193 let num_classes = predictions.ncols();
194 let mut true_positives = vec![0; num_classes];
195 let mut actual_positives = vec![0; num_classes];
196
197 for i in 0..predictions.nrows() {
198 let mut pred_class = 0;
200 let mut max_pred = predictions[[i, 0]];
201 for j in 1..num_classes {
202 if predictions[[i, j]] > max_pred {
203 max_pred = predictions[[i, j]];
204 pred_class = j;
205 }
206 }
207
208 let mut true_class = 0;
210 let mut max_true = targets[[i, 0]];
211 for j in 1..num_classes {
212 if targets[[i, j]] > max_true {
213 max_true = targets[[i, j]];
214 true_class = j;
215 }
216 }
217
218 actual_positives[true_class] += 1;
219 if pred_class == true_class {
220 true_positives[pred_class] += 1;
221 }
222 }
223
224 if let Some(class_id) = self.class_id {
225 if actual_positives[class_id] == 0 {
227 Ok(0.0)
228 } else {
229 Ok(true_positives[class_id] as f64 / actual_positives[class_id] as f64)
230 }
231 } else {
232 let mut total_recall = 0.0;
234 let mut valid_classes = 0;
235
236 for class_id in 0..num_classes {
237 if actual_positives[class_id] > 0 {
238 total_recall +=
239 true_positives[class_id] as f64 / actual_positives[class_id] as f64;
240 valid_classes += 1;
241 }
242 }
243
244 if valid_classes == 0 {
245 Ok(0.0)
246 } else {
247 Ok(total_recall / valid_classes as f64)
248 }
249 }
250 }
251
252 fn name(&self) -> &str {
253 "recall"
254 }
255}
256
257#[derive(Debug, Clone, Default)]
259pub struct F1Score {
260 pub class_id: Option<usize>,
262}
263
264impl Metric for F1Score {
265 fn compute(
266 &self,
267 predictions: &ArrayView<f64, Ix2>,
268 targets: &ArrayView<f64, Ix2>,
269 ) -> TrainResult<f64> {
270 let precision = Precision {
271 class_id: self.class_id,
272 }
273 .compute(predictions, targets)?;
274 let recall = Recall {
275 class_id: self.class_id,
276 }
277 .compute(predictions, targets)?;
278
279 if precision + recall == 0.0 {
280 Ok(0.0)
281 } else {
282 Ok(2.0 * precision * recall / (precision + recall))
283 }
284 }
285
286 fn name(&self) -> &str {
287 "f1_score"
288 }
289}
290
291pub struct MetricTracker {
293 metrics: Vec<Box<dyn Metric>>,
295 history: HashMap<String, Vec<f64>>,
297}
298
299impl MetricTracker {
300 pub fn new() -> Self {
302 Self {
303 metrics: Vec::new(),
304 history: HashMap::new(),
305 }
306 }
307
308 pub fn add(&mut self, metric: Box<dyn Metric>) {
310 let name = metric.name().to_string();
311 self.history.insert(name, Vec::new());
312 self.metrics.push(metric);
313 }
314
315 pub fn compute_all(
317 &mut self,
318 predictions: &ArrayView<f64, Ix2>,
319 targets: &ArrayView<f64, Ix2>,
320 ) -> TrainResult<HashMap<String, f64>> {
321 let mut results = HashMap::new();
322
323 for metric in &self.metrics {
324 let value = metric.compute(predictions, targets)?;
325 let name = metric.name().to_string();
326
327 results.insert(name.clone(), value);
328
329 if let Some(history) = self.history.get_mut(&name) {
330 history.push(value);
331 }
332 }
333
334 Ok(results)
335 }
336
337 pub fn get_history(&self, metric_name: &str) -> Option<&Vec<f64>> {
339 self.history.get(metric_name)
340 }
341
342 pub fn reset(&mut self) {
344 for metric in &mut self.metrics {
345 metric.reset();
346 }
347 }
348
349 pub fn clear_history(&mut self) {
351 for history in self.history.values_mut() {
352 history.clear();
353 }
354 }
355}
356
357impl Default for MetricTracker {
358 fn default() -> Self {
359 Self::new()
360 }
361}
362
363#[derive(Debug, Clone)]
365pub struct ConfusionMatrix {
366 num_classes: usize,
368 matrix: Vec<Vec<usize>>,
370}
371
372impl ConfusionMatrix {
373 pub fn new(num_classes: usize) -> Self {
378 Self {
379 num_classes,
380 matrix: vec![vec![0; num_classes]; num_classes],
381 }
382 }
383
384 pub fn compute(
393 predictions: &ArrayView<f64, Ix2>,
394 targets: &ArrayView<f64, Ix2>,
395 ) -> TrainResult<Self> {
396 if predictions.shape() != targets.shape() {
397 return Err(TrainError::MetricsError(format!(
398 "Shape mismatch: predictions {:?} vs targets {:?}",
399 predictions.shape(),
400 targets.shape()
401 )));
402 }
403
404 let num_classes = predictions.ncols();
405 let mut matrix = vec![vec![0; num_classes]; num_classes];
406
407 for i in 0..predictions.nrows() {
408 let mut pred_class = 0;
410 let mut max_pred = predictions[[i, 0]];
411 for j in 1..num_classes {
412 if predictions[[i, j]] > max_pred {
413 max_pred = predictions[[i, j]];
414 pred_class = j;
415 }
416 }
417
418 let mut true_class = 0;
420 let mut max_true = targets[[i, 0]];
421 for j in 1..num_classes {
422 if targets[[i, j]] > max_true {
423 max_true = targets[[i, j]];
424 true_class = j;
425 }
426 }
427
428 matrix[true_class][pred_class] += 1;
429 }
430
431 Ok(Self {
432 num_classes,
433 matrix,
434 })
435 }
436
437 pub fn matrix(&self) -> &Vec<Vec<usize>> {
439 &self.matrix
440 }
441
442 pub fn get(&self, true_class: usize, pred_class: usize) -> usize {
444 self.matrix[true_class][pred_class]
445 }
446
447 pub fn precision_per_class(&self) -> Vec<f64> {
449 let mut precisions = Vec::with_capacity(self.num_classes);
450
451 for pred_class in 0..self.num_classes {
452 let mut predicted_positive = 0;
453 let mut true_positive = 0;
454
455 for true_class in 0..self.num_classes {
456 predicted_positive += self.matrix[true_class][pred_class];
457 if true_class == pred_class {
458 true_positive += self.matrix[true_class][pred_class];
459 }
460 }
461
462 let precision = if predicted_positive == 0 {
463 0.0
464 } else {
465 true_positive as f64 / predicted_positive as f64
466 };
467 precisions.push(precision);
468 }
469
470 precisions
471 }
472
473 pub fn recall_per_class(&self) -> Vec<f64> {
475 let mut recalls = Vec::with_capacity(self.num_classes);
476
477 for true_class in 0..self.num_classes {
478 let mut actual_positive = 0;
479 let mut true_positive = 0;
480
481 for pred_class in 0..self.num_classes {
482 actual_positive += self.matrix[true_class][pred_class];
483 if true_class == pred_class {
484 true_positive += self.matrix[true_class][pred_class];
485 }
486 }
487
488 let recall = if actual_positive == 0 {
489 0.0
490 } else {
491 true_positive as f64 / actual_positive as f64
492 };
493 recalls.push(recall);
494 }
495
496 recalls
497 }
498
499 pub fn f1_per_class(&self) -> Vec<f64> {
501 let precisions = self.precision_per_class();
502 let recalls = self.recall_per_class();
503
504 precisions
505 .iter()
506 .zip(recalls.iter())
507 .map(|(p, r)| {
508 if p + r == 0.0 {
509 0.0
510 } else {
511 2.0 * p * r / (p + r)
512 }
513 })
514 .collect()
515 }
516
517 pub fn accuracy(&self) -> f64 {
519 let mut correct = 0;
520 let mut total = 0;
521
522 for i in 0..self.num_classes {
523 for j in 0..self.num_classes {
524 total += self.matrix[i][j];
525 if i == j {
526 correct += self.matrix[i][j];
527 }
528 }
529 }
530
531 if total == 0 {
532 0.0
533 } else {
534 correct as f64 / total as f64
535 }
536 }
537
538 pub fn total_predictions(&self) -> usize {
540 let mut total = 0;
541 for i in 0..self.num_classes {
542 for j in 0..self.num_classes {
543 total += self.matrix[i][j];
544 }
545 }
546 total
547 }
548}
549
550impl std::fmt::Display for ConfusionMatrix {
551 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
552 writeln!(f, "Confusion Matrix:")?;
553 write!(f, " ")?;
554
555 for j in 0..self.num_classes {
556 write!(f, "{:5}", j)?;
557 }
558 writeln!(f)?;
559
560 for i in 0..self.num_classes {
561 write!(f, "{:3}| ", i)?;
562 for j in 0..self.num_classes {
563 write!(f, "{:5}", self.matrix[i][j])?;
564 }
565 writeln!(f)?;
566 }
567
568 Ok(())
569 }
570}
571
572#[derive(Debug, Clone)]
574pub struct RocCurve {
575 pub fpr: Vec<f64>,
577 pub tpr: Vec<f64>,
579 pub thresholds: Vec<f64>,
581}
582
583impl RocCurve {
584 pub fn compute(predictions: &[f64], targets: &[bool]) -> TrainResult<Self> {
593 if predictions.len() != targets.len() {
594 return Err(TrainError::MetricsError(format!(
595 "Length mismatch: predictions {} vs targets {}",
596 predictions.len(),
597 targets.len()
598 )));
599 }
600
601 let mut indices: Vec<usize> = (0..predictions.len()).collect();
603 indices.sort_by(|&a, &b| {
604 predictions[b]
605 .partial_cmp(&predictions[a])
606 .unwrap_or(std::cmp::Ordering::Equal)
607 });
608
609 let mut fpr = Vec::new();
610 let mut tpr = Vec::new();
611 let mut thresholds = Vec::new();
612
613 let num_positive = targets.iter().filter(|&&x| x).count();
614 let num_negative = targets.len() - num_positive;
615
616 let mut true_positives = 0;
617 let mut false_positives = 0;
618
619 fpr.push(0.0);
621 tpr.push(0.0);
622 thresholds.push(f64::INFINITY);
623
624 for &idx in &indices {
625 if targets[idx] {
626 true_positives += 1;
627 } else {
628 false_positives += 1;
629 }
630
631 let fpr_val = if num_negative == 0 {
632 0.0
633 } else {
634 false_positives as f64 / num_negative as f64
635 };
636 let tpr_val = if num_positive == 0 {
637 0.0
638 } else {
639 true_positives as f64 / num_positive as f64
640 };
641
642 fpr.push(fpr_val);
643 tpr.push(tpr_val);
644 thresholds.push(predictions[idx]);
645 }
646
647 Ok(Self {
648 fpr,
649 tpr,
650 thresholds,
651 })
652 }
653
654 pub fn auc(&self) -> f64 {
656 let mut auc = 0.0;
657
658 for i in 1..self.fpr.len() {
659 let width = self.fpr[i] - self.fpr[i - 1];
660 let height = (self.tpr[i] + self.tpr[i - 1]) / 2.0;
661 auc += width * height;
662 }
663
664 auc
665 }
666}
667
668#[derive(Debug, Clone)]
670pub struct PerClassMetrics {
671 pub precision: Vec<f64>,
673 pub recall: Vec<f64>,
675 pub f1_score: Vec<f64>,
677 pub support: Vec<usize>,
679}
680
681impl PerClassMetrics {
682 pub fn compute(
691 predictions: &ArrayView<f64, Ix2>,
692 targets: &ArrayView<f64, Ix2>,
693 ) -> TrainResult<Self> {
694 let confusion_matrix = ConfusionMatrix::compute(predictions, targets)?;
695
696 let precision = confusion_matrix.precision_per_class();
697 let recall = confusion_matrix.recall_per_class();
698 let f1_score = confusion_matrix.f1_per_class();
699
700 let num_classes = targets.ncols();
702 let mut support = vec![0; num_classes];
703
704 for i in 0..targets.nrows() {
705 let mut true_class = 0;
707 let mut max_true = targets[[i, 0]];
708 for j in 1..num_classes {
709 if targets[[i, j]] > max_true {
710 max_true = targets[[i, j]];
711 true_class = j;
712 }
713 }
714 support[true_class] += 1;
715 }
716
717 Ok(Self {
718 precision,
719 recall,
720 f1_score,
721 support,
722 })
723 }
724}
725
726impl std::fmt::Display for PerClassMetrics {
727 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
728 writeln!(f, "Per-Class Metrics:")?;
729 writeln!(f, "Class Precision Recall F1-Score Support")?;
730 writeln!(f, "----- --------- ------ -------- -------")?;
731
732 for i in 0..self.precision.len() {
733 writeln!(
734 f,
735 "{:5} {:9.4} {:6.4} {:8.4} {:7}",
736 i, self.precision[i], self.recall[i], self.f1_score[i], self.support[i]
737 )?;
738 }
739
740 let macro_precision: f64 = self.precision.iter().sum::<f64>() / self.precision.len() as f64;
742 let macro_recall: f64 = self.recall.iter().sum::<f64>() / self.recall.len() as f64;
743 let macro_f1: f64 = self.f1_score.iter().sum::<f64>() / self.f1_score.len() as f64;
744 let total_support: usize = self.support.iter().sum();
745
746 writeln!(f, "----- --------- ------ -------- -------")?;
747 writeln!(
748 f,
749 "Macro {:9.4} {:6.4} {:8.4} {:7}",
750 macro_precision, macro_recall, macro_f1, total_support
751 )?;
752
753 Ok(())
754 }
755}
756
757#[derive(Debug, Clone, Default)]
761pub struct MatthewsCorrelationCoefficient;
762
763impl Metric for MatthewsCorrelationCoefficient {
764 fn compute(
765 &self,
766 predictions: &ArrayView<f64, Ix2>,
767 targets: &ArrayView<f64, Ix2>,
768 ) -> TrainResult<f64> {
769 let confusion_matrix = ConfusionMatrix::compute(predictions, targets)?;
770 let num_classes = confusion_matrix.num_classes;
771
772 if num_classes == 2 {
774 let tp = confusion_matrix.matrix[1][1] as f64;
775 let tn = confusion_matrix.matrix[0][0] as f64;
776 let fp = confusion_matrix.matrix[0][1] as f64;
777 let fn_val = confusion_matrix.matrix[1][0] as f64;
778
779 let numerator = (tp * tn) - (fp * fn_val);
780 let denominator = ((tp + fp) * (tp + fn_val) * (tn + fp) * (tn + fn_val)).sqrt();
781
782 if denominator == 0.0 {
783 Ok(0.0)
784 } else {
785 Ok(numerator / denominator)
786 }
787 } else {
788 let mut s = 0.0;
790 let mut c = 0.0;
791 let t = confusion_matrix.total_predictions() as f64;
792
793 let mut p_k = vec![0.0; num_classes];
795 let mut t_k = vec![0.0; num_classes];
796
797 for k in 0..num_classes {
798 for l in 0..num_classes {
799 p_k[k] += confusion_matrix.matrix[l][k] as f64;
800 t_k[k] += confusion_matrix.matrix[k][l] as f64;
801 }
802 }
803
804 for k in 0..num_classes {
806 c += confusion_matrix.matrix[k][k] as f64;
807 }
808
809 for k in 0..num_classes {
811 s += p_k[k] * t_k[k];
812 }
813
814 let numerator = (t * c) - s;
815 let denominator_1 = ((t * t) - s).sqrt();
816 let mut sum_p_sq = 0.0;
817 let mut sum_t_sq = 0.0;
818 for k in 0..num_classes {
819 sum_p_sq += p_k[k] * p_k[k];
820 sum_t_sq += t_k[k] * t_k[k];
821 }
822 let denominator_2 = ((t * t) - sum_p_sq).sqrt();
823 let denominator_3 = ((t * t) - sum_t_sq).sqrt();
824
825 let denominator = denominator_1 * denominator_2 * denominator_3;
826
827 if denominator == 0.0 {
828 Ok(0.0)
829 } else {
830 Ok(numerator / denominator)
831 }
832 }
833 }
834
835 fn name(&self) -> &str {
836 "mcc"
837 }
838}
839
840#[derive(Debug, Clone, Default)]
844pub struct CohensKappa;
845
846impl Metric for CohensKappa {
847 fn compute(
848 &self,
849 predictions: &ArrayView<f64, Ix2>,
850 targets: &ArrayView<f64, Ix2>,
851 ) -> TrainResult<f64> {
852 let confusion_matrix = ConfusionMatrix::compute(predictions, targets)?;
853 let num_classes = confusion_matrix.num_classes;
854 let total = confusion_matrix.total_predictions() as f64;
855
856 let mut observed = 0.0;
858 for i in 0..num_classes {
859 observed += confusion_matrix.matrix[i][i] as f64;
860 }
861 observed /= total;
862
863 let mut expected = 0.0;
865 for i in 0..num_classes {
866 let row_sum: f64 = (0..num_classes)
867 .map(|j| confusion_matrix.matrix[i][j] as f64)
868 .sum();
869 let col_sum: f64 = (0..num_classes)
870 .map(|j| confusion_matrix.matrix[j][i] as f64)
871 .sum();
872 expected += (row_sum / total) * (col_sum / total);
873 }
874
875 if expected >= 1.0 {
876 Ok(0.0)
877 } else {
878 Ok((observed - expected) / (1.0 - expected))
879 }
880 }
881
882 fn name(&self) -> &str {
883 "cohens_kappa"
884 }
885}
886
887#[derive(Debug, Clone)]
890pub struct TopKAccuracy {
891 pub k: usize,
893}
894
895impl Default for TopKAccuracy {
896 fn default() -> Self {
897 Self { k: 5 }
898 }
899}
900
901impl TopKAccuracy {
902 pub fn new(k: usize) -> Self {
904 Self { k }
905 }
906}
907
908impl Metric for TopKAccuracy {
909 fn compute(
910 &self,
911 predictions: &ArrayView<f64, Ix2>,
912 targets: &ArrayView<f64, Ix2>,
913 ) -> TrainResult<f64> {
914 if predictions.shape() != targets.shape() {
915 return Err(TrainError::MetricsError(format!(
916 "Shape mismatch: predictions {:?} vs targets {:?}",
917 predictions.shape(),
918 targets.shape()
919 )));
920 }
921
922 let num_classes = predictions.ncols();
923 if self.k > num_classes {
924 return Err(TrainError::MetricsError(format!(
925 "K ({}) cannot be greater than number of classes ({})",
926 self.k, num_classes
927 )));
928 }
929
930 let mut correct = 0;
931 let total = predictions.nrows();
932
933 for i in 0..total {
934 let mut true_class = 0;
936 let mut max_true = targets[[i, 0]];
937 for j in 1..num_classes {
938 if targets[[i, j]] > max_true {
939 max_true = targets[[i, j]];
940 true_class = j;
941 }
942 }
943
944 let mut indices: Vec<usize> = (0..num_classes).collect();
946 indices.sort_by(|&a, &b| {
947 predictions[[i, b]]
948 .partial_cmp(&predictions[[i, a]])
949 .unwrap_or(std::cmp::Ordering::Equal)
950 });
951
952 if indices[..self.k].contains(&true_class) {
954 correct += 1;
955 }
956 }
957
958 Ok(correct as f64 / total as f64)
959 }
960
961 fn name(&self) -> &str {
962 "top_k_accuracy"
963 }
964}
965
966#[derive(Debug, Clone, Default)]
969pub struct BalancedAccuracy;
970
971impl Metric for BalancedAccuracy {
972 fn compute(
973 &self,
974 predictions: &ArrayView<f64, Ix2>,
975 targets: &ArrayView<f64, Ix2>,
976 ) -> TrainResult<f64> {
977 let confusion_matrix = ConfusionMatrix::compute(predictions, targets)?;
978 let recalls = confusion_matrix.recall_per_class();
979
980 let sum: f64 = recalls.iter().sum();
982 Ok(sum / recalls.len() as f64)
983 }
984
985 fn name(&self) -> &str {
986 "balanced_accuracy"
987 }
988}
989
990#[cfg(test)]
991mod tests {
992 use super::*;
993 use scirs2_core::ndarray::array;
994
995 #[test]
996 fn test_accuracy() {
997 let metric = Accuracy::default();
998
999 let predictions = array![[0.9, 0.1], [0.2, 0.8], [0.8, 0.2]];
1001 let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0]];
1002
1003 let accuracy = metric
1004 .compute(&predictions.view(), &targets.view())
1005 .unwrap();
1006 assert_eq!(accuracy, 1.0);
1007
1008 let predictions = array![[0.9, 0.1], [0.8, 0.2], [0.8, 0.2]];
1010 let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0]];
1011
1012 let accuracy = metric
1013 .compute(&predictions.view(), &targets.view())
1014 .unwrap();
1015 assert!((accuracy - 2.0 / 3.0).abs() < 1e-6);
1016 }
1017
1018 #[test]
1019 fn test_precision() {
1020 let metric = Precision::default();
1021
1022 let predictions = array![[0.9, 0.1], [0.2, 0.8], [0.7, 0.3]];
1023 let targets = array![[1.0, 0.0], [0.0, 1.0], [0.0, 1.0]];
1024
1025 let precision = metric
1026 .compute(&predictions.view(), &targets.view())
1027 .unwrap();
1028 assert!((0.0..=1.0).contains(&precision));
1029 }
1030
1031 #[test]
1032 fn test_recall() {
1033 let metric = Recall::default();
1034
1035 let predictions = array![[0.9, 0.1], [0.2, 0.8], [0.7, 0.3]];
1036 let targets = array![[1.0, 0.0], [0.0, 1.0], [0.0, 1.0]];
1037
1038 let recall = metric
1039 .compute(&predictions.view(), &targets.view())
1040 .unwrap();
1041 assert!((0.0..=1.0).contains(&recall));
1042 }
1043
1044 #[test]
1045 fn test_f1_score() {
1046 let metric = F1Score::default();
1047
1048 let predictions = array![[0.9, 0.1], [0.2, 0.8], [0.7, 0.3]];
1049 let targets = array![[1.0, 0.0], [0.0, 1.0], [0.0, 1.0]];
1050
1051 let f1 = metric
1052 .compute(&predictions.view(), &targets.view())
1053 .unwrap();
1054 assert!((0.0..=1.0).contains(&f1));
1055 }
1056
1057 #[test]
1058 fn test_metric_tracker() {
1059 let mut tracker = MetricTracker::new();
1060 tracker.add(Box::new(Accuracy::default()));
1061 tracker.add(Box::new(F1Score::default()));
1062
1063 let predictions = array![[0.9, 0.1], [0.2, 0.8]];
1064 let targets = array![[1.0, 0.0], [0.0, 1.0]];
1065
1066 let results = tracker
1067 .compute_all(&predictions.view(), &targets.view())
1068 .unwrap();
1069 assert!(results.contains_key("accuracy"));
1070 assert!(results.contains_key("f1_score"));
1071
1072 let history = tracker.get_history("accuracy").unwrap();
1073 assert_eq!(history.len(), 1);
1074 }
1075
1076 #[test]
1077 fn test_confusion_matrix() {
1078 let predictions = array![
1079 [0.9, 0.1, 0.0],
1080 [0.1, 0.8, 0.1],
1081 [0.2, 0.1, 0.7],
1082 [0.8, 0.1, 0.1]
1083 ];
1084 let targets = array![
1085 [1.0, 0.0, 0.0],
1086 [0.0, 1.0, 0.0],
1087 [0.0, 0.0, 1.0],
1088 [1.0, 0.0, 0.0]
1089 ];
1090
1091 let cm = ConfusionMatrix::compute(&predictions.view(), &targets.view()).unwrap();
1092
1093 assert_eq!(cm.get(0, 0), 2); assert_eq!(cm.get(1, 1), 1); assert_eq!(cm.get(2, 2), 1); assert_eq!(cm.accuracy(), 1.0);
1097 }
1098
1099 #[test]
1100 fn test_confusion_matrix_per_class_metrics() {
1101 let predictions = array![[0.9, 0.1], [0.2, 0.8], [0.7, 0.3], [0.1, 0.9]];
1102 let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
1103
1104 let cm = ConfusionMatrix::compute(&predictions.view(), &targets.view()).unwrap();
1105
1106 let precision = cm.precision_per_class();
1107 let recall = cm.recall_per_class();
1108 let f1 = cm.f1_per_class();
1109
1110 assert_eq!(precision.len(), 2);
1111 assert_eq!(recall.len(), 2);
1112 assert_eq!(f1.len(), 2);
1113
1114 assert_eq!(precision[0], 1.0);
1116 assert_eq!(precision[1], 1.0);
1117 assert_eq!(recall[0], 1.0);
1118 assert_eq!(recall[1], 1.0);
1119 }
1120
1121 #[test]
1122 fn test_roc_curve() {
1123 let predictions = vec![0.9, 0.8, 0.4, 0.3, 0.1];
1124 let targets = vec![true, true, false, true, false];
1125
1126 let roc = RocCurve::compute(&predictions, &targets).unwrap();
1127
1128 assert!(!roc.fpr.is_empty());
1129 assert!(!roc.tpr.is_empty());
1130 assert!(!roc.thresholds.is_empty());
1131 assert_eq!(roc.fpr.len(), roc.tpr.len());
1132
1133 let auc = roc.auc();
1134 assert!((0.0..=1.0).contains(&auc));
1135 }
1136
1137 #[test]
1138 fn test_roc_auc_perfect() {
1139 let predictions = vec![0.9, 0.8, 0.3, 0.1];
1140 let targets = vec![true, true, false, false];
1141
1142 let roc = RocCurve::compute(&predictions, &targets).unwrap();
1143 let auc = roc.auc();
1144
1145 assert!((auc - 1.0).abs() < 1e-6);
1147 }
1148
1149 #[test]
1150 fn test_per_class_metrics() {
1151 let predictions = array![
1152 [0.9, 0.1, 0.0],
1153 [0.1, 0.8, 0.1],
1154 [0.2, 0.1, 0.7],
1155 [0.8, 0.1, 0.1]
1156 ];
1157 let targets = array![
1158 [1.0, 0.0, 0.0],
1159 [0.0, 1.0, 0.0],
1160 [0.0, 0.0, 1.0],
1161 [1.0, 0.0, 0.0]
1162 ];
1163
1164 let metrics = PerClassMetrics::compute(&predictions.view(), &targets.view()).unwrap();
1165
1166 assert_eq!(metrics.precision.len(), 3);
1167 assert_eq!(metrics.recall.len(), 3);
1168 assert_eq!(metrics.f1_score.len(), 3);
1169 assert_eq!(metrics.support.len(), 3);
1170
1171 assert_eq!(metrics.support[0], 2);
1173 assert_eq!(metrics.support[1], 1);
1174 assert_eq!(metrics.support[2], 1);
1175 }
1176
1177 #[test]
1178 fn test_matthews_correlation_coefficient() {
1179 let metric = MatthewsCorrelationCoefficient;
1180
1181 let predictions = array![[0.9, 0.1], [0.1, 0.9], [0.9, 0.1], [0.1, 0.9]];
1183 let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
1184
1185 let mcc = metric
1186 .compute(&predictions.view(), &targets.view())
1187 .unwrap();
1188 assert!((mcc - 1.0).abs() < 1e-6);
1189
1190 let predictions = array![[0.5, 0.5], [0.5, 0.5], [0.5, 0.5], [0.5, 0.5]];
1192 let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
1193
1194 let mcc = metric
1195 .compute(&predictions.view(), &targets.view())
1196 .unwrap();
1197 assert!(mcc.abs() < 0.1);
1198 }
1199
1200 #[test]
1201 fn test_cohens_kappa() {
1202 let metric = CohensKappa;
1203
1204 let predictions = array![[0.9, 0.1], [0.1, 0.9], [0.9, 0.1], [0.1, 0.9]];
1206 let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
1207
1208 let kappa = metric
1209 .compute(&predictions.view(), &targets.view())
1210 .unwrap();
1211 assert!((kappa - 1.0).abs() < 1e-6);
1212
1213 let predictions = array![[0.9, 0.1], [0.9, 0.1], [0.9, 0.1], [0.9, 0.1]];
1215 let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
1216
1217 let kappa = metric
1218 .compute(&predictions.view(), &targets.view())
1219 .unwrap();
1220 assert!((-1.0..=1.0).contains(&kappa));
1221 }
1222
1223 #[test]
1224 fn test_top_k_accuracy() {
1225 let metric = TopKAccuracy::new(2);
1226
1227 let predictions = array![
1229 [0.7, 0.2, 0.1], [0.1, 0.6, 0.3], [0.3, 0.4, 0.3], ];
1233 let targets = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]];
1234
1235 let top_k = metric
1236 .compute(&predictions.view(), &targets.view())
1237 .unwrap();
1238 assert!((0.0..=1.0).contains(&top_k));
1239 assert!(top_k >= 0.66); }
1241
1242 #[test]
1243 fn test_balanced_accuracy() {
1244 let metric = BalancedAccuracy;
1245
1246 let predictions = array![[0.9, 0.1], [0.1, 0.9], [0.9, 0.1], [0.1, 0.9]];
1248 let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
1249
1250 let balanced_acc = metric
1251 .compute(&predictions.view(), &targets.view())
1252 .unwrap();
1253 assert!((balanced_acc - 1.0).abs() < 1e-6);
1254
1255 let predictions = array![[0.9, 0.1], [0.9, 0.1], [0.9, 0.1], [0.1, 0.9]];
1257 let targets = array![[1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0]];
1258
1259 let balanced_acc = metric
1260 .compute(&predictions.view(), &targets.view())
1261 .unwrap();
1262 assert!((balanced_acc - 1.0).abs() < 1e-6);
1263 }
1264}