tensorlogic_train/
metrics.rs

1//! Metrics for evaluating model performance.
2
3use crate::{TrainError, TrainResult};
4use scirs2_core::ndarray::{ArrayView, Ix2};
5use std::collections::HashMap;
6
7/// Trait for metrics.
8pub trait Metric {
9    /// Compute metric value.
10    fn compute(
11        &self,
12        predictions: &ArrayView<f64, Ix2>,
13        targets: &ArrayView<f64, Ix2>,
14    ) -> TrainResult<f64>;
15
16    /// Get metric name.
17    fn name(&self) -> &str;
18
19    /// Reset metric state (for stateful metrics).
20    fn reset(&mut self) {}
21}
22
23/// Accuracy metric for classification.
24#[derive(Debug, Clone)]
25pub struct Accuracy {
26    /// Threshold for binary classification.
27    pub threshold: f64,
28}
29
30impl Default for Accuracy {
31    fn default() -> Self {
32        Self { threshold: 0.5 }
33    }
34}
35
36impl Metric for Accuracy {
37    fn compute(
38        &self,
39        predictions: &ArrayView<f64, Ix2>,
40        targets: &ArrayView<f64, Ix2>,
41    ) -> TrainResult<f64> {
42        if predictions.shape() != targets.shape() {
43            return Err(TrainError::MetricsError(format!(
44                "Shape mismatch: predictions {:?} vs targets {:?}",
45                predictions.shape(),
46                targets.shape()
47            )));
48        }
49
50        let mut correct = 0;
51        let total = predictions.nrows();
52
53        for i in 0..total {
54            // Find predicted class (argmax)
55            let mut pred_class = 0;
56            let mut max_pred = predictions[[i, 0]];
57            for j in 1..predictions.ncols() {
58                if predictions[[i, j]] > max_pred {
59                    max_pred = predictions[[i, j]];
60                    pred_class = j;
61                }
62            }
63
64            // Find true class (argmax)
65            let mut true_class = 0;
66            let mut max_true = targets[[i, 0]];
67            for j in 1..targets.ncols() {
68                if targets[[i, j]] > max_true {
69                    max_true = targets[[i, j]];
70                    true_class = j;
71                }
72            }
73
74            if pred_class == true_class {
75                correct += 1;
76            }
77        }
78
79        Ok(correct as f64 / total as f64)
80    }
81
82    fn name(&self) -> &str {
83        "accuracy"
84    }
85}
86
87/// Precision metric for classification.
88#[derive(Debug, Clone, Default)]
89pub struct Precision {
90    /// Class to compute precision for (None = macro average).
91    pub class_id: Option<usize>,
92}
93
94impl Metric for Precision {
95    fn compute(
96        &self,
97        predictions: &ArrayView<f64, Ix2>,
98        targets: &ArrayView<f64, Ix2>,
99    ) -> TrainResult<f64> {
100        if predictions.shape() != targets.shape() {
101            return Err(TrainError::MetricsError(format!(
102                "Shape mismatch: predictions {:?} vs targets {:?}",
103                predictions.shape(),
104                targets.shape()
105            )));
106        }
107
108        let num_classes = predictions.ncols();
109        let mut true_positives = vec![0; num_classes];
110        let mut predicted_positives = vec![0; num_classes];
111
112        for i in 0..predictions.nrows() {
113            // Find predicted class
114            let mut pred_class = 0;
115            let mut max_pred = predictions[[i, 0]];
116            for j in 1..num_classes {
117                if predictions[[i, j]] > max_pred {
118                    max_pred = predictions[[i, j]];
119                    pred_class = j;
120                }
121            }
122
123            // Find true class
124            let mut true_class = 0;
125            let mut max_true = targets[[i, 0]];
126            for j in 1..num_classes {
127                if targets[[i, j]] > max_true {
128                    max_true = targets[[i, j]];
129                    true_class = j;
130                }
131            }
132
133            predicted_positives[pred_class] += 1;
134            if pred_class == true_class {
135                true_positives[pred_class] += 1;
136            }
137        }
138
139        if let Some(class_id) = self.class_id {
140            // Precision for specific class
141            if predicted_positives[class_id] == 0 {
142                Ok(0.0)
143            } else {
144                Ok(true_positives[class_id] as f64 / predicted_positives[class_id] as f64)
145            }
146        } else {
147            // Macro-averaged precision
148            let mut total_precision = 0.0;
149            let mut valid_classes = 0;
150
151            for class_id in 0..num_classes {
152                if predicted_positives[class_id] > 0 {
153                    total_precision +=
154                        true_positives[class_id] as f64 / predicted_positives[class_id] as f64;
155                    valid_classes += 1;
156                }
157            }
158
159            if valid_classes == 0 {
160                Ok(0.0)
161            } else {
162                Ok(total_precision / valid_classes as f64)
163            }
164        }
165    }
166
167    fn name(&self) -> &str {
168        "precision"
169    }
170}
171
172/// Recall metric for classification.
173#[derive(Debug, Clone, Default)]
174pub struct Recall {
175    /// Class to compute recall for (None = macro average).
176    pub class_id: Option<usize>,
177}
178
179impl Metric for Recall {
180    fn compute(
181        &self,
182        predictions: &ArrayView<f64, Ix2>,
183        targets: &ArrayView<f64, Ix2>,
184    ) -> TrainResult<f64> {
185        if predictions.shape() != targets.shape() {
186            return Err(TrainError::MetricsError(format!(
187                "Shape mismatch: predictions {:?} vs targets {:?}",
188                predictions.shape(),
189                targets.shape()
190            )));
191        }
192
193        let num_classes = predictions.ncols();
194        let mut true_positives = vec![0; num_classes];
195        let mut actual_positives = vec![0; num_classes];
196
197        for i in 0..predictions.nrows() {
198            // Find predicted class
199            let mut pred_class = 0;
200            let mut max_pred = predictions[[i, 0]];
201            for j in 1..num_classes {
202                if predictions[[i, j]] > max_pred {
203                    max_pred = predictions[[i, j]];
204                    pred_class = j;
205                }
206            }
207
208            // Find true class
209            let mut true_class = 0;
210            let mut max_true = targets[[i, 0]];
211            for j in 1..num_classes {
212                if targets[[i, j]] > max_true {
213                    max_true = targets[[i, j]];
214                    true_class = j;
215                }
216            }
217
218            actual_positives[true_class] += 1;
219            if pred_class == true_class {
220                true_positives[pred_class] += 1;
221            }
222        }
223
224        if let Some(class_id) = self.class_id {
225            // Recall for specific class
226            if actual_positives[class_id] == 0 {
227                Ok(0.0)
228            } else {
229                Ok(true_positives[class_id] as f64 / actual_positives[class_id] as f64)
230            }
231        } else {
232            // Macro-averaged recall
233            let mut total_recall = 0.0;
234            let mut valid_classes = 0;
235
236            for class_id in 0..num_classes {
237                if actual_positives[class_id] > 0 {
238                    total_recall +=
239                        true_positives[class_id] as f64 / actual_positives[class_id] as f64;
240                    valid_classes += 1;
241                }
242            }
243
244            if valid_classes == 0 {
245                Ok(0.0)
246            } else {
247                Ok(total_recall / valid_classes as f64)
248            }
249        }
250    }
251
252    fn name(&self) -> &str {
253        "recall"
254    }
255}
256
257/// F1 score metric for classification.
258#[derive(Debug, Clone, Default)]
259pub struct F1Score {
260    /// Class to compute F1 for (None = macro average).
261    pub class_id: Option<usize>,
262}
263
264impl Metric for F1Score {
265    fn compute(
266        &self,
267        predictions: &ArrayView<f64, Ix2>,
268        targets: &ArrayView<f64, Ix2>,
269    ) -> TrainResult<f64> {
270        let precision = Precision {
271            class_id: self.class_id,
272        }
273        .compute(predictions, targets)?;
274        let recall = Recall {
275            class_id: self.class_id,
276        }
277        .compute(predictions, targets)?;
278
279        if precision + recall == 0.0 {
280            Ok(0.0)
281        } else {
282            Ok(2.0 * precision * recall / (precision + recall))
283        }
284    }
285
286    fn name(&self) -> &str {
287        "f1_score"
288    }
289}
290
291/// Metric tracker for managing multiple metrics.
292pub struct MetricTracker {
293    /// Metrics to track.
294    metrics: Vec<Box<dyn Metric>>,
295    /// History of metric values.
296    history: HashMap<String, Vec<f64>>,
297}
298
299impl MetricTracker {
300    /// Create a new metric tracker.
301    pub fn new() -> Self {
302        Self {
303            metrics: Vec::new(),
304            history: HashMap::new(),
305        }
306    }
307
308    /// Add a metric to track.
309    pub fn add(&mut self, metric: Box<dyn Metric>) {
310        let name = metric.name().to_string();
311        self.history.insert(name, Vec::new());
312        self.metrics.push(metric);
313    }
314
315    /// Compute all metrics.
316    pub fn compute_all(
317        &mut self,
318        predictions: &ArrayView<f64, Ix2>,
319        targets: &ArrayView<f64, Ix2>,
320    ) -> TrainResult<HashMap<String, f64>> {
321        let mut results = HashMap::new();
322
323        for metric in &self.metrics {
324            let value = metric.compute(predictions, targets)?;
325            let name = metric.name().to_string();
326
327            results.insert(name.clone(), value);
328
329            if let Some(history) = self.history.get_mut(&name) {
330                history.push(value);
331            }
332        }
333
334        Ok(results)
335    }
336
337    /// Get history for a specific metric.
338    pub fn get_history(&self, metric_name: &str) -> Option<&Vec<f64>> {
339        self.history.get(metric_name)
340    }
341
342    /// Reset all metrics.
343    pub fn reset(&mut self) {
344        for metric in &mut self.metrics {
345            metric.reset();
346        }
347    }
348
349    /// Clear history.
350    pub fn clear_history(&mut self) {
351        for history in self.history.values_mut() {
352            history.clear();
353        }
354    }
355}
356
357impl Default for MetricTracker {
358    fn default() -> Self {
359        Self::new()
360    }
361}
362
363/// Confusion matrix for multi-class classification.
364#[derive(Debug, Clone)]
365pub struct ConfusionMatrix {
366    /// Number of classes.
367    num_classes: usize,
368    /// Confusion matrix (rows=true labels, cols=predicted labels).
369    matrix: Vec<Vec<usize>>,
370}
371
372impl ConfusionMatrix {
373    /// Create a new confusion matrix.
374    ///
375    /// # Arguments
376    /// * `num_classes` - Number of classes
377    pub fn new(num_classes: usize) -> Self {
378        Self {
379            num_classes,
380            matrix: vec![vec![0; num_classes]; num_classes],
381        }
382    }
383
384    /// Compute confusion matrix from predictions and targets.
385    ///
386    /// # Arguments
387    /// * `predictions` - Model predictions (one-hot or class probabilities)
388    /// * `targets` - True labels (one-hot encoded)
389    ///
390    /// # Returns
391    /// Confusion matrix
392    pub fn compute(
393        predictions: &ArrayView<f64, Ix2>,
394        targets: &ArrayView<f64, Ix2>,
395    ) -> TrainResult<Self> {
396        if predictions.shape() != targets.shape() {
397            return Err(TrainError::MetricsError(format!(
398                "Shape mismatch: predictions {:?} vs targets {:?}",
399                predictions.shape(),
400                targets.shape()
401            )));
402        }
403
404        let num_classes = predictions.ncols();
405        let mut matrix = vec![vec![0; num_classes]; num_classes];
406
407        for i in 0..predictions.nrows() {
408            // Find predicted class (argmax)
409            let mut pred_class = 0;
410            let mut max_pred = predictions[[i, 0]];
411            for j in 1..num_classes {
412                if predictions[[i, j]] > max_pred {
413                    max_pred = predictions[[i, j]];
414                    pred_class = j;
415                }
416            }
417
418            // Find true class (argmax)
419            let mut true_class = 0;
420            let mut max_true = targets[[i, 0]];
421            for j in 1..num_classes {
422                if targets[[i, j]] > max_true {
423                    max_true = targets[[i, j]];
424                    true_class = j;
425                }
426            }
427
428            matrix[true_class][pred_class] += 1;
429        }
430
431        Ok(Self {
432            num_classes,
433            matrix,
434        })
435    }
436
437    /// Get the confusion matrix.
438    pub fn matrix(&self) -> &Vec<Vec<usize>> {
439        &self.matrix
440    }
441
442    /// Get value at (true_class, pred_class).
443    pub fn get(&self, true_class: usize, pred_class: usize) -> usize {
444        self.matrix[true_class][pred_class]
445    }
446
447    /// Compute per-class precision.
448    pub fn precision_per_class(&self) -> Vec<f64> {
449        let mut precisions = Vec::with_capacity(self.num_classes);
450
451        for pred_class in 0..self.num_classes {
452            let mut predicted_positive = 0;
453            let mut true_positive = 0;
454
455            for true_class in 0..self.num_classes {
456                predicted_positive += self.matrix[true_class][pred_class];
457                if true_class == pred_class {
458                    true_positive += self.matrix[true_class][pred_class];
459                }
460            }
461
462            let precision = if predicted_positive == 0 {
463                0.0
464            } else {
465                true_positive as f64 / predicted_positive as f64
466            };
467            precisions.push(precision);
468        }
469
470        precisions
471    }
472
473    /// Compute per-class recall.
474    pub fn recall_per_class(&self) -> Vec<f64> {
475        let mut recalls = Vec::with_capacity(self.num_classes);
476
477        for true_class in 0..self.num_classes {
478            let mut actual_positive = 0;
479            let mut true_positive = 0;
480
481            for pred_class in 0..self.num_classes {
482                actual_positive += self.matrix[true_class][pred_class];
483                if true_class == pred_class {
484                    true_positive += self.matrix[true_class][pred_class];
485                }
486            }
487
488            let recall = if actual_positive == 0 {
489                0.0
490            } else {
491                true_positive as f64 / actual_positive as f64
492            };
493            recalls.push(recall);
494        }
495
496        recalls
497    }
498
499    /// Compute per-class F1 scores.
500    pub fn f1_per_class(&self) -> Vec<f64> {
501        let precisions = self.precision_per_class();
502        let recalls = self.recall_per_class();
503
504        precisions
505            .iter()
506            .zip(recalls.iter())
507            .map(|(p, r)| {
508                if p + r == 0.0 {
509                    0.0
510                } else {
511                    2.0 * p * r / (p + r)
512                }
513            })
514            .collect()
515    }
516
517    /// Compute overall accuracy.
518    pub fn accuracy(&self) -> f64 {
519        let mut correct = 0;
520        let mut total = 0;
521
522        for i in 0..self.num_classes {
523            for j in 0..self.num_classes {
524                total += self.matrix[i][j];
525                if i == j {
526                    correct += self.matrix[i][j];
527                }
528            }
529        }
530
531        if total == 0 {
532            0.0
533        } else {
534            correct as f64 / total as f64
535        }
536    }
537
538    /// Get total number of predictions.
539    pub fn total_predictions(&self) -> usize {
540        let mut total = 0;
541        for i in 0..self.num_classes {
542            for j in 0..self.num_classes {
543                total += self.matrix[i][j];
544            }
545        }
546        total
547    }
548}
549
550impl std::fmt::Display for ConfusionMatrix {
551    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
552        writeln!(f, "Confusion Matrix:")?;
553        write!(f, "     ")?;
554
555        for j in 0..self.num_classes {
556            write!(f, "{:5}", j)?;
557        }
558        writeln!(f)?;
559
560        for i in 0..self.num_classes {
561            write!(f, "{:3}| ", i)?;
562            for j in 0..self.num_classes {
563                write!(f, "{:5}", self.matrix[i][j])?;
564            }
565            writeln!(f)?;
566        }
567
568        Ok(())
569    }
570}
571
572/// ROC curve and AUC computation utilities.
573#[derive(Debug, Clone)]
574pub struct RocCurve {
575    /// False positive rates.
576    pub fpr: Vec<f64>,
577    /// True positive rates.
578    pub tpr: Vec<f64>,
579    /// Thresholds.
580    pub thresholds: Vec<f64>,
581}
582
583impl RocCurve {
584    /// Compute ROC curve for binary classification.
585    ///
586    /// # Arguments
587    /// * `predictions` - Predicted probabilities for positive class
588    /// * `targets` - True binary labels (0 or 1)
589    ///
590    /// # Returns
591    /// ROC curve with FPR, TPR, and thresholds
592    pub fn compute(predictions: &[f64], targets: &[bool]) -> TrainResult<Self> {
593        if predictions.len() != targets.len() {
594            return Err(TrainError::MetricsError(format!(
595                "Length mismatch: predictions {} vs targets {}",
596                predictions.len(),
597                targets.len()
598            )));
599        }
600
601        // Create sorted indices by prediction score (descending)
602        let mut indices: Vec<usize> = (0..predictions.len()).collect();
603        indices.sort_by(|&a, &b| {
604            predictions[b]
605                .partial_cmp(&predictions[a])
606                .unwrap_or(std::cmp::Ordering::Equal)
607        });
608
609        let mut fpr = Vec::new();
610        let mut tpr = Vec::new();
611        let mut thresholds = Vec::new();
612
613        let num_positive = targets.iter().filter(|&&x| x).count();
614        let num_negative = targets.len() - num_positive;
615
616        let mut true_positives = 0;
617        let mut false_positives = 0;
618
619        // Start with all predictions as negative
620        fpr.push(0.0);
621        tpr.push(0.0);
622        thresholds.push(f64::INFINITY);
623
624        for &idx in &indices {
625            if targets[idx] {
626                true_positives += 1;
627            } else {
628                false_positives += 1;
629            }
630
631            let fpr_val = if num_negative == 0 {
632                0.0
633            } else {
634                false_positives as f64 / num_negative as f64
635            };
636            let tpr_val = if num_positive == 0 {
637                0.0
638            } else {
639                true_positives as f64 / num_positive as f64
640            };
641
642            fpr.push(fpr_val);
643            tpr.push(tpr_val);
644            thresholds.push(predictions[idx]);
645        }
646
647        Ok(Self {
648            fpr,
649            tpr,
650            thresholds,
651        })
652    }
653
654    /// Compute area under the ROC curve (AUC) using trapezoidal rule.
655    pub fn auc(&self) -> f64 {
656        let mut auc = 0.0;
657
658        for i in 1..self.fpr.len() {
659            let width = self.fpr[i] - self.fpr[i - 1];
660            let height = (self.tpr[i] + self.tpr[i - 1]) / 2.0;
661            auc += width * height;
662        }
663
664        auc
665    }
666}
667
668/// Per-class metrics report.
669#[derive(Debug, Clone)]
670pub struct PerClassMetrics {
671    /// Precision per class.
672    pub precision: Vec<f64>,
673    /// Recall per class.
674    pub recall: Vec<f64>,
675    /// F1 score per class.
676    pub f1_score: Vec<f64>,
677    /// Support (number of samples) per class.
678    pub support: Vec<usize>,
679}
680
681impl PerClassMetrics {
682    /// Compute per-class metrics from predictions and targets.
683    ///
684    /// # Arguments
685    /// * `predictions` - Model predictions (one-hot or class probabilities)
686    /// * `targets` - True labels (one-hot encoded)
687    ///
688    /// # Returns
689    /// Per-class metrics report
690    pub fn compute(
691        predictions: &ArrayView<f64, Ix2>,
692        targets: &ArrayView<f64, Ix2>,
693    ) -> TrainResult<Self> {
694        let confusion_matrix = ConfusionMatrix::compute(predictions, targets)?;
695
696        let precision = confusion_matrix.precision_per_class();
697        let recall = confusion_matrix.recall_per_class();
698        let f1_score = confusion_matrix.f1_per_class();
699
700        // Compute support (number of samples per class)
701        let num_classes = targets.ncols();
702        let mut support = vec![0; num_classes];
703
704        for i in 0..targets.nrows() {
705            // Find true class
706            let mut true_class = 0;
707            let mut max_true = targets[[i, 0]];
708            for j in 1..num_classes {
709                if targets[[i, j]] > max_true {
710                    max_true = targets[[i, j]];
711                    true_class = j;
712                }
713            }
714            support[true_class] += 1;
715        }
716
717        Ok(Self {
718            precision,
719            recall,
720            f1_score,
721            support,
722        })
723    }
724}
725
726impl std::fmt::Display for PerClassMetrics {
727    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
728        writeln!(f, "Per-Class Metrics:")?;
729        writeln!(f, "Class  Precision  Recall  F1-Score  Support")?;
730        writeln!(f, "-----  ---------  ------  --------  -------")?;
731
732        for i in 0..self.precision.len() {
733            writeln!(
734                f,
735                "{:5}  {:9.4}  {:6.4}  {:8.4}  {:7}",
736                i, self.precision[i], self.recall[i], self.f1_score[i], self.support[i]
737            )?;
738        }
739
740        // Compute macro averages
741        let macro_precision: f64 = self.precision.iter().sum::<f64>() / self.precision.len() as f64;
742        let macro_recall: f64 = self.recall.iter().sum::<f64>() / self.recall.len() as f64;
743        let macro_f1: f64 = self.f1_score.iter().sum::<f64>() / self.f1_score.len() as f64;
744        let total_support: usize = self.support.iter().sum();
745
746        writeln!(f, "-----  ---------  ------  --------  -------")?;
747        writeln!(
748            f,
749            "Macro  {:9.4}  {:6.4}  {:8.4}  {:7}",
750            macro_precision, macro_recall, macro_f1, total_support
751        )?;
752
753        Ok(())
754    }
755}
756
757/// Matthews Correlation Coefficient (MCC) metric.
758/// Ranges from -1 to +1, where +1 is perfect prediction, 0 is random, -1 is total disagreement.
759/// Particularly useful for imbalanced datasets.
760#[derive(Debug, Clone, Default)]
761pub struct MatthewsCorrelationCoefficient;
762
763impl Metric for MatthewsCorrelationCoefficient {
764    fn compute(
765        &self,
766        predictions: &ArrayView<f64, Ix2>,
767        targets: &ArrayView<f64, Ix2>,
768    ) -> TrainResult<f64> {
769        let confusion_matrix = ConfusionMatrix::compute(predictions, targets)?;
770        let num_classes = confusion_matrix.num_classes;
771
772        // For binary classification, use the standard MCC formula
773        if num_classes == 2 {
774            let tp = confusion_matrix.matrix[1][1] as f64;
775            let tn = confusion_matrix.matrix[0][0] as f64;
776            let fp = confusion_matrix.matrix[0][1] as f64;
777            let fn_val = confusion_matrix.matrix[1][0] as f64;
778
779            let numerator = (tp * tn) - (fp * fn_val);
780            let denominator = ((tp + fp) * (tp + fn_val) * (tn + fp) * (tn + fn_val)).sqrt();
781
782            if denominator == 0.0 {
783                Ok(0.0)
784            } else {
785                Ok(numerator / denominator)
786            }
787        } else {
788            // Multi-class MCC formula
789            let mut s = 0.0;
790            let mut c = 0.0;
791            let t = confusion_matrix.total_predictions() as f64;
792
793            // Compute column sums
794            let mut p_k = vec![0.0; num_classes];
795            let mut t_k = vec![0.0; num_classes];
796
797            for k in 0..num_classes {
798                for l in 0..num_classes {
799                    p_k[k] += confusion_matrix.matrix[l][k] as f64;
800                    t_k[k] += confusion_matrix.matrix[k][l] as f64;
801                }
802            }
803
804            // Compute trace (correct predictions)
805            for k in 0..num_classes {
806                c += confusion_matrix.matrix[k][k] as f64;
807            }
808
809            // Compute sum of products
810            for k in 0..num_classes {
811                s += p_k[k] * t_k[k];
812            }
813
814            let numerator = (t * c) - s;
815            let denominator_1 = ((t * t) - s).sqrt();
816            let mut sum_p_sq = 0.0;
817            let mut sum_t_sq = 0.0;
818            for k in 0..num_classes {
819                sum_p_sq += p_k[k] * p_k[k];
820                sum_t_sq += t_k[k] * t_k[k];
821            }
822            let denominator_2 = ((t * t) - sum_p_sq).sqrt();
823            let denominator_3 = ((t * t) - sum_t_sq).sqrt();
824
825            let denominator = denominator_1 * denominator_2 * denominator_3;
826
827            if denominator == 0.0 {
828                Ok(0.0)
829            } else {
830                Ok(numerator / denominator)
831            }
832        }
833    }
834
835    fn name(&self) -> &str {
836        "mcc"
837    }
838}
839
840/// Cohen's Kappa statistic.
841/// Measures inter-rater agreement, accounting for chance agreement.
842/// Ranges from -1 to +1, where 1 is perfect agreement, 0 is random chance.
843#[derive(Debug, Clone, Default)]
844pub struct CohensKappa;
845
846impl Metric for CohensKappa {
847    fn compute(
848        &self,
849        predictions: &ArrayView<f64, Ix2>,
850        targets: &ArrayView<f64, Ix2>,
851    ) -> TrainResult<f64> {
852        let confusion_matrix = ConfusionMatrix::compute(predictions, targets)?;
853        let num_classes = confusion_matrix.num_classes;
854        let total = confusion_matrix.total_predictions() as f64;
855
856        // Observed agreement (accuracy)
857        let mut observed = 0.0;
858        for i in 0..num_classes {
859            observed += confusion_matrix.matrix[i][i] as f64;
860        }
861        observed /= total;
862
863        // Expected agreement by chance
864        let mut expected = 0.0;
865        for i in 0..num_classes {
866            let row_sum: f64 = (0..num_classes)
867                .map(|j| confusion_matrix.matrix[i][j] as f64)
868                .sum();
869            let col_sum: f64 = (0..num_classes)
870                .map(|j| confusion_matrix.matrix[j][i] as f64)
871                .sum();
872            expected += (row_sum / total) * (col_sum / total);
873        }
874
875        if expected >= 1.0 {
876            Ok(0.0)
877        } else {
878            Ok((observed - expected) / (1.0 - expected))
879        }
880    }
881
882    fn name(&self) -> &str {
883        "cohens_kappa"
884    }
885}
886
887/// Top-K accuracy metric.
888/// Measures whether the correct class is in the top K predictions.
889#[derive(Debug, Clone)]
890pub struct TopKAccuracy {
891    /// Number of top predictions to consider.
892    pub k: usize,
893}
894
895impl Default for TopKAccuracy {
896    fn default() -> Self {
897        Self { k: 5 }
898    }
899}
900
901impl TopKAccuracy {
902    /// Create a new Top-K accuracy metric.
903    pub fn new(k: usize) -> Self {
904        Self { k }
905    }
906}
907
908impl Metric for TopKAccuracy {
909    fn compute(
910        &self,
911        predictions: &ArrayView<f64, Ix2>,
912        targets: &ArrayView<f64, Ix2>,
913    ) -> TrainResult<f64> {
914        if predictions.shape() != targets.shape() {
915            return Err(TrainError::MetricsError(format!(
916                "Shape mismatch: predictions {:?} vs targets {:?}",
917                predictions.shape(),
918                targets.shape()
919            )));
920        }
921
922        let num_classes = predictions.ncols();
923        if self.k > num_classes {
924            return Err(TrainError::MetricsError(format!(
925                "K ({}) cannot be greater than number of classes ({})",
926                self.k, num_classes
927            )));
928        }
929
930        let mut correct = 0;
931        let total = predictions.nrows();
932
933        for i in 0..total {
934            // Find true class
935            let mut true_class = 0;
936            let mut max_true = targets[[i, 0]];
937            for j in 1..num_classes {
938                if targets[[i, j]] > max_true {
939                    max_true = targets[[i, j]];
940                    true_class = j;
941                }
942            }
943
944            // Get top K predictions
945            let mut indices: Vec<usize> = (0..num_classes).collect();
946            indices.sort_by(|&a, &b| {
947                predictions[[i, b]]
948                    .partial_cmp(&predictions[[i, a]])
949                    .unwrap_or(std::cmp::Ordering::Equal)
950            });
951
952            // Check if true class is in top K
953            if indices[..self.k].contains(&true_class) {
954                correct += 1;
955            }
956        }
957
958        Ok(correct as f64 / total as f64)
959    }
960
961    fn name(&self) -> &str {
962        "top_k_accuracy"
963    }
964}
965
966/// Balanced accuracy metric.
967/// Average of recall per class, useful for imbalanced datasets.
968#[derive(Debug, Clone, Default)]
969pub struct BalancedAccuracy;
970
971impl Metric for BalancedAccuracy {
972    fn compute(
973        &self,
974        predictions: &ArrayView<f64, Ix2>,
975        targets: &ArrayView<f64, Ix2>,
976    ) -> TrainResult<f64> {
977        let confusion_matrix = ConfusionMatrix::compute(predictions, targets)?;
978        let recalls = confusion_matrix.recall_per_class();
979
980        // Balanced accuracy is the average of recall per class
981        let sum: f64 = recalls.iter().sum();
982        Ok(sum / recalls.len() as f64)
983    }
984
985    fn name(&self) -> &str {
986        "balanced_accuracy"
987    }
988}
989
990#[cfg(test)]
991mod tests {
992    use super::*;
993    use scirs2_core::ndarray::array;
994
995    #[test]
996    fn test_accuracy() {
997        let metric = Accuracy::default();
998
999        // Perfect predictions
1000        let predictions = array![[0.9, 0.1], [0.2, 0.8], [0.8, 0.2]];
1001        let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0]];
1002
1003        let accuracy = metric
1004            .compute(&predictions.view(), &targets.view())
1005            .unwrap();
1006        assert_eq!(accuracy, 1.0);
1007
1008        // Partial correct
1009        let predictions = array![[0.9, 0.1], [0.8, 0.2], [0.8, 0.2]];
1010        let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0]];
1011
1012        let accuracy = metric
1013            .compute(&predictions.view(), &targets.view())
1014            .unwrap();
1015        assert!((accuracy - 2.0 / 3.0).abs() < 1e-6);
1016    }
1017
1018    #[test]
1019    fn test_precision() {
1020        let metric = Precision::default();
1021
1022        let predictions = array![[0.9, 0.1], [0.2, 0.8], [0.7, 0.3]];
1023        let targets = array![[1.0, 0.0], [0.0, 1.0], [0.0, 1.0]];
1024
1025        let precision = metric
1026            .compute(&predictions.view(), &targets.view())
1027            .unwrap();
1028        assert!((0.0..=1.0).contains(&precision));
1029    }
1030
1031    #[test]
1032    fn test_recall() {
1033        let metric = Recall::default();
1034
1035        let predictions = array![[0.9, 0.1], [0.2, 0.8], [0.7, 0.3]];
1036        let targets = array![[1.0, 0.0], [0.0, 1.0], [0.0, 1.0]];
1037
1038        let recall = metric
1039            .compute(&predictions.view(), &targets.view())
1040            .unwrap();
1041        assert!((0.0..=1.0).contains(&recall));
1042    }
1043
1044    #[test]
1045    fn test_f1_score() {
1046        let metric = F1Score::default();
1047
1048        let predictions = array![[0.9, 0.1], [0.2, 0.8], [0.7, 0.3]];
1049        let targets = array![[1.0, 0.0], [0.0, 1.0], [0.0, 1.0]];
1050
1051        let f1 = metric
1052            .compute(&predictions.view(), &targets.view())
1053            .unwrap();
1054        assert!((0.0..=1.0).contains(&f1));
1055    }
1056
1057    #[test]
1058    fn test_metric_tracker() {
1059        let mut tracker = MetricTracker::new();
1060        tracker.add(Box::new(Accuracy::default()));
1061        tracker.add(Box::new(F1Score::default()));
1062
1063        let predictions = array![[0.9, 0.1], [0.2, 0.8]];
1064        let targets = array![[1.0, 0.0], [0.0, 1.0]];
1065
1066        let results = tracker
1067            .compute_all(&predictions.view(), &targets.view())
1068            .unwrap();
1069        assert!(results.contains_key("accuracy"));
1070        assert!(results.contains_key("f1_score"));
1071
1072        let history = tracker.get_history("accuracy").unwrap();
1073        assert_eq!(history.len(), 1);
1074    }
1075
1076    #[test]
1077    fn test_confusion_matrix() {
1078        let predictions = array![
1079            [0.9, 0.1, 0.0],
1080            [0.1, 0.8, 0.1],
1081            [0.2, 0.1, 0.7],
1082            [0.8, 0.1, 0.1]
1083        ];
1084        let targets = array![
1085            [1.0, 0.0, 0.0],
1086            [0.0, 1.0, 0.0],
1087            [0.0, 0.0, 1.0],
1088            [1.0, 0.0, 0.0]
1089        ];
1090
1091        let cm = ConfusionMatrix::compute(&predictions.view(), &targets.view()).unwrap();
1092
1093        assert_eq!(cm.get(0, 0), 2); // Class 0 correctly predicted
1094        assert_eq!(cm.get(1, 1), 1); // Class 1 correctly predicted
1095        assert_eq!(cm.get(2, 2), 1); // Class 2 correctly predicted
1096        assert_eq!(cm.accuracy(), 1.0);
1097    }
1098
1099    #[test]
1100    fn test_confusion_matrix_per_class_metrics() {
1101        let predictions = array![[0.9, 0.1], [0.2, 0.8], [0.7, 0.3], [0.1, 0.9]];
1102        let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
1103
1104        let cm = ConfusionMatrix::compute(&predictions.view(), &targets.view()).unwrap();
1105
1106        let precision = cm.precision_per_class();
1107        let recall = cm.recall_per_class();
1108        let f1 = cm.f1_per_class();
1109
1110        assert_eq!(precision.len(), 2);
1111        assert_eq!(recall.len(), 2);
1112        assert_eq!(f1.len(), 2);
1113
1114        // All predictions correct
1115        assert_eq!(precision[0], 1.0);
1116        assert_eq!(precision[1], 1.0);
1117        assert_eq!(recall[0], 1.0);
1118        assert_eq!(recall[1], 1.0);
1119    }
1120
1121    #[test]
1122    fn test_roc_curve() {
1123        let predictions = vec![0.9, 0.8, 0.4, 0.3, 0.1];
1124        let targets = vec![true, true, false, true, false];
1125
1126        let roc = RocCurve::compute(&predictions, &targets).unwrap();
1127
1128        assert!(!roc.fpr.is_empty());
1129        assert!(!roc.tpr.is_empty());
1130        assert!(!roc.thresholds.is_empty());
1131        assert_eq!(roc.fpr.len(), roc.tpr.len());
1132
1133        let auc = roc.auc();
1134        assert!((0.0..=1.0).contains(&auc));
1135    }
1136
1137    #[test]
1138    fn test_roc_auc_perfect() {
1139        let predictions = vec![0.9, 0.8, 0.3, 0.1];
1140        let targets = vec![true, true, false, false];
1141
1142        let roc = RocCurve::compute(&predictions, &targets).unwrap();
1143        let auc = roc.auc();
1144
1145        // Perfect classification should have AUC = 1.0
1146        assert!((auc - 1.0).abs() < 1e-6);
1147    }
1148
1149    #[test]
1150    fn test_per_class_metrics() {
1151        let predictions = array![
1152            [0.9, 0.1, 0.0],
1153            [0.1, 0.8, 0.1],
1154            [0.2, 0.1, 0.7],
1155            [0.8, 0.1, 0.1]
1156        ];
1157        let targets = array![
1158            [1.0, 0.0, 0.0],
1159            [0.0, 1.0, 0.0],
1160            [0.0, 0.0, 1.0],
1161            [1.0, 0.0, 0.0]
1162        ];
1163
1164        let metrics = PerClassMetrics::compute(&predictions.view(), &targets.view()).unwrap();
1165
1166        assert_eq!(metrics.precision.len(), 3);
1167        assert_eq!(metrics.recall.len(), 3);
1168        assert_eq!(metrics.f1_score.len(), 3);
1169        assert_eq!(metrics.support.len(), 3);
1170
1171        // Check support counts
1172        assert_eq!(metrics.support[0], 2);
1173        assert_eq!(metrics.support[1], 1);
1174        assert_eq!(metrics.support[2], 1);
1175    }
1176
1177    #[test]
1178    fn test_matthews_correlation_coefficient() {
1179        let metric = MatthewsCorrelationCoefficient;
1180
1181        // Perfect binary classification
1182        let predictions = array![[0.9, 0.1], [0.1, 0.9], [0.9, 0.1], [0.1, 0.9]];
1183        let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
1184
1185        let mcc = metric
1186            .compute(&predictions.view(), &targets.view())
1187            .unwrap();
1188        assert!((mcc - 1.0).abs() < 1e-6);
1189
1190        // Random classification
1191        let predictions = array![[0.5, 0.5], [0.5, 0.5], [0.5, 0.5], [0.5, 0.5]];
1192        let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
1193
1194        let mcc = metric
1195            .compute(&predictions.view(), &targets.view())
1196            .unwrap();
1197        assert!(mcc.abs() < 0.1);
1198    }
1199
1200    #[test]
1201    fn test_cohens_kappa() {
1202        let metric = CohensKappa;
1203
1204        // Perfect agreement
1205        let predictions = array![[0.9, 0.1], [0.1, 0.9], [0.9, 0.1], [0.1, 0.9]];
1206        let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
1207
1208        let kappa = metric
1209            .compute(&predictions.view(), &targets.view())
1210            .unwrap();
1211        assert!((kappa - 1.0).abs() < 1e-6);
1212
1213        // Random agreement
1214        let predictions = array![[0.9, 0.1], [0.9, 0.1], [0.9, 0.1], [0.9, 0.1]];
1215        let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
1216
1217        let kappa = metric
1218            .compute(&predictions.view(), &targets.view())
1219            .unwrap();
1220        assert!((-1.0..=1.0).contains(&kappa));
1221    }
1222
1223    #[test]
1224    fn test_top_k_accuracy() {
1225        let metric = TopKAccuracy::new(2);
1226
1227        // Test with 3 classes
1228        let predictions = array![
1229            [0.7, 0.2, 0.1], // Correct class is 0, top-2 includes it
1230            [0.1, 0.6, 0.3], // Correct class is 1, top-2 includes it
1231            [0.3, 0.4, 0.3], // Correct class is 2, top-2 includes it (1, 0)
1232        ];
1233        let targets = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]];
1234
1235        let top_k = metric
1236            .compute(&predictions.view(), &targets.view())
1237            .unwrap();
1238        assert!((0.0..=1.0).contains(&top_k));
1239        assert!(top_k >= 0.66); // At least 2/3 should be in top-2
1240    }
1241
1242    #[test]
1243    fn test_balanced_accuracy() {
1244        let metric = BalancedAccuracy;
1245
1246        // Perfect classification
1247        let predictions = array![[0.9, 0.1], [0.1, 0.9], [0.9, 0.1], [0.1, 0.9]];
1248        let targets = array![[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]];
1249
1250        let balanced_acc = metric
1251            .compute(&predictions.view(), &targets.view())
1252            .unwrap();
1253        assert!((balanced_acc - 1.0).abs() < 1e-6);
1254
1255        // Imbalanced but perfect
1256        let predictions = array![[0.9, 0.1], [0.9, 0.1], [0.9, 0.1], [0.1, 0.9]];
1257        let targets = array![[1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0]];
1258
1259        let balanced_acc = metric
1260            .compute(&predictions.view(), &targets.view())
1261            .unwrap();
1262        assert!((balanced_acc - 1.0).abs() < 1e-6);
1263    }
1264}