Skip to main content

tensorlogic_infer/
backend_tests.rs

1//! Backend compatibility test templates.
2//!
3//! This module provides comprehensive test templates that backend developers
4//! can use to verify their implementations of TlExecutor and TlAutodiff traits.
5//!
6//! # Usage
7//!
8//! Backend developers should implement the `BackendTestAdapter` trait for their
9//! executor and then use the provided test functions to validate correctness.
10//!
11//! ```ignore
12//! use tensorlogic_infer::backend_tests::*;
13//!
14//! struct MyBackendAdapter;
15//!
16//! impl BackendTestAdapter for MyBackendAdapter {
17//!     type Executor = MyExecutor;
18//!     type Tensor = MyTensor;
19//!
20//!     fn create_executor() -> Self::Executor {
21//!         MyExecutor::new()
22//!     }
23//!
24//!     fn create_tensor_from_data(data: &[f64], shape: &[usize]) -> Self::Tensor {
25//!         MyTensor::from_data(data, shape)
26//!     }
27//!
28//!     fn tensor_to_vec(tensor: &Self::Tensor) -> Vec<f64> {
29//!         tensor.to_vec()
30//!     }
31//! }
32//!
33//! // Run all tests
34//! test_backend_basic_ops::<MyBackendAdapter>();
35//! test_backend_einsum::<MyBackendAdapter>();
36//! test_backend_autodiff::<MyBackendAdapter>();
37//! ```
38
39use crate::ops::{ElemOp, ReduceOp};
40use crate::traits::{TlAutodiff, TlExecutor};
41
42/// Adapter trait for backend testing.
43///
44/// Backend developers implement this trait to adapt their executor
45/// to the test framework.
46pub trait BackendTestAdapter {
47    /// The executor type being tested
48    type Executor: TlExecutor<Tensor = Self::Tensor>;
49    /// The tensor type used by the executor
50    type Tensor: Clone;
51
52    /// Create a new executor instance for testing
53    fn create_executor() -> Self::Executor;
54
55    /// Create a tensor from raw data and shape
56    fn create_tensor_from_data(data: &[f64], shape: &[usize]) -> Self::Tensor;
57
58    /// Convert tensor to a flat vector for comparison
59    fn tensor_to_vec(tensor: &Self::Tensor) -> Vec<f64>;
60
61    /// Get the shape of a tensor
62    fn tensor_shape(tensor: &Self::Tensor) -> Vec<usize>;
63
64    /// Create a scalar tensor
65    fn create_scalar(value: f64) -> Self::Tensor {
66        Self::create_tensor_from_data(&[value], &[])
67    }
68
69    /// Create a 1D tensor (vector)
70    fn create_vector(data: &[f64]) -> Self::Tensor {
71        Self::create_tensor_from_data(data, &[data.len()])
72    }
73
74    /// Create a 2D tensor (matrix)
75    fn create_matrix(data: &[f64], rows: usize, cols: usize) -> Self::Tensor {
76        assert_eq!(data.len(), rows * cols);
77        Self::create_tensor_from_data(data, &[rows, cols])
78    }
79}
80
81/// Test result with optional failure message
82pub type TestResult = Result<(), String>;
83
84/// Tolerance for floating-point comparisons
85pub const DEFAULT_TOLERANCE: f64 = 1e-6;
86
87/// Compare two vectors with tolerance
88pub fn assert_vec_close(actual: &[f64], expected: &[f64], tolerance: f64) -> TestResult {
89    if actual.len() != expected.len() {
90        return Err(format!(
91            "Length mismatch: got {}, expected {}",
92            actual.len(),
93            expected.len()
94        ));
95    }
96
97    for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() {
98        let diff = (a - e).abs();
99        if diff > tolerance && diff / (e.abs() + 1e-10) > tolerance {
100            return Err(format!(
101                "Value mismatch at index {}: got {}, expected {}, diff {}",
102                i, a, e, diff
103            ));
104        }
105    }
106
107    Ok(())
108}
109
110//
111// ===== BASIC OPERATION TESTS =====
112//
113
114/// Test basic element-wise unary operations
115pub fn test_backend_elem_unary<A: BackendTestAdapter>() -> TestResult
116where
117    <A::Executor as TlExecutor>::Error: std::fmt::Debug,
118{
119    let mut executor = A::create_executor();
120
121    // Test OneMinus
122    let x = A::create_vector(&[1.0, 0.5, 0.0]);
123    let result = executor
124        .elem_op(ElemOp::OneMinus, &x)
125        .map_err(|e| format!("OneMinus failed: {:?}", e))?;
126    let output = A::tensor_to_vec(&result);
127    assert_vec_close(&output, &[0.0, 0.5, 1.0], DEFAULT_TOLERANCE)?;
128
129    // Test Relu
130    let x = A::create_vector(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
131    let result = executor
132        .elem_op(ElemOp::Relu, &x)
133        .map_err(|e| format!("Relu failed: {:?}", e))?;
134    let output = A::tensor_to_vec(&result);
135    assert_vec_close(&output, &[0.0, 0.0, 0.0, 1.0, 2.0], DEFAULT_TOLERANCE)?;
136
137    // Test Sigmoid
138    let x = A::create_vector(&[0.0]);
139    let result = executor
140        .elem_op(ElemOp::Sigmoid, &x)
141        .map_err(|e| format!("Sigmoid failed: {:?}", e))?;
142    let output = A::tensor_to_vec(&result);
143    assert_vec_close(&output, &[0.5], DEFAULT_TOLERANCE)?;
144
145    Ok(())
146}
147
148/// Test basic element-wise binary operations
149pub fn test_backend_elem_binary<A: BackendTestAdapter>() -> TestResult
150where
151    <A::Executor as TlExecutor>::Error: std::fmt::Debug,
152{
153    let mut executor = A::create_executor();
154
155    // Test Add
156    let x = A::create_vector(&[1.0, 2.0, 3.0]);
157    let y = A::create_vector(&[4.0, 5.0, 6.0]);
158    let result = executor
159        .elem_op_binary(ElemOp::Add, &x, &y)
160        .map_err(|e| format!("Add failed: {:?}", e))?;
161    let output = A::tensor_to_vec(&result);
162    assert_vec_close(&output, &[5.0, 7.0, 9.0], DEFAULT_TOLERANCE)?;
163
164    // Test Multiply
165    let result = executor
166        .elem_op_binary(ElemOp::Multiply, &x, &y)
167        .map_err(|e| format!("Multiply failed: {:?}", e))?;
168    let output = A::tensor_to_vec(&result);
169    assert_vec_close(&output, &[4.0, 10.0, 18.0], DEFAULT_TOLERANCE)?;
170
171    // Test Subtract
172    let result = executor
173        .elem_op_binary(ElemOp::Subtract, &x, &y)
174        .map_err(|e| format!("Subtract failed: {:?}", e))?;
175    let output = A::tensor_to_vec(&result);
176    assert_vec_close(&output, &[-3.0, -3.0, -3.0], DEFAULT_TOLERANCE)?;
177
178    // Test Divide
179    let result = executor
180        .elem_op_binary(ElemOp::Divide, &y, &x)
181        .map_err(|e| format!("Divide failed: {:?}", e))?;
182    let output = A::tensor_to_vec(&result);
183    assert_vec_close(&output, &[4.0, 2.5, 2.0], DEFAULT_TOLERANCE)?;
184
185    Ok(())
186}
187
188/// Test reduction operations
189pub fn test_backend_reduce<A: BackendTestAdapter>() -> TestResult
190where
191    <A::Executor as TlExecutor>::Error: std::fmt::Debug,
192{
193    let mut executor = A::create_executor();
194
195    // Test Sum reduction
196    let x = A::create_matrix(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
197
198    // Sum over axis 0
199    let result = executor
200        .reduce(ReduceOp::Sum, &x, &[0])
201        .map_err(|e| format!("Sum reduce failed: {:?}", e))?;
202    let output = A::tensor_to_vec(&result);
203    assert_vec_close(&output, &[5.0, 7.0, 9.0], DEFAULT_TOLERANCE)?;
204
205    // Sum over axis 1
206    let result = executor
207        .reduce(ReduceOp::Sum, &x, &[1])
208        .map_err(|e| format!("Sum reduce failed: {:?}", e))?;
209    let output = A::tensor_to_vec(&result);
210    assert_vec_close(&output, &[6.0, 15.0], DEFAULT_TOLERANCE)?;
211
212    // Test Max reduction
213    let x = A::create_vector(&[1.0, 5.0, 3.0, 2.0]);
214    let result = executor
215        .reduce(ReduceOp::Max, &x, &[0])
216        .map_err(|e| format!("Max reduce failed: {:?}", e))?;
217    let output = A::tensor_to_vec(&result);
218    assert_vec_close(&output, &[5.0], DEFAULT_TOLERANCE)?;
219
220    // Test Mean reduction
221    let x = A::create_vector(&[2.0, 4.0, 6.0, 8.0]);
222    let result = executor
223        .reduce(ReduceOp::Mean, &x, &[0])
224        .map_err(|e| format!("Mean reduce failed: {:?}", e))?;
225    let output = A::tensor_to_vec(&result);
226    assert_vec_close(&output, &[5.0], DEFAULT_TOLERANCE)?;
227
228    Ok(())
229}
230
231/// Test einsum operations
232pub fn test_backend_einsum<A: BackendTestAdapter>() -> TestResult
233where
234    <A::Executor as TlExecutor>::Error: std::fmt::Debug,
235{
236    let mut executor = A::create_executor();
237
238    // Test vector dot product: "i,i->"
239    let a = A::create_vector(&[1.0, 2.0, 3.0]);
240    let b = A::create_vector(&[4.0, 5.0, 6.0]);
241    let result = executor
242        .einsum("i,i->", &[a.clone(), b.clone()])
243        .map_err(|e| format!("Einsum dot product failed: {:?}", e))?;
244    let output = A::tensor_to_vec(&result);
245    assert_vec_close(&output, &[32.0], DEFAULT_TOLERANCE)?; // 1*4 + 2*5 + 3*6 = 32
246
247    // Test matrix-vector multiply: "ij,j->i"
248    let mat = A::create_matrix(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
249    let vec = A::create_vector(&[1.0, 2.0, 3.0]);
250    let result = executor
251        .einsum("ij,j->i", &[mat, vec])
252        .map_err(|e| format!("Einsum matvec failed: {:?}", e))?;
253    let output = A::tensor_to_vec(&result);
254    assert_vec_close(&output, &[14.0, 32.0], DEFAULT_TOLERANCE)?;
255
256    // Test matrix-matrix multiply: "ij,jk->ik"
257    let a = A::create_matrix(&[1.0, 2.0, 3.0, 4.0], 2, 2);
258    let b = A::create_matrix(&[5.0, 6.0, 7.0, 8.0], 2, 2);
259    let result = executor
260        .einsum("ij,jk->ik", &[a, b])
261        .map_err(|e| format!("Einsum matmul failed: {:?}", e))?;
262    let output = A::tensor_to_vec(&result);
263    assert_vec_close(&output, &[19.0, 22.0, 43.0, 50.0], DEFAULT_TOLERANCE)?;
264
265    Ok(())
266}
267
268//
269// ===== AUTODIFF TESTS =====
270//
271
272/// Test forward pass execution
273///
274/// Note: This is a placeholder test. Backend developers should implement
275/// their own forward pass tests based on their specific graph execution
276/// requirements and tensor injection mechanisms.
277pub fn test_backend_forward<A>() -> TestResult
278where
279    A: BackendTestAdapter,
280    A::Executor: TlAutodiff<Tensor = A::Tensor>,
281    <A::Executor as TlExecutor>::Error: std::fmt::Debug,
282{
283    // This is a simplified test that backend developers should customize
284    // for their specific implementation. The test validates that the
285    // forward pass can be called without panicking.
286
287    // Backend-specific graph construction and execution should go here
288    Ok(())
289}
290
291//
292// ===== NUMERICAL STABILITY TESTS =====
293//
294
295/// Test handling of edge cases (NaN, Inf, zeros)
296pub fn test_backend_edge_cases<A: BackendTestAdapter>() -> TestResult
297where
298    <A::Executor as TlExecutor>::Error: std::fmt::Debug,
299{
300    let mut executor = A::create_executor();
301
302    // Test division by zero handling
303    let x = A::create_vector(&[1.0, 2.0, 3.0]);
304    let y = A::create_vector(&[1.0, 0.0, 3.0]);
305    let result = executor.elem_op_binary(ElemOp::Divide, &x, &y);
306
307    // Backend should either return Inf or error - both are acceptable
308    match result {
309        Ok(tensor) => {
310            let output = A::tensor_to_vec(&tensor);
311            assert_eq!(output.len(), 3);
312            assert!((output[0] - 1.0).abs() < DEFAULT_TOLERANCE);
313            assert!(output[1].is_infinite() || output[1].is_nan());
314        }
315        Err(_) => {
316            // Error on division by zero is also acceptable
317        }
318    }
319
320    // Test Relu with very large values
321    let x = A::create_vector(&[1e10, -1e10, 0.0]);
322    let result = executor
323        .elem_op(ElemOp::Relu, &x)
324        .map_err(|e| format!("Relu with large values failed: {:?}", e))?;
325    let output = A::tensor_to_vec(&result);
326    assert_vec_close(&output, &[1e10, 0.0, 0.0], 1e4)?;
327
328    Ok(())
329}
330
331//
332// ===== SHAPE HANDLING TESTS =====
333//
334
335/// Test various tensor shapes
336pub fn test_backend_shapes<A: BackendTestAdapter>() -> TestResult
337where
338    <A::Executor as TlExecutor>::Error: std::fmt::Debug,
339{
340    let mut executor = A::create_executor();
341
342    // Test scalar operations
343    let scalar1 = A::create_scalar(5.0);
344    let scalar2 = A::create_scalar(3.0);
345    let result = executor
346        .elem_op_binary(ElemOp::Add, &scalar1, &scalar2)
347        .map_err(|e| format!("Scalar add failed: {:?}", e))?;
348    let output = A::tensor_to_vec(&result);
349    assert_vec_close(&output, &[8.0], DEFAULT_TOLERANCE)?;
350
351    // Test broadcasting (if supported)
352    // Backends may choose to not support broadcasting - test should be optional
353
354    // Test empty reduction
355    let x = A::create_vector(&[1.0, 2.0, 3.0]);
356    let result = executor
357        .reduce(ReduceOp::Sum, &x, &[]) // No axes = reduce all
358        .map_err(|e| format!("Empty axes reduce failed: {:?}", e))?;
359    let output = A::tensor_to_vec(&result);
360    assert_vec_close(&output, &[6.0], DEFAULT_TOLERANCE)?;
361
362    Ok(())
363}
364
365//
366// ===== PERFORMANCE/STRESS TESTS =====
367//
368
369/// Test performance with large tensors
370pub fn test_backend_large_tensors<A: BackendTestAdapter>() -> TestResult
371where
372    <A::Executor as TlExecutor>::Error: std::fmt::Debug,
373{
374    let mut executor = A::create_executor();
375
376    // Create large vectors
377    let size = 10000;
378    let data1: Vec<f64> = (0..size).map(|i| i as f64).collect();
379    let data2: Vec<f64> = (0..size).map(|i| (size - i) as f64).collect();
380
381    let x = A::create_vector(&data1);
382    let y = A::create_vector(&data2);
383
384    // Test large vector addition
385    let result = executor
386        .elem_op_binary(ElemOp::Add, &x, &y)
387        .map_err(|e| format!("Large vector add failed: {:?}", e))?;
388
389    let output = A::tensor_to_vec(&result);
390    assert_eq!(output.len(), size);
391
392    // Verify a few values
393    assert_vec_close(
394        &output[0..3],
395        &[10000.0, 10000.0, 10000.0],
396        DEFAULT_TOLERANCE,
397    )?;
398
399    Ok(())
400}
401
402/// Test memory efficiency with repeated operations
403pub fn test_backend_memory_efficiency<A: BackendTestAdapter>() -> TestResult
404where
405    <A::Executor as TlExecutor>::Error: std::fmt::Debug,
406{
407    let mut executor = A::create_executor();
408
409    // Perform many operations to test memory management
410    let mut x = A::create_vector(&[1.0, 2.0, 3.0]);
411
412    for i in 0..100 {
413        let y = A::create_scalar((i + 1) as f64);
414        x = executor
415            .elem_op_binary(ElemOp::Add, &x, &y)
416            .map_err(|e| format!("Memory efficiency test failed at iteration {}: {:?}", i, e))?;
417    }
418
419    // After 100 iterations of adding 1, 2, 3, ..., 100
420    // Sum = 100 * 101 / 2 = 5050
421    let output = A::tensor_to_vec(&x);
422    assert_vec_close(&output, &[5051.0, 5052.0, 5053.0], DEFAULT_TOLERANCE)?;
423
424    Ok(())
425}
426
427//
428// ===== GRADIENT CHECKING =====
429//
430
431/// Compute numerical gradient using finite differences
432pub fn numerical_gradient<A, F>(f: F, x: &A::Tensor, epsilon: f64) -> Vec<f64>
433where
434    A: BackendTestAdapter,
435    F: Fn(&A::Tensor) -> A::Tensor,
436{
437    let x_vec = A::tensor_to_vec(x);
438    let shape = A::tensor_shape(x);
439    let mut grad = vec![0.0; x_vec.len()];
440
441    for i in 0..x_vec.len() {
442        // Compute f(x + epsilon)
443        let mut x_plus = x_vec.clone();
444        x_plus[i] += epsilon;
445        let x_plus_tensor = A::create_tensor_from_data(&x_plus, &shape);
446        let f_plus = A::tensor_to_vec(&f(&x_plus_tensor));
447
448        // Compute f(x - epsilon)
449        let mut x_minus = x_vec.clone();
450        x_minus[i] -= epsilon;
451        let x_minus_tensor = A::create_tensor_from_data(&x_minus, &shape);
452        let f_minus = A::tensor_to_vec(&f(&x_minus_tensor));
453
454        // Central difference: (f(x+eps) - f(x-eps)) / (2*eps)
455        grad[i] = (f_plus[0] - f_minus[0]) / (2.0 * epsilon);
456    }
457
458    grad
459}
460
461//
462// ===== COMPREHENSIVE TEST SUITE =====
463//
464
465/// Run all basic operation tests
466pub fn run_all_basic_tests<A: BackendTestAdapter>() -> Vec<(String, TestResult)>
467where
468    <A::Executor as TlExecutor>::Error: std::fmt::Debug,
469{
470    vec![
471        ("elem_unary".to_string(), test_backend_elem_unary::<A>()),
472        ("elem_binary".to_string(), test_backend_elem_binary::<A>()),
473        ("reduce".to_string(), test_backend_reduce::<A>()),
474        ("einsum".to_string(), test_backend_einsum::<A>()),
475        ("edge_cases".to_string(), test_backend_edge_cases::<A>()),
476        ("shapes".to_string(), test_backend_shapes::<A>()),
477    ]
478}
479
480/// Run all performance tests
481pub fn run_all_performance_tests<A: BackendTestAdapter>() -> Vec<(String, TestResult)>
482where
483    <A::Executor as TlExecutor>::Error: std::fmt::Debug,
484{
485    vec![
486        (
487            "large_tensors".to_string(),
488            test_backend_large_tensors::<A>(),
489        ),
490        (
491            "memory_efficiency".to_string(),
492            test_backend_memory_efficiency::<A>(),
493        ),
494    ]
495}
496
497/// Print test results summary
498pub fn print_test_summary(results: &[(String, TestResult)]) {
499    println!("\n=== Backend Test Results ===");
500    let mut passed = 0;
501    let mut failed = 0;
502
503    for (name, result) in results {
504        match result {
505            Ok(()) => {
506                println!("✓ {}", name);
507                passed += 1;
508            }
509            Err(msg) => {
510                println!("✗ {} - {}", name, msg);
511                failed += 1;
512            }
513        }
514    }
515
516    println!("\nPassed: {}, Failed: {}", passed, failed);
517}
518
519#[cfg(test)]
520mod tests {
521    use super::*;
522
523    #[test]
524    fn test_assert_vec_close() {
525        assert!(assert_vec_close(&[1.0, 2.0, 3.0], &[1.0, 2.0, 3.0], 1e-10).is_ok());
526        assert!(assert_vec_close(&[1.0, 2.0, 3.0], &[1.0, 2.0, 3.1], 0.2).is_ok());
527        assert!(assert_vec_close(&[1.0, 2.0, 3.0], &[1.0, 2.0, 3.1], 0.01).is_err());
528        assert!(assert_vec_close(&[1.0, 2.0], &[1.0, 2.0, 3.0], 1e-10).is_err());
529    }
530
531    #[test]
532    fn test_default_tolerance() {
533        // Verify tolerance is sensible (not a compile-time constant check)
534        let tolerance = DEFAULT_TOLERANCE;
535        let max_tolerance = 1e-5;
536        assert!(tolerance > 0.0);
537        assert!(tolerance < max_tolerance);
538    }
539}