Skip to main content

tensorlogic_scirs_backend/
comparison.rs

1//! Tensor comparison utilities for testing and validation.
2//!
3//! Provides configurable tolerance-based tensor comparison, element-wise diff
4//! analysis, and assertion helpers for backend validation and gradient checking.
5
6use scirs2_core::ndarray::ArrayD;
7use thiserror::Error;
8
9/// Errors that can occur during tensor comparison.
10#[derive(Debug, Error)]
11pub enum ComparisonError {
12    /// The two tensors have different shapes.
13    #[error("Shape mismatch: {0:?} vs {1:?}")]
14    ShapeMismatch(Vec<usize>, Vec<usize>),
15    /// Both tensors are empty (zero elements).
16    #[error("Empty tensors")]
17    EmptyTensors,
18}
19
20/// Tolerance configuration for tensor comparison.
21///
22/// Uses the NumPy-style closeness criterion:
23/// `|a - b| <= atol + rtol * |b|`
24#[derive(Debug, Clone)]
25pub struct Tolerance {
26    /// Relative tolerance (default 1e-5)
27    pub rtol: f64,
28    /// Absolute tolerance (default 1e-8)
29    pub atol: f64,
30}
31
32impl Default for Tolerance {
33    fn default() -> Self {
34        Tolerance {
35            rtol: 1e-5,
36            atol: 1e-8,
37        }
38    }
39}
40
41impl Tolerance {
42    /// Create a new tolerance with the given relative and absolute tolerances.
43    pub fn new(rtol: f64, atol: f64) -> Self {
44        Tolerance { rtol, atol }
45    }
46
47    /// Strict tolerance suitable for exact-ish comparisons.
48    pub fn strict() -> Self {
49        Tolerance {
50            rtol: 1e-12,
51            atol: 1e-15,
52        }
53    }
54
55    /// Loose tolerance suitable for approximate comparisons (e.g., gradient checking).
56    pub fn loose() -> Self {
57        Tolerance {
58            rtol: 1e-3,
59            atol: 1e-6,
60        }
61    }
62
63    /// Check if two values are close: `|a - b| <= atol + rtol * |b|`
64    pub fn is_close(&self, a: f64, b: f64) -> bool {
65        (a - b).abs() <= self.atol + self.rtol * b.abs()
66    }
67}
68
69/// Result of comparing two tensors element-wise.
70#[derive(Debug, Clone)]
71pub struct ComparisonResult {
72    /// Whether all elements are within tolerance.
73    pub all_close: bool,
74    /// Maximum absolute difference across all elements.
75    pub max_abs_diff: f64,
76    /// Mean absolute difference across all elements.
77    pub mean_abs_diff: f64,
78    /// Maximum relative difference (relative to `|b|`).
79    pub max_rel_diff: f64,
80    /// Number of elements that differ beyond tolerance.
81    pub mismatch_count: usize,
82    /// Total number of elements compared.
83    pub total_elements: usize,
84    /// Flattened index of the maximum absolute difference.
85    pub max_diff_index: usize,
86    /// Number of NaN mismatches (one is NaN, other is not).
87    pub nan_mismatches: usize,
88    /// Number of Inf mismatches (one is infinite, other is not, or signs differ).
89    pub inf_mismatches: usize,
90}
91
92impl ComparisonResult {
93    /// Fraction of elements that match within tolerance.
94    pub fn match_ratio(&self) -> f64 {
95        if self.total_elements == 0 {
96            1.0
97        } else {
98            (self.total_elements - self.mismatch_count) as f64 / self.total_elements as f64
99        }
100    }
101
102    /// Human-readable summary of the comparison.
103    pub fn summary(&self) -> String {
104        if self.all_close {
105            format!(
106                "MATCH: {} elements, max_diff={:.2e}",
107                self.total_elements, self.max_abs_diff
108            )
109        } else {
110            format!(
111                "MISMATCH: {}/{} elements differ, max_diff={:.2e}, mean_diff={:.2e}",
112                self.mismatch_count, self.total_elements, self.max_abs_diff, self.mean_abs_diff
113            )
114        }
115    }
116}
117
118/// Compare two tensors element-wise with configurable tolerance.
119///
120/// Handles NaN and Inf specially:
121/// - Both NaN → considered matching
122/// - One NaN, one not → nan_mismatch
123/// - Both ±Inf with same sign → matching
124/// - One Inf, one not (or different signs) → inf_mismatch
125pub fn compare_tensors(
126    a: &ArrayD<f64>,
127    b: &ArrayD<f64>,
128    tol: &Tolerance,
129) -> Result<ComparisonResult, ComparisonError> {
130    if a.shape() != b.shape() {
131        return Err(ComparisonError::ShapeMismatch(
132            a.shape().to_vec(),
133            b.shape().to_vec(),
134        ));
135    }
136    if a.is_empty() {
137        return Err(ComparisonError::EmptyTensors);
138    }
139
140    let mut max_abs_diff = 0.0_f64;
141    let mut sum_abs_diff = 0.0_f64;
142    let mut max_rel_diff = 0.0_f64;
143    let mut mismatch_count = 0_usize;
144    let mut max_diff_index = 0_usize;
145    let mut nan_mismatches = 0_usize;
146    let mut inf_mismatches = 0_usize;
147
148    for (i, (&va, &vb)) in a.iter().zip(b.iter()).enumerate() {
149        // Handle NaN cases
150        if va.is_nan() != vb.is_nan() {
151            nan_mismatches += 1;
152            mismatch_count += 1;
153            continue;
154        }
155        if va.is_nan() && vb.is_nan() {
156            // Both NaN → considered matching
157            continue;
158        }
159
160        // Handle Inf cases
161        if va.is_infinite() != vb.is_infinite() {
162            inf_mismatches += 1;
163            mismatch_count += 1;
164            continue;
165        }
166        if va.is_infinite() && vb.is_infinite() {
167            if va.signum() == vb.signum() {
168                // Both same-sign Inf → matching
169                continue;
170            }
171            // Different sign Inf → mismatch
172            inf_mismatches += 1;
173            mismatch_count += 1;
174            continue;
175        }
176
177        // Normal finite comparison
178        let abs_diff = (va - vb).abs();
179        sum_abs_diff += abs_diff;
180
181        if abs_diff > max_abs_diff {
182            max_abs_diff = abs_diff;
183            max_diff_index = i;
184        }
185
186        let rel_diff = if vb.abs() > 1e-15 {
187            abs_diff / vb.abs()
188        } else {
189            abs_diff
190        };
191        if rel_diff > max_rel_diff {
192            max_rel_diff = rel_diff;
193        }
194
195        if !tol.is_close(va, vb) {
196            mismatch_count += 1;
197        }
198    }
199
200    let total = a.len();
201    Ok(ComparisonResult {
202        all_close: mismatch_count == 0,
203        max_abs_diff,
204        mean_abs_diff: sum_abs_diff / total as f64,
205        max_rel_diff,
206        mismatch_count,
207        total_elements: total,
208        max_diff_index,
209        nan_mismatches,
210        inf_mismatches,
211    })
212}
213
214/// Assert two tensors are close (panics with detailed message if not).
215///
216/// Intended for use in tests where a panic on mismatch is appropriate.
217pub fn assert_tensors_close(a: &ArrayD<f64>, b: &ArrayD<f64>, tol: &Tolerance) {
218    match compare_tensors(a, b, tol) {
219        Ok(result) if result.all_close => {}
220        Ok(result) => panic!(
221            "Tensors not close: {}\nMax diff at index {}: {:.2e}",
222            result.summary(),
223            result.max_diff_index,
224            result.max_abs_diff
225        ),
226        Err(e) => panic!("Tensor comparison failed: {e}"),
227    }
228}
229
230/// Compute element-wise absolute difference tensor.
231///
232/// Returns a tensor of the same shape where each element is `|a_i - b_i|`.
233pub fn abs_diff(a: &ArrayD<f64>, b: &ArrayD<f64>) -> Result<ArrayD<f64>, ComparisonError> {
234    if a.shape() != b.shape() {
235        return Err(ComparisonError::ShapeMismatch(
236            a.shape().to_vec(),
237            b.shape().to_vec(),
238        ));
239    }
240    let diff = a - b;
241    Ok(diff.mapv(f64::abs))
242}
243
244/// Check if a tensor contains only finite values (no NaN or Inf).
245pub fn is_finite(tensor: &ArrayD<f64>) -> bool {
246    tensor.iter().all(|v| v.is_finite())
247}
248
249/// Count non-finite values in a tensor.
250///
251/// Returns `(nan_count, inf_count)`.
252pub fn count_non_finite(tensor: &ArrayD<f64>) -> (usize, usize) {
253    let nan_count = tensor.iter().filter(|v| v.is_nan()).count();
254    let inf_count = tensor.iter().filter(|v| v.is_infinite()).count();
255    (nan_count, inf_count)
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261    use scirs2_core::ndarray::{arr1, ArrayD};
262
263    fn arr_1d(values: &[f64]) -> ArrayD<f64> {
264        arr1(values).into_dyn()
265    }
266
267    #[test]
268    fn test_tolerance_default() {
269        let tol = Tolerance::default();
270        assert!((tol.rtol - 1e-5).abs() < 1e-20);
271        assert!((tol.atol - 1e-8).abs() < 1e-20);
272    }
273
274    #[test]
275    fn test_tolerance_is_close_true() {
276        let tol = Tolerance::default();
277        assert!(tol.is_close(1.0, 1.0 + 1e-9));
278    }
279
280    #[test]
281    fn test_tolerance_is_close_false() {
282        let tol = Tolerance::default();
283        assert!(!tol.is_close(1.0, 2.0));
284    }
285
286    #[test]
287    fn test_tolerance_strict() {
288        let tol = Tolerance::strict();
289        assert!((tol.rtol - 1e-12).abs() < 1e-20);
290        assert!((tol.atol - 1e-15).abs() < 1e-20);
291    }
292
293    #[test]
294    fn test_tolerance_loose() {
295        let tol = Tolerance::loose();
296        assert!((tol.rtol - 1e-3).abs() < 1e-20);
297        assert!((tol.atol - 1e-6).abs() < 1e-20);
298    }
299
300    #[test]
301    fn test_compare_identical() {
302        let a = arr_1d(&[1.0, 2.0, 3.0]);
303        let b = arr_1d(&[1.0, 2.0, 3.0]);
304        let result = compare_tensors(&a, &b, &Tolerance::default()).expect("comparison failed");
305        assert!(result.all_close);
306        assert!((result.max_abs_diff - 0.0).abs() < 1e-20);
307        assert_eq!(result.mismatch_count, 0);
308    }
309
310    #[test]
311    fn test_compare_close() {
312        let a = arr_1d(&[1.0, 2.0, 3.0]);
313        let b = arr_1d(&[1.0 + 1e-9, 2.0 + 1e-9, 3.0 + 1e-9]);
314        let result = compare_tensors(&a, &b, &Tolerance::default()).expect("comparison failed");
315        assert!(result.all_close);
316    }
317
318    #[test]
319    fn test_compare_different() {
320        let a = arr_1d(&[1.0, 2.0, 3.0]);
321        let b = arr_1d(&[1.0, 2.0, 100.0]);
322        let result = compare_tensors(&a, &b, &Tolerance::default()).expect("comparison failed");
323        assert!(!result.all_close);
324        assert!(result.mismatch_count > 0);
325    }
326
327    #[test]
328    fn test_compare_shape_mismatch() {
329        let a = arr_1d(&[1.0, 2.0]);
330        let b = arr_1d(&[1.0, 2.0, 3.0]);
331        let result = compare_tensors(&a, &b, &Tolerance::default());
332        assert!(result.is_err());
333    }
334
335    #[test]
336    fn test_compare_empty() {
337        let a: ArrayD<f64> = ArrayD::zeros(vec![0]);
338        let b: ArrayD<f64> = ArrayD::zeros(vec![0]);
339        let result = compare_tensors(&a, &b, &Tolerance::default());
340        assert!(result.is_err());
341    }
342
343    #[test]
344    fn test_compare_nan_both() {
345        let a = arr_1d(&[f64::NAN, 1.0]);
346        let b = arr_1d(&[f64::NAN, 1.0]);
347        let result = compare_tensors(&a, &b, &Tolerance::default()).expect("comparison failed");
348        assert!(result.all_close);
349        assert_eq!(result.nan_mismatches, 0);
350    }
351
352    #[test]
353    fn test_compare_nan_one() {
354        let a = arr_1d(&[f64::NAN, 1.0]);
355        let b = arr_1d(&[1.0, 1.0]);
356        let result = compare_tensors(&a, &b, &Tolerance::default()).expect("comparison failed");
357        assert!(!result.all_close);
358        assert_eq!(result.nan_mismatches, 1);
359    }
360
361    #[test]
362    fn test_compare_inf_matching() {
363        let a = arr_1d(&[f64::INFINITY, 1.0]);
364        let b = arr_1d(&[f64::INFINITY, 1.0]);
365        let result = compare_tensors(&a, &b, &Tolerance::default()).expect("comparison failed");
366        assert!(result.all_close);
367        assert_eq!(result.inf_mismatches, 0);
368    }
369
370    #[test]
371    fn test_compare_match_ratio() {
372        let a = arr_1d(&[1.0, 2.0, 3.0, 4.0]);
373        let b = arr_1d(&[1.0, 2.0, 3.0, 100.0]);
374        let result = compare_tensors(&a, &b, &Tolerance::default()).expect("comparison failed");
375        assert!((result.match_ratio() - 0.75).abs() < 1e-10);
376    }
377
378    #[test]
379    fn test_compare_summary() {
380        let a = arr_1d(&[1.0, 2.0]);
381        let b = arr_1d(&[1.0, 2.0]);
382        let result = compare_tensors(&a, &b, &Tolerance::default()).expect("comparison failed");
383        assert!(result.summary().contains("MATCH"));
384
385        let c = arr_1d(&[1.0, 100.0]);
386        let result2 = compare_tensors(&a, &c, &Tolerance::default()).expect("comparison failed");
387        assert!(result2.summary().contains("MISMATCH"));
388    }
389
390    #[test]
391    fn test_assert_tensors_close_passes() {
392        let a = arr_1d(&[1.0, 2.0, 3.0]);
393        let b = arr_1d(&[1.0, 2.0, 3.0]);
394        assert_tensors_close(&a, &b, &Tolerance::default());
395    }
396
397    #[test]
398    fn test_is_finite_true() {
399        let a = arr_1d(&[1.0, 2.0, 3.0]);
400        assert!(is_finite(&a));
401    }
402
403    #[test]
404    fn test_count_non_finite() {
405        let a = arr_1d(&[1.0, f64::NAN, f64::INFINITY]);
406        let (nan_count, inf_count) = count_non_finite(&a);
407        assert_eq!(nan_count, 1);
408        assert_eq!(inf_count, 1);
409    }
410}