Skip to main content

tensorlogic_train/metrics/
advanced.rs

1//! Advanced classification metrics.
2
3use crate::{TrainError, TrainResult};
4use scirs2_core::ndarray::{ArrayView, Ix2};
5
6use super::Metric;
7
8/// Confusion matrix for multi-class classification.
9#[derive(Debug, Clone)]
10pub struct ConfusionMatrix {
11    /// Number of classes.
12    pub(crate) num_classes: usize,
13    /// Confusion matrix (rows=true labels, cols=predicted labels).
14    pub(crate) matrix: Vec<Vec<usize>>,
15}
16
17impl ConfusionMatrix {
18    /// Create a new confusion matrix.
19    ///
20    /// # Arguments
21    /// * `num_classes` - Number of classes
22    pub fn new(num_classes: usize) -> Self {
23        Self {
24            num_classes,
25            matrix: vec![vec![0; num_classes]; num_classes],
26        }
27    }
28
29    /// Compute confusion matrix from predictions and targets.
30    ///
31    /// # Arguments
32    /// * `predictions` - Model predictions (one-hot or class probabilities)
33    /// * `targets` - True labels (one-hot encoded)
34    ///
35    /// # Returns
36    /// Confusion matrix
37    pub fn compute(
38        predictions: &ArrayView<f64, Ix2>,
39        targets: &ArrayView<f64, Ix2>,
40    ) -> TrainResult<Self> {
41        if predictions.shape() != targets.shape() {
42            return Err(TrainError::MetricsError(format!(
43                "Shape mismatch: predictions {:?} vs targets {:?}",
44                predictions.shape(),
45                targets.shape()
46            )));
47        }
48
49        let num_classes = predictions.ncols();
50        let mut matrix = vec![vec![0; num_classes]; num_classes];
51
52        for i in 0..predictions.nrows() {
53            // Find predicted class (argmax)
54            let mut pred_class = 0;
55            let mut max_pred = predictions[[i, 0]];
56            for j in 1..num_classes {
57                if predictions[[i, j]] > max_pred {
58                    max_pred = predictions[[i, j]];
59                    pred_class = j;
60                }
61            }
62
63            // Find true class (argmax)
64            let mut true_class = 0;
65            let mut max_true = targets[[i, 0]];
66            for j in 1..num_classes {
67                if targets[[i, j]] > max_true {
68                    max_true = targets[[i, j]];
69                    true_class = j;
70                }
71            }
72
73            matrix[true_class][pred_class] += 1;
74        }
75
76        Ok(Self {
77            num_classes,
78            matrix,
79        })
80    }
81
82    /// Get the confusion matrix.
83    pub fn matrix(&self) -> &Vec<Vec<usize>> {
84        &self.matrix
85    }
86
87    /// Get value at (true_class, pred_class).
88    pub fn get(&self, true_class: usize, pred_class: usize) -> usize {
89        self.matrix[true_class][pred_class]
90    }
91
92    /// Compute per-class precision.
93    pub fn precision_per_class(&self) -> Vec<f64> {
94        let mut precisions = Vec::with_capacity(self.num_classes);
95
96        for pred_class in 0..self.num_classes {
97            let mut predicted_positive = 0;
98            let mut true_positive = 0;
99
100            for true_class in 0..self.num_classes {
101                predicted_positive += self.matrix[true_class][pred_class];
102                if true_class == pred_class {
103                    true_positive += self.matrix[true_class][pred_class];
104                }
105            }
106
107            let precision = if predicted_positive == 0 {
108                0.0
109            } else {
110                true_positive as f64 / predicted_positive as f64
111            };
112            precisions.push(precision);
113        }
114
115        precisions
116    }
117
118    /// Compute per-class recall.
119    pub fn recall_per_class(&self) -> Vec<f64> {
120        let mut recalls = Vec::with_capacity(self.num_classes);
121
122        for true_class in 0..self.num_classes {
123            let mut actual_positive = 0;
124            let mut true_positive = 0;
125
126            for pred_class in 0..self.num_classes {
127                actual_positive += self.matrix[true_class][pred_class];
128                if true_class == pred_class {
129                    true_positive += self.matrix[true_class][pred_class];
130                }
131            }
132
133            let recall = if actual_positive == 0 {
134                0.0
135            } else {
136                true_positive as f64 / actual_positive as f64
137            };
138            recalls.push(recall);
139        }
140
141        recalls
142    }
143
144    /// Compute per-class F1 scores.
145    pub fn f1_per_class(&self) -> Vec<f64> {
146        let precisions = self.precision_per_class();
147        let recalls = self.recall_per_class();
148
149        precisions
150            .iter()
151            .zip(recalls.iter())
152            .map(|(p, r)| {
153                if p + r == 0.0 {
154                    0.0
155                } else {
156                    2.0 * p * r / (p + r)
157                }
158            })
159            .collect()
160    }
161
162    /// Compute overall accuracy.
163    pub fn accuracy(&self) -> f64 {
164        let mut correct = 0;
165        let mut total = 0;
166
167        for i in 0..self.num_classes {
168            for j in 0..self.num_classes {
169                total += self.matrix[i][j];
170                if i == j {
171                    correct += self.matrix[i][j];
172                }
173            }
174        }
175
176        if total == 0 {
177            0.0
178        } else {
179            correct as f64 / total as f64
180        }
181    }
182
183    /// Get total number of predictions.
184    pub fn total_predictions(&self) -> usize {
185        let mut total = 0;
186        for i in 0..self.num_classes {
187            for j in 0..self.num_classes {
188                total += self.matrix[i][j];
189            }
190        }
191        total
192    }
193}
194
195impl std::fmt::Display for ConfusionMatrix {
196    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197        writeln!(f, "Confusion Matrix:")?;
198        write!(f, "     ")?;
199
200        for j in 0..self.num_classes {
201            write!(f, "{:5}", j)?;
202        }
203        writeln!(f)?;
204
205        for i in 0..self.num_classes {
206            write!(f, "{:3}| ", i)?;
207            for j in 0..self.num_classes {
208                write!(f, "{:5}", self.matrix[i][j])?;
209            }
210            writeln!(f)?;
211        }
212
213        Ok(())
214    }
215}
216
217/// ROC curve and AUC computation utilities.
218#[derive(Debug, Clone)]
219pub struct RocCurve {
220    /// False positive rates.
221    pub fpr: Vec<f64>,
222    /// True positive rates.
223    pub tpr: Vec<f64>,
224    /// Thresholds.
225    pub thresholds: Vec<f64>,
226}
227
228impl RocCurve {
229    /// Compute ROC curve for binary classification.
230    ///
231    /// # Arguments
232    /// * `predictions` - Predicted probabilities for positive class
233    /// * `targets` - True binary labels (0 or 1)
234    ///
235    /// # Returns
236    /// ROC curve with FPR, TPR, and thresholds
237    pub fn compute(predictions: &[f64], targets: &[bool]) -> TrainResult<Self> {
238        if predictions.len() != targets.len() {
239            return Err(TrainError::MetricsError(format!(
240                "Length mismatch: predictions {} vs targets {}",
241                predictions.len(),
242                targets.len()
243            )));
244        }
245
246        // Create sorted indices by prediction score (descending)
247        let mut indices: Vec<usize> = (0..predictions.len()).collect();
248        indices.sort_by(|&a, &b| {
249            predictions[b]
250                .partial_cmp(&predictions[a])
251                .unwrap_or(std::cmp::Ordering::Equal)
252        });
253
254        let mut fpr = Vec::new();
255        let mut tpr = Vec::new();
256        let mut thresholds = Vec::new();
257
258        let num_positive = targets.iter().filter(|&&x| x).count();
259        let num_negative = targets.len() - num_positive;
260
261        let mut true_positives = 0;
262        let mut false_positives = 0;
263
264        // Start with all predictions as negative
265        fpr.push(0.0);
266        tpr.push(0.0);
267        thresholds.push(f64::INFINITY);
268
269        for &idx in &indices {
270            if targets[idx] {
271                true_positives += 1;
272            } else {
273                false_positives += 1;
274            }
275
276            let fpr_val = if num_negative == 0 {
277                0.0
278            } else {
279                false_positives as f64 / num_negative as f64
280            };
281            let tpr_val = if num_positive == 0 {
282                0.0
283            } else {
284                true_positives as f64 / num_positive as f64
285            };
286
287            fpr.push(fpr_val);
288            tpr.push(tpr_val);
289            thresholds.push(predictions[idx]);
290        }
291
292        Ok(Self {
293            fpr,
294            tpr,
295            thresholds,
296        })
297    }
298
299    /// Compute area under the ROC curve (AUC) using trapezoidal rule.
300    pub fn auc(&self) -> f64 {
301        let mut auc = 0.0;
302
303        for i in 1..self.fpr.len() {
304            let width = self.fpr[i] - self.fpr[i - 1];
305            let height = (self.tpr[i] + self.tpr[i - 1]) / 2.0;
306            auc += width * height;
307        }
308
309        auc
310    }
311}
312
313/// Per-class metrics report.
314#[derive(Debug, Clone)]
315pub struct PerClassMetrics {
316    /// Precision per class.
317    pub precision: Vec<f64>,
318    /// Recall per class.
319    pub recall: Vec<f64>,
320    /// F1 score per class.
321    pub f1_score: Vec<f64>,
322    /// Support (number of samples) per class.
323    pub support: Vec<usize>,
324}
325
326impl PerClassMetrics {
327    /// Compute per-class metrics from predictions and targets.
328    ///
329    /// # Arguments
330    /// * `predictions` - Model predictions (one-hot or class probabilities)
331    /// * `targets` - True labels (one-hot encoded)
332    ///
333    /// # Returns
334    /// Per-class metrics report
335    pub fn compute(
336        predictions: &ArrayView<f64, Ix2>,
337        targets: &ArrayView<f64, Ix2>,
338    ) -> TrainResult<Self> {
339        let confusion_matrix = ConfusionMatrix::compute(predictions, targets)?;
340
341        let precision = confusion_matrix.precision_per_class();
342        let recall = confusion_matrix.recall_per_class();
343        let f1_score = confusion_matrix.f1_per_class();
344
345        // Compute support (number of samples per class)
346        let num_classes = targets.ncols();
347        let mut support = vec![0; num_classes];
348
349        for i in 0..targets.nrows() {
350            // Find true class
351            let mut true_class = 0;
352            let mut max_true = targets[[i, 0]];
353            for j in 1..num_classes {
354                if targets[[i, j]] > max_true {
355                    max_true = targets[[i, j]];
356                    true_class = j;
357                }
358            }
359            support[true_class] += 1;
360        }
361
362        Ok(Self {
363            precision,
364            recall,
365            f1_score,
366            support,
367        })
368    }
369}
370
371impl std::fmt::Display for PerClassMetrics {
372    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
373        writeln!(f, "Per-Class Metrics:")?;
374        writeln!(f, "Class  Precision  Recall  F1-Score  Support")?;
375        writeln!(f, "-----  ---------  ------  --------  -------")?;
376
377        for i in 0..self.precision.len() {
378            writeln!(
379                f,
380                "{:5}  {:9.4}  {:6.4}  {:8.4}  {:7}",
381                i, self.precision[i], self.recall[i], self.f1_score[i], self.support[i]
382            )?;
383        }
384
385        // Compute macro averages
386        let macro_precision: f64 = self.precision.iter().sum::<f64>() / self.precision.len() as f64;
387        let macro_recall: f64 = self.recall.iter().sum::<f64>() / self.recall.len() as f64;
388        let macro_f1: f64 = self.f1_score.iter().sum::<f64>() / self.f1_score.len() as f64;
389        let total_support: usize = self.support.iter().sum();
390
391        writeln!(f, "-----  ---------  ------  --------  -------")?;
392        writeln!(
393            f,
394            "Macro  {:9.4}  {:6.4}  {:8.4}  {:7}",
395            macro_precision, macro_recall, macro_f1, total_support
396        )?;
397
398        Ok(())
399    }
400}
401
402/// Matthews Correlation Coefficient (MCC) metric.
403/// Ranges from -1 to +1, where +1 is perfect prediction, 0 is random, -1 is total disagreement.
404/// Particularly useful for imbalanced datasets.
405#[derive(Debug, Clone, Default)]
406pub struct MatthewsCorrelationCoefficient;
407
408impl Metric for MatthewsCorrelationCoefficient {
409    fn compute(
410        &self,
411        predictions: &ArrayView<f64, Ix2>,
412        targets: &ArrayView<f64, Ix2>,
413    ) -> TrainResult<f64> {
414        let confusion_matrix = ConfusionMatrix::compute(predictions, targets)?;
415        let num_classes = confusion_matrix.num_classes;
416
417        // For binary classification, use the standard MCC formula
418        if num_classes == 2 {
419            let tp = confusion_matrix.matrix[1][1] as f64;
420            let tn = confusion_matrix.matrix[0][0] as f64;
421            let fp = confusion_matrix.matrix[0][1] as f64;
422            let fn_val = confusion_matrix.matrix[1][0] as f64;
423
424            let numerator = (tp * tn) - (fp * fn_val);
425            let denominator = ((tp + fp) * (tp + fn_val) * (tn + fp) * (tn + fn_val)).sqrt();
426
427            if denominator == 0.0 {
428                Ok(0.0)
429            } else {
430                Ok(numerator / denominator)
431            }
432        } else {
433            // Multi-class MCC formula
434            let mut s = 0.0;
435            let mut c = 0.0;
436            let t = confusion_matrix.total_predictions() as f64;
437
438            // Compute column sums
439            let mut p_k = vec![0.0; num_classes];
440            let mut t_k = vec![0.0; num_classes];
441
442            for k in 0..num_classes {
443                for l in 0..num_classes {
444                    p_k[k] += confusion_matrix.matrix[l][k] as f64;
445                    t_k[k] += confusion_matrix.matrix[k][l] as f64;
446                }
447            }
448
449            // Compute trace (correct predictions)
450            for k in 0..num_classes {
451                c += confusion_matrix.matrix[k][k] as f64;
452            }
453
454            // Compute sum of products
455            for k in 0..num_classes {
456                s += p_k[k] * t_k[k];
457            }
458
459            let numerator = (t * c) - s;
460            let denominator_1 = ((t * t) - s).sqrt();
461            let mut sum_p_sq = 0.0;
462            let mut sum_t_sq = 0.0;
463            for k in 0..num_classes {
464                sum_p_sq += p_k[k] * p_k[k];
465                sum_t_sq += t_k[k] * t_k[k];
466            }
467            let denominator_2 = ((t * t) - sum_p_sq).sqrt();
468            let denominator_3 = ((t * t) - sum_t_sq).sqrt();
469
470            let denominator = denominator_1 * denominator_2 * denominator_3;
471
472            if denominator == 0.0 {
473                Ok(0.0)
474            } else {
475                Ok(numerator / denominator)
476            }
477        }
478    }
479
480    fn name(&self) -> &str {
481        "mcc"
482    }
483}
484
485/// Cohen's Kappa statistic.
486/// Measures inter-rater agreement, accounting for chance agreement.
487/// Ranges from -1 to +1, where 1 is perfect agreement, 0 is random chance.
488#[derive(Debug, Clone, Default)]
489pub struct CohensKappa;
490
491impl Metric for CohensKappa {
492    fn compute(
493        &self,
494        predictions: &ArrayView<f64, Ix2>,
495        targets: &ArrayView<f64, Ix2>,
496    ) -> TrainResult<f64> {
497        let confusion_matrix = ConfusionMatrix::compute(predictions, targets)?;
498        let num_classes = confusion_matrix.num_classes;
499        let total = confusion_matrix.total_predictions() as f64;
500
501        // Observed agreement (accuracy)
502        let mut observed = 0.0;
503        for i in 0..num_classes {
504            observed += confusion_matrix.matrix[i][i] as f64;
505        }
506        observed /= total;
507
508        // Expected agreement by chance
509        let mut expected = 0.0;
510        for i in 0..num_classes {
511            let row_sum: f64 = (0..num_classes)
512                .map(|j| confusion_matrix.matrix[i][j] as f64)
513                .sum();
514            let col_sum: f64 = (0..num_classes)
515                .map(|j| confusion_matrix.matrix[j][i] as f64)
516                .sum();
517            expected += (row_sum / total) * (col_sum / total);
518        }
519
520        if expected >= 1.0 {
521            Ok(0.0)
522        } else {
523            Ok((observed - expected) / (1.0 - expected))
524        }
525    }
526
527    fn name(&self) -> &str {
528        "cohens_kappa"
529    }
530}
531
532/// Balanced accuracy metric.
533/// Average of recall per class, useful for imbalanced datasets.
534#[derive(Debug, Clone, Default)]
535pub struct BalancedAccuracy;
536
537impl Metric for BalancedAccuracy {
538    fn compute(
539        &self,
540        predictions: &ArrayView<f64, Ix2>,
541        targets: &ArrayView<f64, Ix2>,
542    ) -> TrainResult<f64> {
543        let confusion_matrix = ConfusionMatrix::compute(predictions, targets)?;
544        let recalls = confusion_matrix.recall_per_class();
545
546        // Balanced accuracy is the average of recall per class
547        let sum: f64 = recalls.iter().sum();
548        Ok(sum / recalls.len() as f64)
549    }
550
551    fn name(&self) -> &str {
552        "balanced_accuracy"
553    }
554}
555
556#[cfg(test)]
557mod tests {
558    use super::*;
559    use scirs2_core::ndarray::array;
560
561    #[test]
562    fn test_confusion_matrix() {
563        let predictions = array![
564            [0.9, 0.1, 0.0],
565            [0.1, 0.8, 0.1],
566            [0.2, 0.1, 0.7],
567            [0.8, 0.1, 0.1]
568        ];
569        let targets = array![
570            [1.0, 0.0, 0.0],
571            [0.0, 1.0, 0.0],
572            [0.0, 0.0, 1.0],
573            [1.0, 0.0, 0.0]
574        ];
575
576        let cm = ConfusionMatrix::compute(&predictions.view(), &targets.view()).unwrap();
577
578        assert_eq!(cm.get(0, 0), 2); // Class 0 correctly predicted
579        assert_eq!(cm.get(1, 1), 1); // Class 1 correctly predicted
580        assert_eq!(cm.get(2, 2), 1); // Class 2 correctly predicted
581        assert_eq!(cm.accuracy(), 1.0);
582    }
583
584    #[test]
585    fn test_confusion_matrix_per_class_metrics() {
586        let predictions = array![[0.9, 0.1], [0.2, 0.8], [0.7, 0.3], [0.1, 0.9]];
587        let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
588
589        let cm = ConfusionMatrix::compute(&predictions.view(), &targets.view()).unwrap();
590
591        let precision = cm.precision_per_class();
592        let recall = cm.recall_per_class();
593        let f1 = cm.f1_per_class();
594
595        assert_eq!(precision.len(), 2);
596        assert_eq!(recall.len(), 2);
597        assert_eq!(f1.len(), 2);
598
599        // All predictions correct
600        assert_eq!(precision[0], 1.0);
601        assert_eq!(precision[1], 1.0);
602        assert_eq!(recall[0], 1.0);
603        assert_eq!(recall[1], 1.0);
604    }
605
606    #[test]
607    fn test_roc_curve() {
608        let predictions = vec![0.9, 0.8, 0.4, 0.3, 0.1];
609        let targets = vec![true, true, false, true, false];
610
611        let roc = RocCurve::compute(&predictions, &targets).unwrap();
612
613        assert!(!roc.fpr.is_empty());
614        assert!(!roc.tpr.is_empty());
615        assert!(!roc.thresholds.is_empty());
616        assert_eq!(roc.fpr.len(), roc.tpr.len());
617
618        let auc = roc.auc();
619        assert!((0.0..=1.0).contains(&auc));
620    }
621
622    #[test]
623    fn test_roc_auc_perfect() {
624        let predictions = vec![0.9, 0.8, 0.3, 0.1];
625        let targets = vec![true, true, false, false];
626
627        let roc = RocCurve::compute(&predictions, &targets).unwrap();
628        let auc = roc.auc();
629
630        // Perfect classification should have AUC = 1.0
631        assert!((auc - 1.0).abs() < 1e-6);
632    }
633
634    #[test]
635    fn test_per_class_metrics() {
636        let predictions = array![
637            [0.9, 0.1, 0.0],
638            [0.1, 0.8, 0.1],
639            [0.2, 0.1, 0.7],
640            [0.8, 0.1, 0.1]
641        ];
642        let targets = array![
643            [1.0, 0.0, 0.0],
644            [0.0, 1.0, 0.0],
645            [0.0, 0.0, 1.0],
646            [1.0, 0.0, 0.0]
647        ];
648
649        let metrics = PerClassMetrics::compute(&predictions.view(), &targets.view()).unwrap();
650
651        assert_eq!(metrics.precision.len(), 3);
652        assert_eq!(metrics.recall.len(), 3);
653        assert_eq!(metrics.f1_score.len(), 3);
654        assert_eq!(metrics.support.len(), 3);
655
656        // Check support counts
657        assert_eq!(metrics.support[0], 2);
658        assert_eq!(metrics.support[1], 1);
659        assert_eq!(metrics.support[2], 1);
660    }
661
662    #[test]
663    fn test_matthews_correlation_coefficient() {
664        let metric = MatthewsCorrelationCoefficient;
665
666        // Perfect binary classification
667        let predictions = array![[0.9, 0.1], [0.1, 0.9], [0.9, 0.1], [0.1, 0.9]];
668        let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
669
670        let mcc = metric
671            .compute(&predictions.view(), &targets.view())
672            .unwrap();
673        assert!((mcc - 1.0).abs() < 1e-6);
674
675        // Random classification
676        let predictions = array![[0.5, 0.5], [0.5, 0.5], [0.5, 0.5], [0.5, 0.5]];
677        let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
678
679        let mcc = metric
680            .compute(&predictions.view(), &targets.view())
681            .unwrap();
682        assert!(mcc.abs() < 0.1);
683    }
684
685    #[test]
686    fn test_cohens_kappa() {
687        let metric = CohensKappa;
688
689        // Perfect agreement
690        let predictions = array![[0.9, 0.1], [0.1, 0.9], [0.9, 0.1], [0.1, 0.9]];
691        let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
692
693        let kappa = metric
694            .compute(&predictions.view(), &targets.view())
695            .unwrap();
696        assert!((kappa - 1.0).abs() < 1e-6);
697
698        // Random agreement
699        let predictions = array![[0.9, 0.1], [0.9, 0.1], [0.9, 0.1], [0.9, 0.1]];
700        let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
701
702        let kappa = metric
703            .compute(&predictions.view(), &targets.view())
704            .unwrap();
705        assert!((-1.0..=1.0).contains(&kappa));
706    }
707
708    #[test]
709    fn test_balanced_accuracy() {
710        let metric = BalancedAccuracy;
711
712        // Perfect classification
713        let predictions = array![[0.9, 0.1], [0.1, 0.9], [0.9, 0.1], [0.1, 0.9]];
714        let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
715
716        let balanced_acc = metric
717            .compute(&predictions.view(), &targets.view())
718            .unwrap();
719        assert!((balanced_acc - 1.0).abs() < 1e-6);
720
721        // Imbalanced but perfect
722        let predictions = array![[0.9, 0.1], [0.9, 0.1], [0.9, 0.1], [0.1, 0.9]];
723        let targets = array![[1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0]];
724
725        let balanced_acc = metric
726            .compute(&predictions.view(), &targets.view())
727            .unwrap();
728        assert!((balanced_acc - 1.0).abs() < 1e-6);
729    }
730}