trustformers_models/comprehensive_testing/
reference_comparison.rs1use anyhow::{Error, Result};
4use trustformers_core::tensor::Tensor;
5
6use super::types::NumericalDifferences;
7
8pub struct ReferenceComparator {
10 tolerance: f32,
11}
12
13impl ReferenceComparator {
14 pub fn new(tolerance: f32) -> Self {
16 Self { tolerance }
17 }
18
19 pub fn tolerance(&self) -> f32 {
21 self.tolerance
22 }
23
24 pub fn compare_with_reference(
26 &self,
27 actual: &Tensor,
28 expected: &Tensor,
29 ) -> Result<NumericalDifferences> {
30 match (actual, expected) {
31 (Tensor::F32(actual_arr), Tensor::F32(expected_arr)) => {
32 if actual_arr.shape() != expected_arr.shape() {
33 return Err(Error::msg("Tensor shapes don't match"));
34 }
35
36 let diffs: Vec<f32> = actual_arr
37 .iter()
38 .zip(expected_arr.iter())
39 .map(|(a, e)| (a - e).abs())
40 .collect();
41
42 let max_abs_diff = diffs.iter().cloned().fold(0.0, f32::max);
43 let mean_abs_diff = diffs.iter().sum::<f32>() / diffs.len() as f32;
44 let rms_diff =
45 (diffs.iter().map(|d| d * d).sum::<f32>() / diffs.len() as f32).sqrt();
46 let within_tolerance = diffs.iter().filter(|&&d| d <= self.tolerance).count();
47 let within_tolerance_percent =
48 (within_tolerance as f32 / diffs.len() as f32) * 100.0;
49
50 Ok(NumericalDifferences {
51 max_abs_diff,
52 mean_abs_diff,
53 rms_diff,
54 within_tolerance_percent,
55 })
56 },
57 _ => Err(Error::msg("Unsupported tensor types for comparison")),
58 }
59 }
60
61 pub fn validate_differences(&self, differences: &NumericalDifferences) -> bool {
63 differences.max_abs_diff <= self.tolerance && differences.within_tolerance_percent >= 95.0
64 }
65}