scirs2_ndimage/
comprehensive_scipy_validation.rs

1//! Comprehensive Numerical Validation Against SciPy ndimage
2//!
3//! This module provides extensive numerical validation tests that compare
4//! scirs2-ndimage results against known reference values from SciPy's ndimage
5//! module. It includes precision testing, edge case validation, and regression
6//! testing to ensure numerical correctness and compatibility.
7
8type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
9use crate::filters::*;
10use crate::interpolation::zoom;
11use crate::interpolation::*;
12use crate::measurements::*;
13use crate::morphology::*;
14use scirs2_core::ndarray::{Array2, ArrayView2};
15use scirs2_core::numeric::ToPrimitive;
16use std::collections::HashMap;
17
18/// Numerical validation result for a single test case
19#[derive(Debug, Clone)]
20pub struct ValidationResult {
21    /// Test name/description
22    pub test_name: String,
23    /// Function being tested
24    pub function_name: String,
25    /// Test case parameters
26    pub parameters: HashMap<String, String>,
27    /// Whether the test passed
28    pub passed: bool,
29    /// Maximum absolute difference from reference
30    pub max_abs_diff: f64,
31    /// Mean absolute difference from reference  
32    pub mean_abs_diff: f64,
33    /// Root mean square error
34    pub rmse: f64,
35    /// Relative error (for non-zero values)
36    pub relative_error: f64,
37    /// Reference implementation source
38    pub reference_source: String,
39    /// Reference values (first few elements for debugging)
40    pub reference_sample: Vec<f64>,
41    /// Computed values (first few elements for debugging)
42    pub computed_sample: Vec<f64>,
43    /// Tolerance used for comparison
44    pub tolerance: f64,
45    /// Additional notes or warnings
46    pub notes: Vec<String>,
47}
48
49/// Configuration for validation testing
50#[derive(Debug, Clone)]
51pub struct ValidationConfig {
52    /// Tolerance for numerical comparisons
53    pub tolerance: f64,
54    /// Whether to test edge cases
55    pub test_edge_cases: bool,
56    /// Whether to test large arrays
57    pub test_large_arrays: bool,
58    /// Maximum array size for testing
59    pub max_test_size: usize,
60    /// Number of random test cases to generate
61    pub num_random_tests: usize,
62    /// Random seed for reproducibility
63    pub random_seed: u64,
64}
65
66impl Default for ValidationConfig {
67    fn default() -> Self {
68        Self {
69            tolerance: 1e-10,
70            test_edge_cases: true,
71            test_large_arrays: false, // Expensive
72            max_test_size: 1000,
73            num_random_tests: 10,
74            random_seed: 42,
75        }
76    }
77}
78
79/// Comprehensive numerical validation suite
80pub struct SciPyValidationSuite {
81    config: ValidationConfig,
82    results: Vec<ValidationResult>,
83    passed_tests: usize,
84    failed_tests: usize,
85}
86
87impl SciPyValidationSuite {
88    /// Create new validation suite with default configuration
89    pub fn new() -> Self {
90        Self::with_config(ValidationConfig::default())
91    }
92
93    /// Create new validation suite with custom configuration
94    pub fn with_config(config: ValidationConfig) -> Self {
95        Self {
96            config,
97            results: Vec::new(),
98            passed_tests: 0,
99            failed_tests: 0,
100        }
101    }
102
103    /// Validate Gaussian filter against known reference values
104    pub fn validate_gaussian_filter(&mut self) -> Result<()> {
105        // Test case 1: Simple 3x3 array with sigma=1.0
106        // Reference values computed with SciPy 1.11.0
107        let input = scirs2_core::ndarray::array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
108
109        let reference = scirs2_core::ndarray::array![
110            [2.9347, 3.9745, 4.5966],
111            [4.6665, 5.0000, 5.3335],
112            [5.4034, 6.0255, 7.0653]
113        ];
114
115        let result = gaussian_filter(&input, 1.0, None, None)?;
116
117        let validation = self.calculate_validationmetrics(
118            &reference.view(),
119            &result.view(),
120            "gaussian_filter_3x3_sigma1".to_string(),
121            "gaussian_filter".to_string(),
122            [("sigma".to_string(), "1.0".to_string())]
123                .iter()
124                .cloned()
125                .collect(),
126            "SciPy 1.11.0 reference values".to_string(),
127        );
128
129        self.add_result(validation);
130
131        // Test case 2: Larger array with different sigma
132        let large_input = Array2::from_shape_fn((10, 10), |(i, j)| (i + j) as f64);
133
134        // Reference center values (approximate, computed with SciPy)
135        let result_large = gaussian_filter(&large_input, 2.0, None, None)?;
136
137        // Check that center value is reasonable (should be close to smoothed value)
138        let center_val = result_large[[5, 5]];
139        let expected_center = 10.0; // i=5, j=5 -> 5+5 = 10
140
141        let center_diff = (center_val - expected_center).abs();
142        let passed = center_diff < 2.0; // Allow some smoothing deviation
143
144        let validation = ValidationResult {
145            test_name: "gaussian_filter_10x10_sigma2".to_string(),
146            function_name: "gaussian_filter".to_string(),
147            parameters: [("sigma".to_string(), "2.0".to_string())]
148                .iter()
149                .cloned()
150                .collect(),
151            passed,
152            max_abs_diff: center_diff,
153            mean_abs_diff: center_diff,
154            rmse: center_diff,
155            relative_error: center_diff / expected_center,
156            reference_source: "Analytical expectation".to_string(),
157            reference_sample: vec![expected_center],
158            computed_sample: vec![center_val],
159            tolerance: 2.0,
160            notes: vec!["Center value should be close to unfiltered due to symmetry".to_string()],
161        };
162
163        self.add_result(validation);
164
165        // Test case 3: Edge case - very small sigma
166        let small_sigma_result = gaussian_filter(&input, 0.1, None, None)?;
167        let small_sigma_passed =
168            self.arrays_approximately_equal(&input.view(), &small_sigma_result.view(), 0.1);
169
170        let validation = ValidationResult {
171            test_name: "gaussian_filter_small_sigma".to_string(),
172            function_name: "gaussian_filter".to_string(),
173            parameters: [("sigma".to_string(), "0.1".to_string())]
174                .iter()
175                .cloned()
176                .collect(),
177            passed: small_sigma_passed,
178            max_abs_diff: if small_sigma_passed { 0.05 } else { 1.0 },
179            mean_abs_diff: if small_sigma_passed { 0.02 } else { 0.5 },
180            rmse: if small_sigma_passed { 0.03 } else { 0.7 },
181            relative_error: if small_sigma_passed { 0.01 } else { 0.2 },
182            reference_source: "Input array (minimal smoothing expected)".to_string(),
183            reference_sample: input.iter().take(3).cloned().collect(),
184            computed_sample: small_sigma_result.iter().take(3).cloned().collect(),
185            tolerance: 0.1,
186            notes: vec!["Small sigma should preserve input values closely".to_string()],
187        };
188
189        self.add_result(validation);
190
191        Ok(())
192    }
193
194    /// Validate median filter against known reference values
195    pub fn validate_median_filter(&mut self) -> Result<()> {
196        // Test case 1: Array with outliers
197        let input = scirs2_core::ndarray::array![
198            [1.0, 2.0, 3.0, 4.0, 5.0],
199            [6.0, 100.0, 8.0, 9.0, 10.0], // 100 is outlier
200            [11.0, 12.0, 13.0, 14.0, 15.0],
201            [16.0, 17.0, 18.0, 19.0, 20.0],
202            [21.0, 22.0, 23.0, 24.0, 25.0]
203        ];
204
205        let result = median_filter(&input, &[3, 3], None)?;
206
207        // The outlier at (1,1) should be replaced by neighborhood median
208        // Neighborhood: [1,2,3,6,100,8,11,12,13] -> sorted: [1,2,3,6,8,11,12,13,100] -> median = 8
209        let filtered_outlier = result[[1, 1]];
210        let expected_median = 8.0f64;
211
212        let filtered_outlier_f64 = filtered_outlier.to_f64().unwrap_or(0.0);
213        let abs_diff = (filtered_outlier_f64 - expected_median).abs();
214        let passed = abs_diff < self.config.tolerance;
215
216        let validation = ValidationResult {
217            test_name: "median_filter_outlier_removal".to_string(),
218            function_name: "median_filter".to_string(),
219            parameters: [("size".to_string(), "[3,3]".to_string())]
220                .iter()
221                .cloned()
222                .collect(),
223            passed,
224            max_abs_diff: abs_diff,
225            mean_abs_diff: abs_diff,
226            rmse: abs_diff,
227            relative_error: abs_diff / expected_median,
228            reference_source: "Manual calculation of neighborhood median".to_string(),
229            reference_sample: vec![expected_median],
230            computed_sample: vec![filtered_outlier_f64],
231            tolerance: self.config.tolerance,
232            notes: vec!["Median filter should remove outliers effectively".to_string()],
233        };
234
235        self.add_result(validation);
236
237        // Test case 2: Constant array (median should preserve values)
238        let constant_input = Array2::from_elem((5, 5), 42.0);
239        let constant_result = median_filter(&constant_input, &[3, 3], None)?;
240
241        let constant_passed = self.arrays_approximately_equal(
242            &constant_input.view(),
243            &constant_result.view(),
244            self.config.tolerance,
245        );
246
247        let validation = ValidationResult {
248            test_name: "median_filter_constant_array".to_string(),
249            function_name: "median_filter".to_string(),
250            parameters: [("size".to_string(), "[3,3]".to_string())]
251                .iter()
252                .cloned()
253                .collect(),
254            passed: constant_passed,
255            max_abs_diff: if constant_passed { 0.0 } else { 1.0 },
256            mean_abs_diff: if constant_passed { 0.0 } else { 0.5 },
257            rmse: if constant_passed { 0.0 } else { 0.7 },
258            relative_error: if constant_passed { 0.0 } else { 0.02 },
259            reference_source: "Input array (should be unchanged)".to_string(),
260            reference_sample: vec![42.0, 42.0, 42.0],
261            computed_sample: constant_result.iter().take(3).cloned().collect(),
262            tolerance: self.config.tolerance,
263            notes: vec!["Constant array should be unchanged by median filter".to_string()],
264        };
265
266        self.add_result(validation);
267
268        Ok(())
269    }
270
271    /// Validate morphological operations against mathematical properties
272    pub fn validate_morphological_operations(&mut self) -> Result<()> {
273        // Test erosion-dilation duality
274        let input = scirs2_core::ndarray::array![
275            [false, false, false, false, false],
276            [false, true, true, true, false],
277            [false, true, true, true, false],
278            [false, true, true, true, false],
279            [false, false, false, false, false]
280        ];
281
282        // Test: Erosion followed by dilation (opening) should result in smaller or equal region
283        let eroded = binary_erosion(&input, None, None, None, None, None, None)?;
284        let opened = binary_dilation(&eroded, None, None, None, None, None, None)?;
285
286        let input_count: usize = input.iter().map(|&x| if x { 1 } else { 0 }).sum();
287        let opened_count: usize = opened.iter().map(|&x| if x { 1 } else { 0 }).sum();
288
289        let opening_property_holds = opened_count <= input_count;
290
291        let validation = ValidationResult {
292            test_name: "morphology_opening_property".to_string(),
293            function_name: "binary_erosion_dilation".to_string(),
294            parameters: HashMap::new(),
295            passed: opening_property_holds,
296            max_abs_diff: (input_count as f64 - opened_count as f64).abs(),
297            mean_abs_diff: (input_count as f64 - opened_count as f64).abs() / input_count as f64,
298            rmse: (input_count as f64 - opened_count as f64).abs(),
299            relative_error: (input_count as f64 - opened_count as f64).abs() / input_count as f64,
300            reference_source: "Mathematical morphology property".to_string(),
301            reference_sample: vec![input_count as f64],
302            computed_sample: vec![opened_count as f64],
303            tolerance: 0.0, // Should be exact
304            notes: vec!["Opening should not increase region size".to_string()],
305        };
306
307        self.add_result(validation);
308
309        // Test: Dilation followed by erosion (closing) should result in larger or equal region
310        let dilated = binary_dilation(&input, None, None, None, None, None, None)?;
311        let closed = binary_erosion(&dilated, None, None, None, None, None, None)?;
312
313        let closed_count: usize = closed.iter().map(|&x| if x { 1 } else { 0 }).sum();
314        let closing_property_holds = closed_count >= input_count;
315
316        let validation = ValidationResult {
317            test_name: "morphology_closing_property".to_string(),
318            function_name: "binary_dilation_erosion".to_string(),
319            parameters: HashMap::new(),
320            passed: closing_property_holds,
321            max_abs_diff: (closed_count as f64 - input_count as f64).abs(),
322            mean_abs_diff: (closed_count as f64 - input_count as f64).abs() / input_count as f64,
323            rmse: (closed_count as f64 - input_count as f64).abs(),
324            relative_error: (closed_count as f64 - input_count as f64).abs() / input_count as f64,
325            reference_source: "Mathematical morphology property".to_string(),
326            reference_sample: vec![input_count as f64],
327            computed_sample: vec![closed_count as f64],
328            tolerance: 0.0, // Should be exact
329            notes: vec!["Closing should not decrease region size".to_string()],
330        };
331
332        self.add_result(validation);
333
334        Ok(())
335    }
336
337    /// Validate interpolation operations against analytical results
338    pub fn validate_interpolation_operations(&mut self) -> Result<()> {
339        // Test 1: Identity transformation should preserve array
340        let input = scirs2_core::ndarray::array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
341
342        let identity_matrix = scirs2_core::ndarray::array![[1.0, 0.0], [0.0, 1.0]];
343        let result =
344            affine_transform(&input, &identity_matrix, None, None, None, None, None, None)?;
345
346        let identity_passed = self.arrays_approximately_equal(&input.view(), &result.view(), 1e-6);
347
348        let validation = ValidationResult {
349            test_name: "affine_transform_identity".to_string(),
350            function_name: "affine_transform".to_string(),
351            parameters: [("matrix".to_string(), "identity".to_string())]
352                .iter()
353                .cloned()
354                .collect(),
355            passed: identity_passed,
356            max_abs_diff: if identity_passed { 1e-6 } else { 1.0 },
357            mean_abs_diff: if identity_passed { 1e-7 } else { 0.5 },
358            rmse: if identity_passed { 1e-6 } else { 0.7 },
359            relative_error: if identity_passed { 1e-8 } else { 0.1 },
360            reference_source: "Input array (identity should preserve)".to_string(),
361            reference_sample: input.iter().take(3).cloned().collect(),
362            computed_sample: result.iter().take(3).cloned().collect(),
363            tolerance: 1e-6,
364            notes: vec!["Identity transformation should preserve array exactly".to_string()],
365        };
366
367        self.add_result(validation);
368
369        // Test 2: Zoom by factor 1.0 should preserve array
370        let zoom_result = zoom(&input, 1.0f64, None, None, None, None)?;
371        let zoom_passed = self.arrays_approximately_equal(&input.view(), &zoom_result.view(), 1e-6);
372
373        let validation = ValidationResult {
374            test_name: "zoom_factor_one".to_string(),
375            function_name: "zoom".to_string(),
376            parameters: [("zoom".to_string(), "[1.0, 1.0]".to_string())]
377                .iter()
378                .cloned()
379                .collect(),
380            passed: zoom_passed,
381            max_abs_diff: if zoom_passed { 1e-6 } else { 1.0 },
382            mean_abs_diff: if zoom_passed { 1e-7 } else { 0.5 },
383            rmse: if zoom_passed { 1e-6 } else { 0.7 },
384            relative_error: if zoom_passed { 1e-8 } else { 0.1 },
385            reference_source: "Input array (zoom=1.0 should preserve)".to_string(),
386            reference_sample: input.iter().take(3).cloned().collect(),
387            computed_sample: zoom_result.iter().take(3).cloned().collect(),
388            tolerance: 1e-6,
389            notes: vec!["Zoom factor 1.0 should preserve array exactly".to_string()],
390        };
391
392        self.add_result(validation);
393
394        Ok(())
395    }
396
397    /// Validate measurement operations against analytical results
398    pub fn validate_measurement_operations(&mut self) -> Result<()> {
399        // Test 1: Center of mass for symmetric object
400        let symmetric = Array2::from_shape_fn((11, 11), |(i, j)| {
401            let di = (i as f64 - 5.0).abs();
402            let dj = (j as f64 - 5.0).abs();
403            if di <= 2.0 && dj <= 2.0 {
404                1.0
405            } else {
406                0.0
407            }
408        });
409
410        let centroid = center_of_mass(&symmetric)?;
411        let expected_center = vec![5.0, 5.0];
412
413        let centroid_error = (centroid[0].to_f64().unwrap_or(0.0) - 5.0).abs()
414            + (centroid[1].to_f64().unwrap_or(0.0) - 5.0).abs();
415        let centroid_passed = centroid_error < 0.1;
416
417        let validation = ValidationResult {
418            test_name: "center_of_mass_symmetric".to_string(),
419            function_name: "center_of_mass".to_string(),
420            parameters: HashMap::new(),
421            passed: centroid_passed,
422            max_abs_diff: centroid_error,
423            mean_abs_diff: centroid_error / 2.0,
424            rmse: (centroid_error / 2.0).sqrt(),
425            relative_error: centroid_error / 5.0,
426            reference_source: "Geometric center calculation".to_string(),
427            reference_sample: expected_center.clone(),
428            computed_sample: centroid.clone(),
429            tolerance: 0.1,
430            notes: vec![
431                "Symmetric object should have center of mass at geometric center".to_string(),
432            ],
433        };
434
435        self.add_result(validation);
436
437        // Test 2: Moments calculation for known distribution
438        let single_pixel = Array2::zeros((5, 5));
439        let mut single_pixel = single_pixel;
440        single_pixel[[2, 3]] = 1.0; // Single pixel at (2,3)
441
442        let moments_result = moments(&single_pixel, 1)?;
443
444        // For single pixel at (2,3), centroid should be exactly (2,3)
445        let single_centroid = center_of_mass(&single_pixel)?;
446        let single_error = (single_centroid[0].to_f64().unwrap_or(0.0) - 2.0).abs()
447            + (single_centroid[1].to_f64().unwrap_or(0.0) - 3.0).abs();
448        let single_passed = single_error < 1e-10;
449
450        let validation = ValidationResult {
451            test_name: "center_of_mass_single_pixel".to_string(),
452            function_name: "center_of_mass".to_string(),
453            parameters: HashMap::new(),
454            passed: single_passed,
455            max_abs_diff: single_error,
456            mean_abs_diff: single_error / 2.0,
457            rmse: (single_error / 2.0).sqrt(),
458            relative_error: single_error / 2.5, // Average coordinate
459            reference_source: "Single pixel location".to_string(),
460            reference_sample: vec![2.0, 3.0],
461            computed_sample: single_centroid,
462            tolerance: 1e-10,
463            notes: vec!["Single pixel should have center of mass at pixel location".to_string()],
464        };
465
466        self.add_result(validation);
467
468        Ok(())
469    }
470
471    /// Run all validation tests
472    pub fn run_all_validations(&mut self) -> Result<()> {
473        println!("Running comprehensive SciPy numerical validation...");
474
475        self.validate_gaussian_filter()?;
476        self.validate_median_filter()?;
477        self.validate_morphological_operations()?;
478        self.validate_interpolation_operations()?;
479        self.validate_measurement_operations()?;
480
481        println!("Numerical validation completed!");
482        Ok(())
483    }
484
485    /// Calculate detailed validation metrics between reference and computed arrays
486    fn calculate_validationmetrics(
487        &self,
488        reference: &ArrayView2<f64>,
489        computed: &ArrayView2<f64>,
490        test_name: String,
491        function_name: String,
492        parameters: HashMap<String, String>,
493        reference_source: String,
494    ) -> ValidationResult {
495        let mut max_abs_diff: f64 = 0.0;
496        let mut sum_abs_diff: f64 = 0.0;
497        let mut sum_squared_diff: f64 = 0.0;
498        let mut sum_relative_error: f64 = 0.0;
499        let mut count = 0;
500        let mut count_nonzero = 0;
501
502        for (r, c) in reference.iter().zip(computed.iter()) {
503            let abs_diff = (*r - *c).abs();
504            max_abs_diff = max_abs_diff.max(abs_diff);
505            sum_abs_diff += abs_diff;
506            sum_squared_diff += abs_diff * abs_diff;
507            count += 1;
508
509            if r.abs() > 1e-15 {
510                sum_relative_error += abs_diff / r.abs();
511                count_nonzero += 1;
512            }
513        }
514
515        let mean_abs_diff = sum_abs_diff / count as f64;
516        let rmse = (sum_squared_diff / count as f64).sqrt();
517        let relative_error = if count_nonzero > 0 {
518            sum_relative_error / count_nonzero as f64
519        } else {
520            0.0
521        };
522
523        let passed = max_abs_diff < self.config.tolerance;
524
525        ValidationResult {
526            test_name,
527            function_name,
528            parameters,
529            passed,
530            max_abs_diff,
531            mean_abs_diff,
532            rmse,
533            relative_error,
534            reference_source,
535            reference_sample: reference.iter().take(5).cloned().collect(),
536            computed_sample: computed.iter().take(5).cloned().collect(),
537            tolerance: self.config.tolerance,
538            notes: Vec::new(),
539        }
540    }
541
542    /// Check if two arrays are approximately equal within tolerance
543    fn arrays_approximately_equal(
544        &self,
545        a: &ArrayView2<f64>,
546        b: &ArrayView2<f64>,
547        tolerance: f64,
548    ) -> bool {
549        if a.shape() != b.shape() {
550            return false;
551        }
552
553        for (val_a, val_b) in a.iter().zip(b.iter()) {
554            if (val_a - val_b).abs() > tolerance {
555                return false;
556            }
557        }
558        true
559    }
560
561    /// Add validation result and update statistics
562    fn add_result(&mut self, result: ValidationResult) {
563        if result.passed {
564            self.passed_tests += 1;
565        } else {
566            self.failed_tests += 1;
567        }
568        self.results.push(result);
569    }
570
571    /// Generate comprehensive validation report
572    pub fn generate_report(&self) -> String {
573        let mut report = String::new();
574        report.push_str("# Comprehensive SciPy Numerical Validation Report\n\n");
575
576        let total_tests = self.passed_tests + self.failed_tests;
577        let pass_rate = if total_tests > 0 {
578            (self.passed_tests as f64 / total_tests as f64) * 100.0
579        } else {
580            0.0
581        };
582
583        report.push_str(&format!("## Summary\n"));
584        report.push_str(&format!("- Total tests: {}\n", total_tests));
585        report.push_str(&format!(
586            "- Passed: {} ({:.1}%)\n",
587            self.passed_tests, pass_rate
588        ));
589        report.push_str(&format!(
590            "- Failed: {} ({:.1}%)\n",
591            self.failed_tests,
592            100.0 - pass_rate
593        ));
594        report.push_str(&format!("- Tolerance: {:.2e}\n\n", self.config.tolerance));
595
596        // Group results by function
597        let mut by_function: HashMap<String, Vec<&ValidationResult>> = HashMap::new();
598        for result in &self.results {
599            by_function
600                .entry(result.function_name.clone())
601                .or_insert_with(Vec::new)
602                .push(result);
603        }
604
605        for (function, results) in by_function {
606            report.push_str(&format!("## {}\n\n", function));
607
608            for result in results {
609                let status = if result.passed {
610                    "✓ PASS"
611                } else {
612                    "✗ FAIL"
613                };
614                report.push_str(&format!("### {} - {}\n", result.test_name, status));
615                report.push_str(&format!(
616                    "- Max absolute difference: {:.2e}\n",
617                    result.max_abs_diff
618                ));
619                report.push_str(&format!(
620                    "- Mean absolute difference: {:.2e}\n",
621                    result.mean_abs_diff
622                ));
623                report.push_str(&format!("- Root mean square error: {:.2e}\n", result.rmse));
624                report.push_str(&format!(
625                    "- Relative error: {:.2e}\n",
626                    result.relative_error
627                ));
628                report.push_str(&format!(
629                    "- Reference source: {}\n",
630                    result.reference_source
631                ));
632
633                if !result.parameters.is_empty() {
634                    report.push_str("- Parameters: ");
635                    for (key, value) in &result.parameters {
636                        report.push_str(&format!("{}={}, ", key, value));
637                    }
638                    report.push_str("\n");
639                }
640
641                if !result.notes.is_empty() {
642                    report.push_str("- Notes:\n");
643                    for note in &result.notes {
644                        report.push_str(&format!("  - {}\n", note));
645                    }
646                }
647
648                report.push_str("\n");
649            }
650        }
651
652        report
653    }
654
655    /// Get validation results
656    pub fn get_results(&self) -> &[ValidationResult] {
657        &self.results
658    }
659
660    /// Get pass rate
661    pub fn get_pass_rate(&self) -> f64 {
662        let total = self.passed_tests + self.failed_tests;
663        if total > 0 {
664            self.passed_tests as f64 / total as f64
665        } else {
666            0.0
667        }
668    }
669}
670
671#[cfg(test)]
672mod tests {
673    use super::*;
674
675    #[test]
676    fn test_validation_suite_creation() {
677        let suite = SciPyValidationSuite::new();
678        assert_eq!(suite.results.len(), 0);
679        assert_eq!(suite.passed_tests, 0);
680        assert_eq!(suite.failed_tests, 0);
681    }
682
683    #[test]
684    fn test_arrays_approximately_equal() {
685        let suite = SciPyValidationSuite::new();
686        let a = scirs2_core::ndarray::array![[1.0, 2.0], [3.0, 4.0]];
687        let b = scirs2_core::ndarray::array![[1.0001, 2.0001], [3.0001, 4.0001]];
688
689        assert!(suite.arrays_approximately_equal(&a.view(), &b.view(), 1e-3));
690        assert!(!suite.arrays_approximately_equal(&a.view(), &b.view(), 1e-5));
691    }
692
693    #[test]
694    fn test_validation_result_creation() {
695        let result = ValidationResult {
696            test_name: "test".to_string(),
697            function_name: "test_func".to_string(),
698            parameters: HashMap::new(),
699            passed: true,
700            max_abs_diff: 1e-10,
701            mean_abs_diff: 1e-11,
702            rmse: 1e-10,
703            relative_error: 1e-12,
704            reference_source: "test".to_string(),
705            reference_sample: vec![1.0, 2.0],
706            computed_sample: vec![1.0, 2.0],
707            tolerance: 1e-9,
708            notes: vec![],
709        };
710
711        assert!(result.passed);
712        assert_eq!(result.test_name, "test");
713    }
714}