Skip to main content

sklears_simd/
validation.rs

1//! Validation framework for SIMD operations
2//!
3//! This module provides comprehensive validation tools for ensuring numerical accuracy,
4//! correctness, and performance of SIMD implementations against reference implementations.
5
6#[cfg(not(feature = "no-std"))]
7use std::collections::HashMap;
8#[cfg(not(feature = "no-std"))]
9use std::string::{String, ToString};
10#[cfg(not(feature = "no-std"))]
11use std::time::Instant;
12#[cfg(not(feature = "no-std"))]
13use std::vec::Vec;
14
15#[cfg(feature = "no-std")]
16use alloc::collections::BTreeMap as HashMap;
17#[cfg(feature = "no-std")]
18use alloc::string::{String, ToString};
19#[cfg(feature = "no-std")]
20use alloc::vec::Vec;
21#[cfg(feature = "no-std")]
22use alloc::{format, vec};
23
24// Mock types for no-std compatibility
25#[cfg(feature = "no-std")]
26#[derive(Debug, Clone, Copy)]
27pub struct Instant;
28
29#[cfg(feature = "no-std")]
30#[derive(Debug, Clone, Copy)]
31pub struct Duration;
32
33#[cfg(feature = "no-std")]
34impl Instant {
35    pub fn now() -> Self {
36        Instant // Mock implementation for no-std
37    }
38
39    pub fn elapsed(&self) -> Duration {
40        Duration // Mock implementation
41    }
42}
43
44#[cfg(feature = "no-std")]
45impl Duration {
46    pub fn as_nanos(&self) -> u128 {
47        0 // Mock implementation for no-std
48    }
49}
50
51/// Numerical precision validation with configurable tolerances
52pub mod precision {
53    use super::*;
54
55    /// Tolerance levels for different types of operations
56    #[derive(Debug, Clone, Copy)]
57    pub struct Tolerance {
58        pub absolute: f64,
59        pub relative: f64,
60    }
61
62    impl Tolerance {
63        pub const STRICT: Self = Self {
64            absolute: 1e-15,
65            relative: 1e-14,
66        };
67
68        pub const NORMAL: Self = Self {
69            absolute: 1e-12,
70            relative: 1e-11,
71        };
72
73        pub const RELAXED: Self = Self {
74            absolute: 1e-9,
75            relative: 1e-8,
76        };
77
78        pub const VERY_RELAXED: Self = Self {
79            absolute: 1e-6,
80            relative: 1e-5,
81        };
82    }
83
84    /// Compare two floating-point values with given tolerance
85    pub fn compare_f32(a: f32, b: f32, tolerance: Tolerance) -> bool {
86        let abs_diff = (a - b).abs() as f64;
87        let rel_diff = if b != 0.0 {
88            abs_diff / (b.abs() as f64)
89        } else {
90            abs_diff
91        };
92
93        abs_diff <= tolerance.absolute || rel_diff <= tolerance.relative
94    }
95
96    /// Compare two f64 values with given tolerance
97    pub fn compare_f64(a: f64, b: f64, tolerance: Tolerance) -> bool {
98        let abs_diff = (a - b).abs();
99        let rel_diff = if b != 0.0 {
100            abs_diff / b.abs()
101        } else {
102            abs_diff
103        };
104
105        abs_diff <= tolerance.absolute || rel_diff <= tolerance.relative
106    }
107
108    /// Compare two slices of f32 values
109    pub fn compare_f32_slice(a: &[f32], b: &[f32], tolerance: Tolerance) -> ValidationResult {
110        if a.len() != b.len() {
111            return ValidationResult::error("Length mismatch");
112        }
113
114        let mut mismatches = Vec::new();
115        let mut max_abs_error = 0.0f64;
116        let mut max_rel_error = 0.0f64;
117
118        for (i, (&val_a, &val_b)) in a.iter().zip(b.iter()).enumerate() {
119            if !compare_f32(val_a, val_b, tolerance) {
120                let abs_error = (val_a - val_b).abs() as f64;
121                let rel_error = if val_b != 0.0 {
122                    abs_error / (val_b.abs() as f64)
123                } else {
124                    abs_error
125                };
126
127                max_abs_error = max_abs_error.max(abs_error);
128                max_rel_error = max_rel_error.max(rel_error);
129
130                mismatches.push(ValidationError {
131                    index: Some(i),
132                    expected: val_b as f64,
133                    actual: val_a as f64,
134                    abs_error,
135                    rel_error,
136                    description: format!("Mismatch at index {}", i),
137                });
138
139                if mismatches.len() >= 10 {
140                    break; // Limit reported errors
141                }
142            }
143        }
144
145        if mismatches.is_empty() {
146            ValidationResult::success()
147        } else {
148            let failed_count = mismatches.len();
149            ValidationResult {
150                passed: false,
151                errors: mismatches,
152                statistics: Some(ValidationStatistics {
153                    max_abs_error,
154                    max_rel_error,
155                    total_comparisons: a.len(),
156                    failed_comparisons: failed_count,
157                }),
158            }
159        }
160    }
161
162    /// Compare two slices of f64 values
163    pub fn compare_f64_slice(a: &[f64], b: &[f64], tolerance: Tolerance) -> ValidationResult {
164        if a.len() != b.len() {
165            return ValidationResult::error("Length mismatch");
166        }
167
168        let mut mismatches = Vec::new();
169        let mut max_abs_error = 0.0f64;
170        let mut max_rel_error = 0.0f64;
171
172        for (i, (&val_a, &val_b)) in a.iter().zip(b.iter()).enumerate() {
173            if !compare_f64(val_a, val_b, tolerance) {
174                let abs_error = (val_a - val_b).abs();
175                let rel_error = if val_b != 0.0 {
176                    abs_error / val_b.abs()
177                } else {
178                    abs_error
179                };
180
181                max_abs_error = max_abs_error.max(abs_error);
182                max_rel_error = max_rel_error.max(rel_error);
183
184                mismatches.push(ValidationError {
185                    index: Some(i),
186                    expected: val_b,
187                    actual: val_a,
188                    abs_error,
189                    rel_error,
190                    description: format!("Mismatch at index {}", i),
191                });
192
193                if mismatches.len() >= 10 {
194                    break;
195                }
196            }
197        }
198
199        if mismatches.is_empty() {
200            ValidationResult::success()
201        } else {
202            let failed_count = mismatches.len();
203            ValidationResult {
204                passed: false,
205                errors: mismatches,
206                statistics: Some(ValidationStatistics {
207                    max_abs_error,
208                    max_rel_error,
209                    total_comparisons: a.len(),
210                    failed_comparisons: failed_count,
211                }),
212            }
213        }
214    }
215}
216
217/// Edge case testing for special values
218pub mod edge_cases {
219    use super::*;
220
221    /// Special floating-point test values
222    pub fn get_special_f32_values() -> Vec<f32> {
223        vec![
224            0.0,
225            -0.0,
226            1.0,
227            -1.0,
228            f32::INFINITY,
229            f32::NEG_INFINITY,
230            f32::NAN,
231            f32::MIN,
232            f32::MAX,
233            f32::MIN_POSITIVE,
234            f32::EPSILON,
235            core::f32::consts::PI,
236            core::f32::consts::E,
237            1e-30,
238            1e30,
239            -1e-30,
240            -1e30,
241        ]
242    }
243
244    /// Special floating-point test values for f64
245    pub fn get_special_f64_values() -> Vec<f64> {
246        vec![
247            0.0,
248            -0.0,
249            1.0,
250            -1.0,
251            f64::INFINITY,
252            f64::NEG_INFINITY,
253            f64::NAN,
254            f64::MIN,
255            f64::MAX,
256            f64::MIN_POSITIVE,
257            f64::EPSILON,
258            core::f64::consts::PI,
259            core::f64::consts::E,
260            1e-100,
261            1e100,
262            -1e-100,
263            -1e100,
264        ]
265    }
266
267    /// Test a function with edge case values
268    pub fn test_unary_f32<F>(
269        func: F,
270        reference_func: F,
271        tolerance: precision::Tolerance,
272    ) -> ValidationResult
273    where
274        F: Fn(f32) -> f32,
275    {
276        let test_values = get_special_f32_values();
277        let mut errors = Vec::new();
278
279        for &val in &test_values {
280            let result = func(val);
281            let expected = reference_func(val);
282
283            if !are_equal_with_nan_handling_f32(result, expected, tolerance) {
284                errors.push(ValidationError {
285                    index: None,
286                    expected: expected as f64,
287                    actual: result as f64,
288                    abs_error: (result - expected).abs() as f64,
289                    rel_error: if expected != 0.0 {
290                        ((result - expected) / expected).abs() as f64
291                    } else {
292                        (result - expected).abs() as f64
293                    },
294                    description: format!("Edge case failure for input: {}", val),
295                });
296            }
297        }
298
299        if errors.is_empty() {
300            ValidationResult::success()
301        } else {
302            ValidationResult {
303                passed: false,
304                errors,
305                statistics: None,
306            }
307        }
308    }
309
310    /// Test a binary function with edge case combinations
311    pub fn test_binary_f32<F>(
312        func: F,
313        reference_func: F,
314        tolerance: precision::Tolerance,
315    ) -> ValidationResult
316    where
317        F: Fn(f32, f32) -> f32,
318    {
319        let test_values = get_special_f32_values();
320        let mut errors = Vec::new();
321
322        for &a in &test_values {
323            for &b in &test_values {
324                let result = func(a, b);
325                let expected = reference_func(a, b);
326
327                if !are_equal_with_nan_handling_f32(result, expected, tolerance) {
328                    errors.push(ValidationError {
329                        index: None,
330                        expected: expected as f64,
331                        actual: result as f64,
332                        abs_error: (result - expected).abs() as f64,
333                        rel_error: if expected != 0.0 {
334                            ((result - expected) / expected).abs() as f64
335                        } else {
336                            (result - expected).abs() as f64
337                        },
338                        description: format!("Edge case failure for inputs: {}, {}", a, b),
339                    });
340
341                    if errors.len() >= 20 {
342                        break;
343                    }
344                }
345            }
346            if errors.len() >= 20 {
347                break;
348            }
349        }
350
351        if errors.is_empty() {
352            ValidationResult::success()
353        } else {
354            ValidationResult {
355                passed: false,
356                errors,
357                statistics: None,
358            }
359        }
360    }
361
362    fn are_equal_with_nan_handling_f32(a: f32, b: f32, tolerance: precision::Tolerance) -> bool {
363        if a.is_nan() && b.is_nan() {
364            true
365        } else if a.is_infinite() && b.is_infinite() {
366            a.signum() == b.signum()
367        } else {
368            precision::compare_f32(a, b, tolerance)
369        }
370    }
371}
372
373/// Correctness verification against reference implementations
374pub mod correctness {
375    use super::*;
376
377    /// Verify SIMD implementation against scalar reference
378    pub fn verify_against_scalar<F1, F2, T, R>(
379        simd_func: F1,
380        scalar_func: F2,
381        test_data: &[T],
382        _tolerance: precision::Tolerance,
383        operation_name: &str,
384    ) -> ValidationResult
385    where
386        F1: Fn(&[T]) -> R,
387        F2: Fn(&[T]) -> R,
388        R: PartialEq + core::fmt::Debug + Clone,
389    {
390        let simd_result = simd_func(test_data);
391        let scalar_result = scalar_func(test_data);
392
393        if simd_result == scalar_result {
394            ValidationResult::success()
395        } else {
396            ValidationResult::error(&format!(
397                "SIMD result {:?} does not match scalar result {:?} for operation: {}",
398                simd_result, scalar_result, operation_name
399            ))
400        }
401    }
402
403    /// Verify SIMD f32 slice operations
404    pub fn verify_f32_slice_operation<F1, F2>(
405        simd_func: F1,
406        scalar_func: F2,
407        test_data: &[f32],
408        tolerance: precision::Tolerance,
409        operation_name: &str,
410    ) -> ValidationResult
411    where
412        F1: Fn(&[f32]) -> Vec<f32>,
413        F2: Fn(&[f32]) -> Vec<f32>,
414    {
415        let simd_result = simd_func(test_data);
416        let scalar_result = scalar_func(test_data);
417
418        let mut validation_result =
419            precision::compare_f32_slice(&simd_result, &scalar_result, tolerance);
420
421        if !validation_result.passed {
422            for error in &mut validation_result.errors {
423                error.description = format!("{}: {}", operation_name, error.description);
424            }
425        }
426
427        validation_result
428    }
429
430    /// Verify SIMD f64 slice operations
431    pub fn verify_f64_slice_operation<F1, F2>(
432        simd_func: F1,
433        scalar_func: F2,
434        test_data: &[f64],
435        tolerance: precision::Tolerance,
436        operation_name: &str,
437    ) -> ValidationResult
438    where
439        F1: Fn(&[f64]) -> Vec<f64>,
440        F2: Fn(&[f64]) -> Vec<f64>,
441    {
442        let simd_result = simd_func(test_data);
443        let scalar_result = scalar_func(test_data);
444
445        let mut validation_result =
446            precision::compare_f64_slice(&simd_result, &scalar_result, tolerance);
447
448        if !validation_result.passed {
449            for error in &mut validation_result.errors {
450                error.description = format!("{}: {}", operation_name, error.description);
451            }
452        }
453
454        validation_result
455    }
456
457    /// Generate comprehensive test datasets for validation
458    pub fn generate_test_datasets_f32() -> Vec<Vec<f32>> {
459        vec![
460            // Empty
461            vec![],
462            // Single element
463            vec![1.0],
464            // Small arrays
465            vec![1.0, 2.0, 3.0],
466            vec![-1.0, 0.0, 1.0],
467            // Power-of-2 sizes for SIMD alignment
468            (0..4).map(|i| i as f32).collect(),
469            (0..8).map(|i| i as f32).collect(),
470            (0..16).map(|i| i as f32).collect(),
471            (0..32).map(|i| i as f32).collect(),
472            // Non-power-of-2 sizes
473            (0..7).map(|i| i as f32).collect(),
474            (0..15).map(|i| i as f32).collect(),
475            (0..31).map(|i| i as f32).collect(),
476            // Large arrays
477            (0..1000).map(|i| (i as f32) * 0.1).collect(),
478            // Random-like data
479            vec![
480                0.1, -2.3, 4.7, -0.9, 8.2, -3.1, 5.6, -7.4, 1.8, -6.5, 9.3, -4.7, 2.1, -8.9, 3.4,
481                -1.2,
482            ],
483            // Large values
484            vec![1e10, -1e10, 1e20, -1e20],
485            // Small values
486            vec![1e-10, -1e-10, 1e-20, -1e-20],
487            // Mixed scales
488            vec![1e-10, 1.0, 1e10, -1e-10, -1.0, -1e10],
489        ]
490    }
491
492    /// Generate comprehensive test datasets for f64
493    pub fn generate_test_datasets_f64() -> Vec<Vec<f64>> {
494        vec![
495            // Empty
496            vec![],
497            // Single element
498            vec![1.0],
499            // Small arrays
500            vec![1.0, 2.0, 3.0],
501            vec![-1.0, 0.0, 1.0],
502            // Power-of-2 sizes
503            (0..4).map(|i| i as f64).collect(),
504            (0..8).map(|i| i as f64).collect(),
505            (0..16).map(|i| i as f64).collect(),
506            // Large arrays
507            (0..1000).map(|i| (i as f64) * 0.1).collect(),
508            // High precision values
509            vec![
510                core::f64::consts::PI,
511                core::f64::consts::E,
512                core::f64::consts::SQRT_2,
513                core::f64::consts::LN_2,
514            ],
515            // Extreme values
516            vec![f64::MIN, f64::MAX, f64::MIN_POSITIVE],
517        ]
518    }
519}
520
521/// Performance regression detection
522pub mod performance {
523    use super::*;
524
525    /// Performance measurement result
526    #[derive(Debug, Clone)]
527    pub struct PerformanceResult {
528        pub operation_name: String,
529        pub duration_ns: u64,
530        pub throughput_ops_per_sec: f64,
531        pub data_size: usize,
532    }
533
534    /// Benchmark a function and return performance metrics
535    pub fn benchmark_function<F, T, R>(
536        func: F,
537        data: &[T],
538        operation_name: &str,
539        iterations: usize,
540    ) -> PerformanceResult
541    where
542        F: Fn(&[T]) -> R,
543        T: Clone,
544    {
545        let start = Instant::now();
546
547        for _ in 0..iterations {
548            let _ = func(data);
549        }
550
551        let duration = start.elapsed();
552        let duration_ns = duration.as_nanos() as u64;
553        let avg_duration_ns = duration_ns / iterations as u64;
554        let throughput = if avg_duration_ns > 0 {
555            1_000_000_000.0 / (avg_duration_ns as f64)
556        } else {
557            f64::INFINITY
558        };
559
560        PerformanceResult {
561            operation_name: operation_name.to_string(),
562            duration_ns: avg_duration_ns,
563            throughput_ops_per_sec: throughput,
564            data_size: data.len(),
565        }
566    }
567
568    /// Compare SIMD vs scalar performance
569    pub fn compare_simd_vs_scalar<F1, F2, T, R>(
570        simd_func: F1,
571        scalar_func: F2,
572        data: &[T],
573        operation_name: &str,
574        iterations: usize,
575    ) -> PerformanceComparison
576    where
577        F1: Fn(&[T]) -> R,
578        F2: Fn(&[T]) -> R,
579        T: Clone,
580    {
581        let simd_result = benchmark_function(
582            simd_func,
583            data,
584            &format!("{operation_name}_simd"),
585            iterations,
586        );
587
588        let scalar_result = benchmark_function(
589            scalar_func,
590            data,
591            &format!("{operation_name}_scalar"),
592            iterations,
593        );
594
595        let speedup = if scalar_result.duration_ns > 0 {
596            scalar_result.duration_ns as f64 / simd_result.duration_ns as f64
597        } else {
598            1.0
599        };
600
601        PerformanceComparison {
602            operation_name: operation_name.to_string(),
603            simd_result,
604            scalar_result,
605            speedup,
606        }
607    }
608
609    /// Performance regression threshold check
610    pub fn check_performance_regression(
611        current: &PerformanceResult,
612        baseline: &PerformanceResult,
613        max_regression_percent: f64,
614    ) -> ValidationResult {
615        if baseline.duration_ns == 0 {
616            return ValidationResult::error("Baseline duration is zero");
617        }
618
619        let regression_ratio = current.duration_ns as f64 / baseline.duration_ns as f64;
620        let regression_percent = (regression_ratio - 1.0) * 100.0;
621
622        if regression_percent > max_regression_percent {
623            ValidationResult::error(&format!(
624                "Performance regression detected: {regression_percent:.2}% slower than baseline (max allowed: {max_regression_percent:.2}%)"
625            ))
626        } else {
627            ValidationResult::success()
628        }
629    }
630
631    #[derive(Debug, Clone)]
632    pub struct PerformanceComparison {
633        pub operation_name: String,
634        pub simd_result: PerformanceResult,
635        pub scalar_result: PerformanceResult,
636        pub speedup: f64,
637    }
638}
639
640/// Core validation types and utilities
641#[derive(Debug, Clone)]
642pub struct ValidationError {
643    pub index: Option<usize>,
644    pub expected: f64,
645    pub actual: f64,
646    pub abs_error: f64,
647    pub rel_error: f64,
648    pub description: String,
649}
650
651#[derive(Debug, Clone)]
652pub struct ValidationStatistics {
653    pub max_abs_error: f64,
654    pub max_rel_error: f64,
655    pub total_comparisons: usize,
656    pub failed_comparisons: usize,
657}
658
659#[derive(Debug, Clone)]
660pub struct ValidationResult {
661    pub passed: bool,
662    pub errors: Vec<ValidationError>,
663    pub statistics: Option<ValidationStatistics>,
664}
665
666impl ValidationResult {
667    pub fn success() -> Self {
668        Self {
669            passed: true,
670            errors: Vec::new(),
671            statistics: None,
672        }
673    }
674
675    pub fn error(message: &str) -> Self {
676        Self {
677            passed: false,
678            errors: vec![ValidationError {
679                index: None,
680                expected: 0.0,
681                actual: 0.0,
682                abs_error: 0.0,
683                rel_error: 0.0,
684                description: message.to_string(),
685            }],
686            statistics: None,
687        }
688    }
689
690    pub fn combine(mut self, other: ValidationResult) -> Self {
691        self.passed = self.passed && other.passed;
692        self.errors.extend(other.errors);
693        self
694    }
695}
696
697/// Comprehensive validation suite
698pub struct ValidationSuite {
699    pub results: HashMap<String, ValidationResult>,
700    pub performance_results: HashMap<String, performance::PerformanceResult>,
701}
702
703impl Default for ValidationSuite {
704    fn default() -> Self {
705        Self::new()
706    }
707}
708
709impl ValidationSuite {
710    pub fn new() -> Self {
711        Self {
712            results: HashMap::new(),
713            performance_results: HashMap::new(),
714        }
715    }
716
717    pub fn add_result(&mut self, name: String, result: ValidationResult) {
718        self.results.insert(name, result);
719    }
720
721    pub fn add_performance_result(&mut self, name: String, result: performance::PerformanceResult) {
722        self.performance_results.insert(name, result);
723    }
724
725    pub fn all_passed(&self) -> bool {
726        self.results.values().all(|r| r.passed)
727    }
728
729    pub fn print_summary(&self) {
730        #[cfg(not(feature = "no-std"))]
731        {
732            let total_tests = self.results.len();
733            let passed_tests = self.results.values().filter(|r| r.passed).count();
734
735            println!("Validation Summary:");
736            println!("  Total tests: {total_tests}");
737            println!("  Passed: {passed_tests}");
738            println!("  Failed: {}", total_tests - passed_tests);
739
740            for (name, result) in &self.results {
741                if !result.passed {
742                    println!("  FAILED: {name}");
743                    for error in &result.errors {
744                        println!("    {}", error.description);
745                    }
746                }
747            }
748
749            if !self.performance_results.is_empty() {
750                println!("\nPerformance Results:");
751                for (name, perf) in &self.performance_results {
752                    println!(
753                        "  {}: {:.2} ns/op ({:.2e} ops/sec)",
754                        name, perf.duration_ns, perf.throughput_ops_per_sec
755                    );
756                }
757            }
758        }
759    }
760}
761
762#[allow(non_snake_case)]
763#[cfg(all(test, not(feature = "no-std")))]
764mod tests {
765    use super::*;
766
767    #[cfg(feature = "no-std")]
768    use alloc::{vec, vec::Vec};
769
770    #[test]
771    fn test_precision_comparison() {
772        assert!(precision::compare_f32(
773            1.0,
774            1.0,
775            precision::Tolerance::STRICT
776        ));
777        assert!(precision::compare_f32(
778            1.0,
779            1.0 + 1e-12,
780            precision::Tolerance::NORMAL
781        ));
782        assert!(!precision::compare_f32(
783            1.0,
784            1.1,
785            precision::Tolerance::STRICT
786        ));
787    }
788
789    #[test]
790    fn test_edge_cases() {
791        let special_values = edge_cases::get_special_f32_values();
792        assert!(special_values.iter().any(|x| x.is_nan())); // NaN comparison needs special handling
793        assert!(special_values.contains(&f32::INFINITY));
794        assert!(special_values.contains(&0.0));
795    }
796
797    #[test]
798    fn test_slice_comparison() {
799        let a = vec![1.0, 2.0, 3.0];
800        let b = vec![1.0, 2.0, 3.0];
801        let result = precision::compare_f32_slice(&a, &b, precision::Tolerance::NORMAL);
802        assert!(result.passed);
803
804        let c = vec![1.0, 2.1, 3.0];
805        let result2 = precision::compare_f32_slice(&a, &c, precision::Tolerance::STRICT);
806        assert!(!result2.passed);
807    }
808
809    #[test]
810    fn test_validation_suite() {
811        let mut suite = ValidationSuite::new();
812        suite.add_result("test1".to_string(), ValidationResult::success());
813        suite.add_result("test2".to_string(), ValidationResult::error("Test error"));
814
815        assert!(!suite.all_passed());
816        assert_eq!(suite.results.len(), 2);
817    }
818
819    #[test]
820    fn test_performance_measurement() {
821        let data = vec![1.0f32; 1000];
822        let result = performance::benchmark_function(
823            |slice| slice.iter().sum::<f32>(),
824            &data,
825            "sum_test",
826            100,
827        );
828
829        assert_eq!(result.operation_name, "sum_test");
830        assert!(result.duration_ns > 0);
831        assert!(result.throughput_ops_per_sec > 0.0);
832    }
833
834    #[test]
835    fn test_test_data_generation() {
836        let datasets = correctness::generate_test_datasets_f32();
837        assert!(!datasets.is_empty());
838        assert!(datasets.iter().any(|d| d.is_empty()));
839        assert!(datasets.iter().any(|d| d.len() == 1));
840        assert!(datasets.iter().any(|d| d.len() > 100));
841    }
842}