Skip to main content

tensorlogic_train/metrics/
tracker.rs

1//! Metric tracker for managing multiple metrics.
2
3use crate::TrainResult;
4use scirs2_core::ndarray::{ArrayView, Ix2};
5use std::collections::HashMap;
6
7use super::Metric;
8
9/// Metric tracker for managing multiple metrics.
10pub struct MetricTracker {
11    /// Metrics to track.
12    metrics: Vec<Box<dyn Metric>>,
13    /// History of metric values.
14    history: HashMap<String, Vec<f64>>,
15}
16
17impl MetricTracker {
18    /// Create a new metric tracker.
19    pub fn new() -> Self {
20        Self {
21            metrics: Vec::new(),
22            history: HashMap::new(),
23        }
24    }
25
26    /// Add a metric to track.
27    pub fn add(&mut self, metric: Box<dyn Metric>) {
28        let name = metric.name().to_string();
29        self.history.insert(name, Vec::new());
30        self.metrics.push(metric);
31    }
32
33    /// Compute all metrics.
34    pub fn compute_all(
35        &mut self,
36        predictions: &ArrayView<f64, Ix2>,
37        targets: &ArrayView<f64, Ix2>,
38    ) -> TrainResult<HashMap<String, f64>> {
39        let mut results = HashMap::new();
40
41        for metric in &self.metrics {
42            let value = metric.compute(predictions, targets)?;
43            let name = metric.name().to_string();
44
45            results.insert(name.clone(), value);
46
47            if let Some(history) = self.history.get_mut(&name) {
48                history.push(value);
49            }
50        }
51
52        Ok(results)
53    }
54
55    /// Get history for a specific metric.
56    pub fn get_history(&self, metric_name: &str) -> Option<&Vec<f64>> {
57        self.history.get(metric_name)
58    }
59
60    /// Reset all metrics.
61    pub fn reset(&mut self) {
62        for metric in &mut self.metrics {
63            metric.reset();
64        }
65    }
66
67    /// Clear history.
68    pub fn clear_history(&mut self) {
69        for history in self.history.values_mut() {
70            history.clear();
71        }
72    }
73}
74
75impl Default for MetricTracker {
76    fn default() -> Self {
77        Self::new()
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use super::*;
84    use crate::metrics::{Accuracy, F1Score};
85    use scirs2_core::ndarray::array;
86
87    #[test]
88    fn test_metric_tracker() {
89        let mut tracker = MetricTracker::new();
90        tracker.add(Box::new(Accuracy::default()));
91        tracker.add(Box::new(F1Score::default()));
92
93        let predictions = array![[0.9, 0.1], [0.2, 0.8]];
94        let targets = array![[1.0, 0.0], [0.0, 1.0]];
95
96        let results = tracker
97            .compute_all(&predictions.view(), &targets.view())
98            .unwrap();
99        assert!(results.contains_key("accuracy"));
100        assert!(results.contains_key("f1_score"));
101
102        let history = tracker.get_history("accuracy").unwrap();
103        assert_eq!(history.len(), 1);
104    }
105}