1use crate::EvalError;
2
3#[derive(Debug, Clone, Copy, PartialEq)]
4pub struct CountingMetrics {
5 pub num_frames: usize,
6 pub mae: f32,
7 pub rmse: f32,
8 pub max_abs_error: usize,
9}
10
11pub fn evaluate_counts(
12 ground_truth: &[usize],
13 predictions: &[usize],
14) -> Result<CountingMetrics, EvalError> {
15 if ground_truth.len() != predictions.len() {
16 return Err(EvalError::CountLengthMismatch {
17 ground_truth: ground_truth.len(),
18 predictions: predictions.len(),
19 });
20 }
21
22 if ground_truth.is_empty() {
23 return Ok(CountingMetrics {
24 num_frames: 0,
25 mae: 0.0,
26 rmse: 0.0,
27 max_abs_error: 0,
28 });
29 }
30
31 let mut abs_error_sum = 0.0f32;
32 let mut sq_error_sum = 0.0f32;
33 let mut max_abs_error = 0usize;
34 for (>, &prediction) in ground_truth.iter().zip(predictions.iter()) {
35 let error = prediction as i64 - gt as i64;
36 let abs_error = error.unsigned_abs() as usize;
37 abs_error_sum += abs_error as f32;
38 sq_error_sum += (error as f32).powi(2);
39 max_abs_error = max_abs_error.max(abs_error);
40 }
41
42 let denom = ground_truth.len() as f32;
43 Ok(CountingMetrics {
44 num_frames: ground_truth.len(),
45 mae: abs_error_sum / denom,
46 rmse: (sq_error_sum / denom).sqrt(),
47 max_abs_error,
48 })
49}