tensorlogic_train/metrics/
tracker.rs1use crate::TrainResult;
4use scirs2_core::ndarray::{ArrayView, Ix2};
5use std::collections::HashMap;
6
7use super::Metric;
8
9pub struct MetricTracker {
11 metrics: Vec<Box<dyn Metric>>,
13 history: HashMap<String, Vec<f64>>,
15}
16
17impl MetricTracker {
18 pub fn new() -> Self {
20 Self {
21 metrics: Vec::new(),
22 history: HashMap::new(),
23 }
24 }
25
26 pub fn add(&mut self, metric: Box<dyn Metric>) {
28 let name = metric.name().to_string();
29 self.history.insert(name, Vec::new());
30 self.metrics.push(metric);
31 }
32
33 pub fn compute_all(
35 &mut self,
36 predictions: &ArrayView<f64, Ix2>,
37 targets: &ArrayView<f64, Ix2>,
38 ) -> TrainResult<HashMap<String, f64>> {
39 let mut results = HashMap::new();
40
41 for metric in &self.metrics {
42 let value = metric.compute(predictions, targets)?;
43 let name = metric.name().to_string();
44
45 results.insert(name.clone(), value);
46
47 if let Some(history) = self.history.get_mut(&name) {
48 history.push(value);
49 }
50 }
51
52 Ok(results)
53 }
54
55 pub fn get_history(&self, metric_name: &str) -> Option<&Vec<f64>> {
57 self.history.get(metric_name)
58 }
59
60 pub fn reset(&mut self) {
62 for metric in &mut self.metrics {
63 metric.reset();
64 }
65 }
66
67 pub fn clear_history(&mut self) {
69 for history in self.history.values_mut() {
70 history.clear();
71 }
72 }
73}
74
75impl Default for MetricTracker {
76 fn default() -> Self {
77 Self::new()
78 }
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84 use crate::metrics::{Accuracy, F1Score};
85 use scirs2_core::ndarray::array;
86
87 #[test]
88 fn test_metric_tracker() {
89 let mut tracker = MetricTracker::new();
90 tracker.add(Box::new(Accuracy::default()));
91 tracker.add(Box::new(F1Score::default()));
92
93 let predictions = array![[0.9, 0.1], [0.2, 0.8]];
94 let targets = array![[1.0, 0.0], [0.0, 1.0]];
95
96 let results = tracker
97 .compute_all(&predictions.view(), &targets.view())
98 .unwrap();
99 assert!(results.contains_key("accuracy"));
100 assert!(results.contains_key("f1_score"));
101
102 let history = tracker.get_history("accuracy").unwrap();
103 assert_eq!(history.len(), 1);
104 }
105}