Skip to main content

trustformers_models/comprehensive_testing/
reference_comparison.rs

1//! Reference value comparison utilities
2
3use anyhow::{Error, Result};
4use trustformers_core::tensor::Tensor;
5
6use super::types::NumericalDifferences;
7
8/// Reference value comparison utilities
9pub struct ReferenceComparator {
10    tolerance: f32,
11}
12
13impl ReferenceComparator {
14    /// Create a new reference comparator
15    pub fn new(tolerance: f32) -> Self {
16        Self { tolerance }
17    }
18
19    /// Get the tolerance value
20    pub fn tolerance(&self) -> f32 {
21        self.tolerance
22    }
23
24    /// Compare model output with reference values
25    pub fn compare_with_reference(
26        &self,
27        actual: &Tensor,
28        expected: &Tensor,
29    ) -> Result<NumericalDifferences> {
30        match (actual, expected) {
31            (Tensor::F32(actual_arr), Tensor::F32(expected_arr)) => {
32                if actual_arr.shape() != expected_arr.shape() {
33                    return Err(Error::msg("Tensor shapes don't match"));
34                }
35
36                let diffs: Vec<f32> = actual_arr
37                    .iter()
38                    .zip(expected_arr.iter())
39                    .map(|(a, e)| (a - e).abs())
40                    .collect();
41
42                let max_abs_diff = diffs.iter().cloned().fold(0.0, f32::max);
43                let mean_abs_diff = diffs.iter().sum::<f32>() / diffs.len() as f32;
44                let rms_diff =
45                    (diffs.iter().map(|d| d * d).sum::<f32>() / diffs.len() as f32).sqrt();
46                let within_tolerance = diffs.iter().filter(|&&d| d <= self.tolerance).count();
47                let within_tolerance_percent =
48                    (within_tolerance as f32 / diffs.len() as f32) * 100.0;
49
50                Ok(NumericalDifferences {
51                    max_abs_diff,
52                    mean_abs_diff,
53                    rms_diff,
54                    within_tolerance_percent,
55                })
56            },
57            _ => Err(Error::msg("Unsupported tensor types for comparison")),
58        }
59    }
60
61    /// Validate that differences are within acceptable bounds
62    pub fn validate_differences(&self, differences: &NumericalDifferences) -> bool {
63        differences.max_abs_diff <= self.tolerance && differences.within_tolerance_percent >= 95.0
64    }
65}