1use crate::ops::{ElemOp, ReduceOp};
40use crate::traits::{TlAutodiff, TlExecutor};
41
42pub trait BackendTestAdapter {
47 type Executor: TlExecutor<Tensor = Self::Tensor>;
49 type Tensor: Clone;
51
52 fn create_executor() -> Self::Executor;
54
55 fn create_tensor_from_data(data: &[f64], shape: &[usize]) -> Self::Tensor;
57
58 fn tensor_to_vec(tensor: &Self::Tensor) -> Vec<f64>;
60
61 fn tensor_shape(tensor: &Self::Tensor) -> Vec<usize>;
63
64 fn create_scalar(value: f64) -> Self::Tensor {
66 Self::create_tensor_from_data(&[value], &[])
67 }
68
69 fn create_vector(data: &[f64]) -> Self::Tensor {
71 Self::create_tensor_from_data(data, &[data.len()])
72 }
73
74 fn create_matrix(data: &[f64], rows: usize, cols: usize) -> Self::Tensor {
76 assert_eq!(data.len(), rows * cols);
77 Self::create_tensor_from_data(data, &[rows, cols])
78 }
79}
80
81pub type TestResult = Result<(), String>;
83
84pub const DEFAULT_TOLERANCE: f64 = 1e-6;
86
87pub fn assert_vec_close(actual: &[f64], expected: &[f64], tolerance: f64) -> TestResult {
89 if actual.len() != expected.len() {
90 return Err(format!(
91 "Length mismatch: got {}, expected {}",
92 actual.len(),
93 expected.len()
94 ));
95 }
96
97 for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() {
98 let diff = (a - e).abs();
99 if diff > tolerance && diff / (e.abs() + 1e-10) > tolerance {
100 return Err(format!(
101 "Value mismatch at index {}: got {}, expected {}, diff {}",
102 i, a, e, diff
103 ));
104 }
105 }
106
107 Ok(())
108}
109
110pub fn test_backend_elem_unary<A: BackendTestAdapter>() -> TestResult
116where
117 <A::Executor as TlExecutor>::Error: std::fmt::Debug,
118{
119 let mut executor = A::create_executor();
120
121 let x = A::create_vector(&[1.0, 0.5, 0.0]);
123 let result = executor
124 .elem_op(ElemOp::OneMinus, &x)
125 .map_err(|e| format!("OneMinus failed: {:?}", e))?;
126 let output = A::tensor_to_vec(&result);
127 assert_vec_close(&output, &[0.0, 0.5, 1.0], DEFAULT_TOLERANCE)?;
128
129 let x = A::create_vector(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
131 let result = executor
132 .elem_op(ElemOp::Relu, &x)
133 .map_err(|e| format!("Relu failed: {:?}", e))?;
134 let output = A::tensor_to_vec(&result);
135 assert_vec_close(&output, &[0.0, 0.0, 0.0, 1.0, 2.0], DEFAULT_TOLERANCE)?;
136
137 let x = A::create_vector(&[0.0]);
139 let result = executor
140 .elem_op(ElemOp::Sigmoid, &x)
141 .map_err(|e| format!("Sigmoid failed: {:?}", e))?;
142 let output = A::tensor_to_vec(&result);
143 assert_vec_close(&output, &[0.5], DEFAULT_TOLERANCE)?;
144
145 Ok(())
146}
147
148pub fn test_backend_elem_binary<A: BackendTestAdapter>() -> TestResult
150where
151 <A::Executor as TlExecutor>::Error: std::fmt::Debug,
152{
153 let mut executor = A::create_executor();
154
155 let x = A::create_vector(&[1.0, 2.0, 3.0]);
157 let y = A::create_vector(&[4.0, 5.0, 6.0]);
158 let result = executor
159 .elem_op_binary(ElemOp::Add, &x, &y)
160 .map_err(|e| format!("Add failed: {:?}", e))?;
161 let output = A::tensor_to_vec(&result);
162 assert_vec_close(&output, &[5.0, 7.0, 9.0], DEFAULT_TOLERANCE)?;
163
164 let result = executor
166 .elem_op_binary(ElemOp::Multiply, &x, &y)
167 .map_err(|e| format!("Multiply failed: {:?}", e))?;
168 let output = A::tensor_to_vec(&result);
169 assert_vec_close(&output, &[4.0, 10.0, 18.0], DEFAULT_TOLERANCE)?;
170
171 let result = executor
173 .elem_op_binary(ElemOp::Subtract, &x, &y)
174 .map_err(|e| format!("Subtract failed: {:?}", e))?;
175 let output = A::tensor_to_vec(&result);
176 assert_vec_close(&output, &[-3.0, -3.0, -3.0], DEFAULT_TOLERANCE)?;
177
178 let result = executor
180 .elem_op_binary(ElemOp::Divide, &y, &x)
181 .map_err(|e| format!("Divide failed: {:?}", e))?;
182 let output = A::tensor_to_vec(&result);
183 assert_vec_close(&output, &[4.0, 2.5, 2.0], DEFAULT_TOLERANCE)?;
184
185 Ok(())
186}
187
188pub fn test_backend_reduce<A: BackendTestAdapter>() -> TestResult
190where
191 <A::Executor as TlExecutor>::Error: std::fmt::Debug,
192{
193 let mut executor = A::create_executor();
194
195 let x = A::create_matrix(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
197
198 let result = executor
200 .reduce(ReduceOp::Sum, &x, &[0])
201 .map_err(|e| format!("Sum reduce failed: {:?}", e))?;
202 let output = A::tensor_to_vec(&result);
203 assert_vec_close(&output, &[5.0, 7.0, 9.0], DEFAULT_TOLERANCE)?;
204
205 let result = executor
207 .reduce(ReduceOp::Sum, &x, &[1])
208 .map_err(|e| format!("Sum reduce failed: {:?}", e))?;
209 let output = A::tensor_to_vec(&result);
210 assert_vec_close(&output, &[6.0, 15.0], DEFAULT_TOLERANCE)?;
211
212 let x = A::create_vector(&[1.0, 5.0, 3.0, 2.0]);
214 let result = executor
215 .reduce(ReduceOp::Max, &x, &[0])
216 .map_err(|e| format!("Max reduce failed: {:?}", e))?;
217 let output = A::tensor_to_vec(&result);
218 assert_vec_close(&output, &[5.0], DEFAULT_TOLERANCE)?;
219
220 let x = A::create_vector(&[2.0, 4.0, 6.0, 8.0]);
222 let result = executor
223 .reduce(ReduceOp::Mean, &x, &[0])
224 .map_err(|e| format!("Mean reduce failed: {:?}", e))?;
225 let output = A::tensor_to_vec(&result);
226 assert_vec_close(&output, &[5.0], DEFAULT_TOLERANCE)?;
227
228 Ok(())
229}
230
231pub fn test_backend_einsum<A: BackendTestAdapter>() -> TestResult
233where
234 <A::Executor as TlExecutor>::Error: std::fmt::Debug,
235{
236 let mut executor = A::create_executor();
237
238 let a = A::create_vector(&[1.0, 2.0, 3.0]);
240 let b = A::create_vector(&[4.0, 5.0, 6.0]);
241 let result = executor
242 .einsum("i,i->", &[a.clone(), b.clone()])
243 .map_err(|e| format!("Einsum dot product failed: {:?}", e))?;
244 let output = A::tensor_to_vec(&result);
245 assert_vec_close(&output, &[32.0], DEFAULT_TOLERANCE)?; let mat = A::create_matrix(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
249 let vec = A::create_vector(&[1.0, 2.0, 3.0]);
250 let result = executor
251 .einsum("ij,j->i", &[mat, vec])
252 .map_err(|e| format!("Einsum matvec failed: {:?}", e))?;
253 let output = A::tensor_to_vec(&result);
254 assert_vec_close(&output, &[14.0, 32.0], DEFAULT_TOLERANCE)?;
255
256 let a = A::create_matrix(&[1.0, 2.0, 3.0, 4.0], 2, 2);
258 let b = A::create_matrix(&[5.0, 6.0, 7.0, 8.0], 2, 2);
259 let result = executor
260 .einsum("ij,jk->ik", &[a, b])
261 .map_err(|e| format!("Einsum matmul failed: {:?}", e))?;
262 let output = A::tensor_to_vec(&result);
263 assert_vec_close(&output, &[19.0, 22.0, 43.0, 50.0], DEFAULT_TOLERANCE)?;
264
265 Ok(())
266}
267
268pub fn test_backend_forward<A>() -> TestResult
278where
279 A: BackendTestAdapter,
280 A::Executor: TlAutodiff<Tensor = A::Tensor>,
281 <A::Executor as TlExecutor>::Error: std::fmt::Debug,
282{
283 Ok(())
289}
290
291pub fn test_backend_edge_cases<A: BackendTestAdapter>() -> TestResult
297where
298 <A::Executor as TlExecutor>::Error: std::fmt::Debug,
299{
300 let mut executor = A::create_executor();
301
302 let x = A::create_vector(&[1.0, 2.0, 3.0]);
304 let y = A::create_vector(&[1.0, 0.0, 3.0]);
305 let result = executor.elem_op_binary(ElemOp::Divide, &x, &y);
306
307 match result {
309 Ok(tensor) => {
310 let output = A::tensor_to_vec(&tensor);
311 assert_eq!(output.len(), 3);
312 assert!((output[0] - 1.0).abs() < DEFAULT_TOLERANCE);
313 assert!(output[1].is_infinite() || output[1].is_nan());
314 }
315 Err(_) => {
316 }
318 }
319
320 let x = A::create_vector(&[1e10, -1e10, 0.0]);
322 let result = executor
323 .elem_op(ElemOp::Relu, &x)
324 .map_err(|e| format!("Relu with large values failed: {:?}", e))?;
325 let output = A::tensor_to_vec(&result);
326 assert_vec_close(&output, &[1e10, 0.0, 0.0], 1e4)?;
327
328 Ok(())
329}
330
331pub fn test_backend_shapes<A: BackendTestAdapter>() -> TestResult
337where
338 <A::Executor as TlExecutor>::Error: std::fmt::Debug,
339{
340 let mut executor = A::create_executor();
341
342 let scalar1 = A::create_scalar(5.0);
344 let scalar2 = A::create_scalar(3.0);
345 let result = executor
346 .elem_op_binary(ElemOp::Add, &scalar1, &scalar2)
347 .map_err(|e| format!("Scalar add failed: {:?}", e))?;
348 let output = A::tensor_to_vec(&result);
349 assert_vec_close(&output, &[8.0], DEFAULT_TOLERANCE)?;
350
351 let x = A::create_vector(&[1.0, 2.0, 3.0]);
356 let result = executor
357 .reduce(ReduceOp::Sum, &x, &[]) .map_err(|e| format!("Empty axes reduce failed: {:?}", e))?;
359 let output = A::tensor_to_vec(&result);
360 assert_vec_close(&output, &[6.0], DEFAULT_TOLERANCE)?;
361
362 Ok(())
363}
364
365pub fn test_backend_large_tensors<A: BackendTestAdapter>() -> TestResult
371where
372 <A::Executor as TlExecutor>::Error: std::fmt::Debug,
373{
374 let mut executor = A::create_executor();
375
376 let size = 10000;
378 let data1: Vec<f64> = (0..size).map(|i| i as f64).collect();
379 let data2: Vec<f64> = (0..size).map(|i| (size - i) as f64).collect();
380
381 let x = A::create_vector(&data1);
382 let y = A::create_vector(&data2);
383
384 let result = executor
386 .elem_op_binary(ElemOp::Add, &x, &y)
387 .map_err(|e| format!("Large vector add failed: {:?}", e))?;
388
389 let output = A::tensor_to_vec(&result);
390 assert_eq!(output.len(), size);
391
392 assert_vec_close(
394 &output[0..3],
395 &[10000.0, 10000.0, 10000.0],
396 DEFAULT_TOLERANCE,
397 )?;
398
399 Ok(())
400}
401
402pub fn test_backend_memory_efficiency<A: BackendTestAdapter>() -> TestResult
404where
405 <A::Executor as TlExecutor>::Error: std::fmt::Debug,
406{
407 let mut executor = A::create_executor();
408
409 let mut x = A::create_vector(&[1.0, 2.0, 3.0]);
411
412 for i in 0..100 {
413 let y = A::create_scalar((i + 1) as f64);
414 x = executor
415 .elem_op_binary(ElemOp::Add, &x, &y)
416 .map_err(|e| format!("Memory efficiency test failed at iteration {}: {:?}", i, e))?;
417 }
418
419 let output = A::tensor_to_vec(&x);
422 assert_vec_close(&output, &[5051.0, 5052.0, 5053.0], DEFAULT_TOLERANCE)?;
423
424 Ok(())
425}
426
427pub fn numerical_gradient<A, F>(f: F, x: &A::Tensor, epsilon: f64) -> Vec<f64>
433where
434 A: BackendTestAdapter,
435 F: Fn(&A::Tensor) -> A::Tensor,
436{
437 let x_vec = A::tensor_to_vec(x);
438 let shape = A::tensor_shape(x);
439 let mut grad = vec![0.0; x_vec.len()];
440
441 for i in 0..x_vec.len() {
442 let mut x_plus = x_vec.clone();
444 x_plus[i] += epsilon;
445 let x_plus_tensor = A::create_tensor_from_data(&x_plus, &shape);
446 let f_plus = A::tensor_to_vec(&f(&x_plus_tensor));
447
448 let mut x_minus = x_vec.clone();
450 x_minus[i] -= epsilon;
451 let x_minus_tensor = A::create_tensor_from_data(&x_minus, &shape);
452 let f_minus = A::tensor_to_vec(&f(&x_minus_tensor));
453
454 grad[i] = (f_plus[0] - f_minus[0]) / (2.0 * epsilon);
456 }
457
458 grad
459}
460
461pub fn run_all_basic_tests<A: BackendTestAdapter>() -> Vec<(String, TestResult)>
467where
468 <A::Executor as TlExecutor>::Error: std::fmt::Debug,
469{
470 vec![
471 ("elem_unary".to_string(), test_backend_elem_unary::<A>()),
472 ("elem_binary".to_string(), test_backend_elem_binary::<A>()),
473 ("reduce".to_string(), test_backend_reduce::<A>()),
474 ("einsum".to_string(), test_backend_einsum::<A>()),
475 ("edge_cases".to_string(), test_backend_edge_cases::<A>()),
476 ("shapes".to_string(), test_backend_shapes::<A>()),
477 ]
478}
479
480pub fn run_all_performance_tests<A: BackendTestAdapter>() -> Vec<(String, TestResult)>
482where
483 <A::Executor as TlExecutor>::Error: std::fmt::Debug,
484{
485 vec![
486 (
487 "large_tensors".to_string(),
488 test_backend_large_tensors::<A>(),
489 ),
490 (
491 "memory_efficiency".to_string(),
492 test_backend_memory_efficiency::<A>(),
493 ),
494 ]
495}
496
497pub fn print_test_summary(results: &[(String, TestResult)]) {
499 println!("\n=== Backend Test Results ===");
500 let mut passed = 0;
501 let mut failed = 0;
502
503 for (name, result) in results {
504 match result {
505 Ok(()) => {
506 println!("✓ {}", name);
507 passed += 1;
508 }
509 Err(msg) => {
510 println!("✗ {} - {}", name, msg);
511 failed += 1;
512 }
513 }
514 }
515
516 println!("\nPassed: {}, Failed: {}", passed, failed);
517}
518
519#[cfg(test)]
520mod tests {
521 use super::*;
522
523 #[test]
524 fn test_assert_vec_close() {
525 assert!(assert_vec_close(&[1.0, 2.0, 3.0], &[1.0, 2.0, 3.0], 1e-10).is_ok());
526 assert!(assert_vec_close(&[1.0, 2.0, 3.0], &[1.0, 2.0, 3.1], 0.2).is_ok());
527 assert!(assert_vec_close(&[1.0, 2.0, 3.0], &[1.0, 2.0, 3.1], 0.01).is_err());
528 assert!(assert_vec_close(&[1.0, 2.0], &[1.0, 2.0, 3.0], 1e-10).is_err());
529 }
530
531 #[test]
532 fn test_default_tolerance() {
533 let tolerance = DEFAULT_TOLERANCE;
535 let max_tolerance = 1e-5;
536 assert!(tolerance > 0.0);
537 assert!(tolerance < max_tolerance);
538 }
539}