tensorlogic_scirs_backend/
comparison.rs1use scirs2_core::ndarray::ArrayD;
7use thiserror::Error;
8
9#[derive(Debug, Error)]
11pub enum ComparisonError {
12 #[error("Shape mismatch: {0:?} vs {1:?}")]
14 ShapeMismatch(Vec<usize>, Vec<usize>),
15 #[error("Empty tensors")]
17 EmptyTensors,
18}
19
20#[derive(Debug, Clone)]
25pub struct Tolerance {
26 pub rtol: f64,
28 pub atol: f64,
30}
31
32impl Default for Tolerance {
33 fn default() -> Self {
34 Tolerance {
35 rtol: 1e-5,
36 atol: 1e-8,
37 }
38 }
39}
40
41impl Tolerance {
42 pub fn new(rtol: f64, atol: f64) -> Self {
44 Tolerance { rtol, atol }
45 }
46
47 pub fn strict() -> Self {
49 Tolerance {
50 rtol: 1e-12,
51 atol: 1e-15,
52 }
53 }
54
55 pub fn loose() -> Self {
57 Tolerance {
58 rtol: 1e-3,
59 atol: 1e-6,
60 }
61 }
62
63 pub fn is_close(&self, a: f64, b: f64) -> bool {
65 (a - b).abs() <= self.atol + self.rtol * b.abs()
66 }
67}
68
69#[derive(Debug, Clone)]
71pub struct ComparisonResult {
72 pub all_close: bool,
74 pub max_abs_diff: f64,
76 pub mean_abs_diff: f64,
78 pub max_rel_diff: f64,
80 pub mismatch_count: usize,
82 pub total_elements: usize,
84 pub max_diff_index: usize,
86 pub nan_mismatches: usize,
88 pub inf_mismatches: usize,
90}
91
92impl ComparisonResult {
93 pub fn match_ratio(&self) -> f64 {
95 if self.total_elements == 0 {
96 1.0
97 } else {
98 (self.total_elements - self.mismatch_count) as f64 / self.total_elements as f64
99 }
100 }
101
102 pub fn summary(&self) -> String {
104 if self.all_close {
105 format!(
106 "MATCH: {} elements, max_diff={:.2e}",
107 self.total_elements, self.max_abs_diff
108 )
109 } else {
110 format!(
111 "MISMATCH: {}/{} elements differ, max_diff={:.2e}, mean_diff={:.2e}",
112 self.mismatch_count, self.total_elements, self.max_abs_diff, self.mean_abs_diff
113 )
114 }
115 }
116}
117
118pub fn compare_tensors(
126 a: &ArrayD<f64>,
127 b: &ArrayD<f64>,
128 tol: &Tolerance,
129) -> Result<ComparisonResult, ComparisonError> {
130 if a.shape() != b.shape() {
131 return Err(ComparisonError::ShapeMismatch(
132 a.shape().to_vec(),
133 b.shape().to_vec(),
134 ));
135 }
136 if a.is_empty() {
137 return Err(ComparisonError::EmptyTensors);
138 }
139
140 let mut max_abs_diff = 0.0_f64;
141 let mut sum_abs_diff = 0.0_f64;
142 let mut max_rel_diff = 0.0_f64;
143 let mut mismatch_count = 0_usize;
144 let mut max_diff_index = 0_usize;
145 let mut nan_mismatches = 0_usize;
146 let mut inf_mismatches = 0_usize;
147
148 for (i, (&va, &vb)) in a.iter().zip(b.iter()).enumerate() {
149 if va.is_nan() != vb.is_nan() {
151 nan_mismatches += 1;
152 mismatch_count += 1;
153 continue;
154 }
155 if va.is_nan() && vb.is_nan() {
156 continue;
158 }
159
160 if va.is_infinite() != vb.is_infinite() {
162 inf_mismatches += 1;
163 mismatch_count += 1;
164 continue;
165 }
166 if va.is_infinite() && vb.is_infinite() {
167 if va.signum() == vb.signum() {
168 continue;
170 }
171 inf_mismatches += 1;
173 mismatch_count += 1;
174 continue;
175 }
176
177 let abs_diff = (va - vb).abs();
179 sum_abs_diff += abs_diff;
180
181 if abs_diff > max_abs_diff {
182 max_abs_diff = abs_diff;
183 max_diff_index = i;
184 }
185
186 let rel_diff = if vb.abs() > 1e-15 {
187 abs_diff / vb.abs()
188 } else {
189 abs_diff
190 };
191 if rel_diff > max_rel_diff {
192 max_rel_diff = rel_diff;
193 }
194
195 if !tol.is_close(va, vb) {
196 mismatch_count += 1;
197 }
198 }
199
200 let total = a.len();
201 Ok(ComparisonResult {
202 all_close: mismatch_count == 0,
203 max_abs_diff,
204 mean_abs_diff: sum_abs_diff / total as f64,
205 max_rel_diff,
206 mismatch_count,
207 total_elements: total,
208 max_diff_index,
209 nan_mismatches,
210 inf_mismatches,
211 })
212}
213
214pub fn assert_tensors_close(a: &ArrayD<f64>, b: &ArrayD<f64>, tol: &Tolerance) {
218 match compare_tensors(a, b, tol) {
219 Ok(result) if result.all_close => {}
220 Ok(result) => panic!(
221 "Tensors not close: {}\nMax diff at index {}: {:.2e}",
222 result.summary(),
223 result.max_diff_index,
224 result.max_abs_diff
225 ),
226 Err(e) => panic!("Tensor comparison failed: {e}"),
227 }
228}
229
230pub fn abs_diff(a: &ArrayD<f64>, b: &ArrayD<f64>) -> Result<ArrayD<f64>, ComparisonError> {
234 if a.shape() != b.shape() {
235 return Err(ComparisonError::ShapeMismatch(
236 a.shape().to_vec(),
237 b.shape().to_vec(),
238 ));
239 }
240 let diff = a - b;
241 Ok(diff.mapv(f64::abs))
242}
243
244pub fn is_finite(tensor: &ArrayD<f64>) -> bool {
246 tensor.iter().all(|v| v.is_finite())
247}
248
249pub fn count_non_finite(tensor: &ArrayD<f64>) -> (usize, usize) {
253 let nan_count = tensor.iter().filter(|v| v.is_nan()).count();
254 let inf_count = tensor.iter().filter(|v| v.is_infinite()).count();
255 (nan_count, inf_count)
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261 use scirs2_core::ndarray::{arr1, ArrayD};
262
263 fn arr_1d(values: &[f64]) -> ArrayD<f64> {
264 arr1(values).into_dyn()
265 }
266
267 #[test]
268 fn test_tolerance_default() {
269 let tol = Tolerance::default();
270 assert!((tol.rtol - 1e-5).abs() < 1e-20);
271 assert!((tol.atol - 1e-8).abs() < 1e-20);
272 }
273
274 #[test]
275 fn test_tolerance_is_close_true() {
276 let tol = Tolerance::default();
277 assert!(tol.is_close(1.0, 1.0 + 1e-9));
278 }
279
280 #[test]
281 fn test_tolerance_is_close_false() {
282 let tol = Tolerance::default();
283 assert!(!tol.is_close(1.0, 2.0));
284 }
285
286 #[test]
287 fn test_tolerance_strict() {
288 let tol = Tolerance::strict();
289 assert!((tol.rtol - 1e-12).abs() < 1e-20);
290 assert!((tol.atol - 1e-15).abs() < 1e-20);
291 }
292
293 #[test]
294 fn test_tolerance_loose() {
295 let tol = Tolerance::loose();
296 assert!((tol.rtol - 1e-3).abs() < 1e-20);
297 assert!((tol.atol - 1e-6).abs() < 1e-20);
298 }
299
300 #[test]
301 fn test_compare_identical() {
302 let a = arr_1d(&[1.0, 2.0, 3.0]);
303 let b = arr_1d(&[1.0, 2.0, 3.0]);
304 let result = compare_tensors(&a, &b, &Tolerance::default()).expect("comparison failed");
305 assert!(result.all_close);
306 assert!((result.max_abs_diff - 0.0).abs() < 1e-20);
307 assert_eq!(result.mismatch_count, 0);
308 }
309
310 #[test]
311 fn test_compare_close() {
312 let a = arr_1d(&[1.0, 2.0, 3.0]);
313 let b = arr_1d(&[1.0 + 1e-9, 2.0 + 1e-9, 3.0 + 1e-9]);
314 let result = compare_tensors(&a, &b, &Tolerance::default()).expect("comparison failed");
315 assert!(result.all_close);
316 }
317
318 #[test]
319 fn test_compare_different() {
320 let a = arr_1d(&[1.0, 2.0, 3.0]);
321 let b = arr_1d(&[1.0, 2.0, 100.0]);
322 let result = compare_tensors(&a, &b, &Tolerance::default()).expect("comparison failed");
323 assert!(!result.all_close);
324 assert!(result.mismatch_count > 0);
325 }
326
327 #[test]
328 fn test_compare_shape_mismatch() {
329 let a = arr_1d(&[1.0, 2.0]);
330 let b = arr_1d(&[1.0, 2.0, 3.0]);
331 let result = compare_tensors(&a, &b, &Tolerance::default());
332 assert!(result.is_err());
333 }
334
335 #[test]
336 fn test_compare_empty() {
337 let a: ArrayD<f64> = ArrayD::zeros(vec![0]);
338 let b: ArrayD<f64> = ArrayD::zeros(vec![0]);
339 let result = compare_tensors(&a, &b, &Tolerance::default());
340 assert!(result.is_err());
341 }
342
343 #[test]
344 fn test_compare_nan_both() {
345 let a = arr_1d(&[f64::NAN, 1.0]);
346 let b = arr_1d(&[f64::NAN, 1.0]);
347 let result = compare_tensors(&a, &b, &Tolerance::default()).expect("comparison failed");
348 assert!(result.all_close);
349 assert_eq!(result.nan_mismatches, 0);
350 }
351
352 #[test]
353 fn test_compare_nan_one() {
354 let a = arr_1d(&[f64::NAN, 1.0]);
355 let b = arr_1d(&[1.0, 1.0]);
356 let result = compare_tensors(&a, &b, &Tolerance::default()).expect("comparison failed");
357 assert!(!result.all_close);
358 assert_eq!(result.nan_mismatches, 1);
359 }
360
361 #[test]
362 fn test_compare_inf_matching() {
363 let a = arr_1d(&[f64::INFINITY, 1.0]);
364 let b = arr_1d(&[f64::INFINITY, 1.0]);
365 let result = compare_tensors(&a, &b, &Tolerance::default()).expect("comparison failed");
366 assert!(result.all_close);
367 assert_eq!(result.inf_mismatches, 0);
368 }
369
370 #[test]
371 fn test_compare_match_ratio() {
372 let a = arr_1d(&[1.0, 2.0, 3.0, 4.0]);
373 let b = arr_1d(&[1.0, 2.0, 3.0, 100.0]);
374 let result = compare_tensors(&a, &b, &Tolerance::default()).expect("comparison failed");
375 assert!((result.match_ratio() - 0.75).abs() < 1e-10);
376 }
377
378 #[test]
379 fn test_compare_summary() {
380 let a = arr_1d(&[1.0, 2.0]);
381 let b = arr_1d(&[1.0, 2.0]);
382 let result = compare_tensors(&a, &b, &Tolerance::default()).expect("comparison failed");
383 assert!(result.summary().contains("MATCH"));
384
385 let c = arr_1d(&[1.0, 100.0]);
386 let result2 = compare_tensors(&a, &c, &Tolerance::default()).expect("comparison failed");
387 assert!(result2.summary().contains("MISMATCH"));
388 }
389
390 #[test]
391 fn test_assert_tensors_close_passes() {
392 let a = arr_1d(&[1.0, 2.0, 3.0]);
393 let b = arr_1d(&[1.0, 2.0, 3.0]);
394 assert_tensors_close(&a, &b, &Tolerance::default());
395 }
396
397 #[test]
398 fn test_is_finite_true() {
399 let a = arr_1d(&[1.0, 2.0, 3.0]);
400 assert!(is_finite(&a));
401 }
402
403 #[test]
404 fn test_count_non_finite() {
405 let a = arr_1d(&[1.0, f64::NAN, f64::INFINITY]);
406 let (nan_count, inf_count) = count_non_finite(&a);
407 assert_eq!(nan_count, 1);
408 assert_eq!(inf_count, 1);
409 }
410}