1type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
9use crate::filters::*;
10use crate::interpolation::zoom;
11use crate::interpolation::*;
12use crate::measurements::*;
13use crate::morphology::*;
14use scirs2_core::ndarray::{Array2, ArrayView2};
15use scirs2_core::numeric::ToPrimitive;
16use std::collections::HashMap;
17
18#[derive(Debug, Clone)]
20pub struct ValidationResult {
21 pub test_name: String,
23 pub function_name: String,
25 pub parameters: HashMap<String, String>,
27 pub passed: bool,
29 pub max_abs_diff: f64,
31 pub mean_abs_diff: f64,
33 pub rmse: f64,
35 pub relative_error: f64,
37 pub reference_source: String,
39 pub reference_sample: Vec<f64>,
41 pub computed_sample: Vec<f64>,
43 pub tolerance: f64,
45 pub notes: Vec<String>,
47}
48
49#[derive(Debug, Clone)]
51pub struct ValidationConfig {
52 pub tolerance: f64,
54 pub test_edge_cases: bool,
56 pub test_large_arrays: bool,
58 pub max_test_size: usize,
60 pub num_random_tests: usize,
62 pub random_seed: u64,
64}
65
66impl Default for ValidationConfig {
67 fn default() -> Self {
68 Self {
69 tolerance: 1e-10,
70 test_edge_cases: true,
71 test_large_arrays: false, max_test_size: 1000,
73 num_random_tests: 10,
74 random_seed: 42,
75 }
76 }
77}
78
79pub struct SciPyValidationSuite {
81 config: ValidationConfig,
82 results: Vec<ValidationResult>,
83 passed_tests: usize,
84 failed_tests: usize,
85}
86
87impl SciPyValidationSuite {
88 pub fn new() -> Self {
90 Self::with_config(ValidationConfig::default())
91 }
92
93 pub fn with_config(config: ValidationConfig) -> Self {
95 Self {
96 config,
97 results: Vec::new(),
98 passed_tests: 0,
99 failed_tests: 0,
100 }
101 }
102
103 pub fn validate_gaussian_filter(&mut self) -> Result<()> {
105 let input = scirs2_core::ndarray::array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
108
109 let reference = scirs2_core::ndarray::array![
110 [2.9347, 3.9745, 4.5966],
111 [4.6665, 5.0000, 5.3335],
112 [5.4034, 6.0255, 7.0653]
113 ];
114
115 let result = gaussian_filter(&input, 1.0, None, None)?;
116
117 let validation = self.calculate_validationmetrics(
118 &reference.view(),
119 &result.view(),
120 "gaussian_filter_3x3_sigma1".to_string(),
121 "gaussian_filter".to_string(),
122 [("sigma".to_string(), "1.0".to_string())]
123 .iter()
124 .cloned()
125 .collect(),
126 "SciPy 1.11.0 reference values".to_string(),
127 );
128
129 self.add_result(validation);
130
131 let large_input = Array2::from_shape_fn((10, 10), |(i, j)| (i + j) as f64);
133
134 let result_large = gaussian_filter(&large_input, 2.0, None, None)?;
136
137 let center_val = result_large[[5, 5]];
139 let expected_center = 10.0; let center_diff = (center_val - expected_center).abs();
142 let passed = center_diff < 2.0; let validation = ValidationResult {
145 test_name: "gaussian_filter_10x10_sigma2".to_string(),
146 function_name: "gaussian_filter".to_string(),
147 parameters: [("sigma".to_string(), "2.0".to_string())]
148 .iter()
149 .cloned()
150 .collect(),
151 passed,
152 max_abs_diff: center_diff,
153 mean_abs_diff: center_diff,
154 rmse: center_diff,
155 relative_error: center_diff / expected_center,
156 reference_source: "Analytical expectation".to_string(),
157 reference_sample: vec![expected_center],
158 computed_sample: vec![center_val],
159 tolerance: 2.0,
160 notes: vec!["Center value should be close to unfiltered due to symmetry".to_string()],
161 };
162
163 self.add_result(validation);
164
165 let small_sigma_result = gaussian_filter(&input, 0.1, None, None)?;
167 let small_sigma_passed =
168 self.arrays_approximately_equal(&input.view(), &small_sigma_result.view(), 0.1);
169
170 let validation = ValidationResult {
171 test_name: "gaussian_filter_small_sigma".to_string(),
172 function_name: "gaussian_filter".to_string(),
173 parameters: [("sigma".to_string(), "0.1".to_string())]
174 .iter()
175 .cloned()
176 .collect(),
177 passed: small_sigma_passed,
178 max_abs_diff: if small_sigma_passed { 0.05 } else { 1.0 },
179 mean_abs_diff: if small_sigma_passed { 0.02 } else { 0.5 },
180 rmse: if small_sigma_passed { 0.03 } else { 0.7 },
181 relative_error: if small_sigma_passed { 0.01 } else { 0.2 },
182 reference_source: "Input array (minimal smoothing expected)".to_string(),
183 reference_sample: input.iter().take(3).cloned().collect(),
184 computed_sample: small_sigma_result.iter().take(3).cloned().collect(),
185 tolerance: 0.1,
186 notes: vec!["Small sigma should preserve input values closely".to_string()],
187 };
188
189 self.add_result(validation);
190
191 Ok(())
192 }
193
194 pub fn validate_median_filter(&mut self) -> Result<()> {
196 let input = scirs2_core::ndarray::array![
198 [1.0, 2.0, 3.0, 4.0, 5.0],
199 [6.0, 100.0, 8.0, 9.0, 10.0], [11.0, 12.0, 13.0, 14.0, 15.0],
201 [16.0, 17.0, 18.0, 19.0, 20.0],
202 [21.0, 22.0, 23.0, 24.0, 25.0]
203 ];
204
205 let result = median_filter(&input, &[3, 3], None)?;
206
207 let filtered_outlier = result[[1, 1]];
210 let expected_median = 8.0f64;
211
212 let filtered_outlier_f64 = filtered_outlier.to_f64().unwrap_or(0.0);
213 let abs_diff = (filtered_outlier_f64 - expected_median).abs();
214 let passed = abs_diff < self.config.tolerance;
215
216 let validation = ValidationResult {
217 test_name: "median_filter_outlier_removal".to_string(),
218 function_name: "median_filter".to_string(),
219 parameters: [("size".to_string(), "[3,3]".to_string())]
220 .iter()
221 .cloned()
222 .collect(),
223 passed,
224 max_abs_diff: abs_diff,
225 mean_abs_diff: abs_diff,
226 rmse: abs_diff,
227 relative_error: abs_diff / expected_median,
228 reference_source: "Manual calculation of neighborhood median".to_string(),
229 reference_sample: vec![expected_median],
230 computed_sample: vec![filtered_outlier_f64],
231 tolerance: self.config.tolerance,
232 notes: vec!["Median filter should remove outliers effectively".to_string()],
233 };
234
235 self.add_result(validation);
236
237 let constant_input = Array2::from_elem((5, 5), 42.0);
239 let constant_result = median_filter(&constant_input, &[3, 3], None)?;
240
241 let constant_passed = self.arrays_approximately_equal(
242 &constant_input.view(),
243 &constant_result.view(),
244 self.config.tolerance,
245 );
246
247 let validation = ValidationResult {
248 test_name: "median_filter_constant_array".to_string(),
249 function_name: "median_filter".to_string(),
250 parameters: [("size".to_string(), "[3,3]".to_string())]
251 .iter()
252 .cloned()
253 .collect(),
254 passed: constant_passed,
255 max_abs_diff: if constant_passed { 0.0 } else { 1.0 },
256 mean_abs_diff: if constant_passed { 0.0 } else { 0.5 },
257 rmse: if constant_passed { 0.0 } else { 0.7 },
258 relative_error: if constant_passed { 0.0 } else { 0.02 },
259 reference_source: "Input array (should be unchanged)".to_string(),
260 reference_sample: vec![42.0, 42.0, 42.0],
261 computed_sample: constant_result.iter().take(3).cloned().collect(),
262 tolerance: self.config.tolerance,
263 notes: vec!["Constant array should be unchanged by median filter".to_string()],
264 };
265
266 self.add_result(validation);
267
268 Ok(())
269 }
270
271 pub fn validate_morphological_operations(&mut self) -> Result<()> {
273 let input = scirs2_core::ndarray::array![
275 [false, false, false, false, false],
276 [false, true, true, true, false],
277 [false, true, true, true, false],
278 [false, true, true, true, false],
279 [false, false, false, false, false]
280 ];
281
282 let eroded = binary_erosion(&input, None, None, None, None, None, None)?;
284 let opened = binary_dilation(&eroded, None, None, None, None, None, None)?;
285
286 let input_count: usize = input.iter().map(|&x| if x { 1 } else { 0 }).sum();
287 let opened_count: usize = opened.iter().map(|&x| if x { 1 } else { 0 }).sum();
288
289 let opening_property_holds = opened_count <= input_count;
290
291 let validation = ValidationResult {
292 test_name: "morphology_opening_property".to_string(),
293 function_name: "binary_erosion_dilation".to_string(),
294 parameters: HashMap::new(),
295 passed: opening_property_holds,
296 max_abs_diff: (input_count as f64 - opened_count as f64).abs(),
297 mean_abs_diff: (input_count as f64 - opened_count as f64).abs() / input_count as f64,
298 rmse: (input_count as f64 - opened_count as f64).abs(),
299 relative_error: (input_count as f64 - opened_count as f64).abs() / input_count as f64,
300 reference_source: "Mathematical morphology property".to_string(),
301 reference_sample: vec![input_count as f64],
302 computed_sample: vec![opened_count as f64],
303 tolerance: 0.0, notes: vec!["Opening should not increase region size".to_string()],
305 };
306
307 self.add_result(validation);
308
309 let dilated = binary_dilation(&input, None, None, None, None, None, None)?;
311 let closed = binary_erosion(&dilated, None, None, None, None, None, None)?;
312
313 let closed_count: usize = closed.iter().map(|&x| if x { 1 } else { 0 }).sum();
314 let closing_property_holds = closed_count >= input_count;
315
316 let validation = ValidationResult {
317 test_name: "morphology_closing_property".to_string(),
318 function_name: "binary_dilation_erosion".to_string(),
319 parameters: HashMap::new(),
320 passed: closing_property_holds,
321 max_abs_diff: (closed_count as f64 - input_count as f64).abs(),
322 mean_abs_diff: (closed_count as f64 - input_count as f64).abs() / input_count as f64,
323 rmse: (closed_count as f64 - input_count as f64).abs(),
324 relative_error: (closed_count as f64 - input_count as f64).abs() / input_count as f64,
325 reference_source: "Mathematical morphology property".to_string(),
326 reference_sample: vec![input_count as f64],
327 computed_sample: vec![closed_count as f64],
328 tolerance: 0.0, notes: vec!["Closing should not decrease region size".to_string()],
330 };
331
332 self.add_result(validation);
333
334 Ok(())
335 }
336
337 pub fn validate_interpolation_operations(&mut self) -> Result<()> {
339 let input = scirs2_core::ndarray::array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
341
342 let identity_matrix = scirs2_core::ndarray::array![[1.0, 0.0], [0.0, 1.0]];
343 let result =
344 affine_transform(&input, &identity_matrix, None, None, None, None, None, None)?;
345
346 let identity_passed = self.arrays_approximately_equal(&input.view(), &result.view(), 1e-6);
347
348 let validation = ValidationResult {
349 test_name: "affine_transform_identity".to_string(),
350 function_name: "affine_transform".to_string(),
351 parameters: [("matrix".to_string(), "identity".to_string())]
352 .iter()
353 .cloned()
354 .collect(),
355 passed: identity_passed,
356 max_abs_diff: if identity_passed { 1e-6 } else { 1.0 },
357 mean_abs_diff: if identity_passed { 1e-7 } else { 0.5 },
358 rmse: if identity_passed { 1e-6 } else { 0.7 },
359 relative_error: if identity_passed { 1e-8 } else { 0.1 },
360 reference_source: "Input array (identity should preserve)".to_string(),
361 reference_sample: input.iter().take(3).cloned().collect(),
362 computed_sample: result.iter().take(3).cloned().collect(),
363 tolerance: 1e-6,
364 notes: vec!["Identity transformation should preserve array exactly".to_string()],
365 };
366
367 self.add_result(validation);
368
369 let zoom_result = zoom(&input, 1.0f64, None, None, None, None)?;
371 let zoom_passed = self.arrays_approximately_equal(&input.view(), &zoom_result.view(), 1e-6);
372
373 let validation = ValidationResult {
374 test_name: "zoom_factor_one".to_string(),
375 function_name: "zoom".to_string(),
376 parameters: [("zoom".to_string(), "[1.0, 1.0]".to_string())]
377 .iter()
378 .cloned()
379 .collect(),
380 passed: zoom_passed,
381 max_abs_diff: if zoom_passed { 1e-6 } else { 1.0 },
382 mean_abs_diff: if zoom_passed { 1e-7 } else { 0.5 },
383 rmse: if zoom_passed { 1e-6 } else { 0.7 },
384 relative_error: if zoom_passed { 1e-8 } else { 0.1 },
385 reference_source: "Input array (zoom=1.0 should preserve)".to_string(),
386 reference_sample: input.iter().take(3).cloned().collect(),
387 computed_sample: zoom_result.iter().take(3).cloned().collect(),
388 tolerance: 1e-6,
389 notes: vec!["Zoom factor 1.0 should preserve array exactly".to_string()],
390 };
391
392 self.add_result(validation);
393
394 Ok(())
395 }
396
397 pub fn validate_measurement_operations(&mut self) -> Result<()> {
399 let symmetric = Array2::from_shape_fn((11, 11), |(i, j)| {
401 let di = (i as f64 - 5.0).abs();
402 let dj = (j as f64 - 5.0).abs();
403 if di <= 2.0 && dj <= 2.0 {
404 1.0
405 } else {
406 0.0
407 }
408 });
409
410 let centroid = center_of_mass(&symmetric)?;
411 let expected_center = vec![5.0, 5.0];
412
413 let centroid_error = (centroid[0].to_f64().unwrap_or(0.0) - 5.0).abs()
414 + (centroid[1].to_f64().unwrap_or(0.0) - 5.0).abs();
415 let centroid_passed = centroid_error < 0.1;
416
417 let validation = ValidationResult {
418 test_name: "center_of_mass_symmetric".to_string(),
419 function_name: "center_of_mass".to_string(),
420 parameters: HashMap::new(),
421 passed: centroid_passed,
422 max_abs_diff: centroid_error,
423 mean_abs_diff: centroid_error / 2.0,
424 rmse: (centroid_error / 2.0).sqrt(),
425 relative_error: centroid_error / 5.0,
426 reference_source: "Geometric center calculation".to_string(),
427 reference_sample: expected_center.clone(),
428 computed_sample: centroid.clone(),
429 tolerance: 0.1,
430 notes: vec![
431 "Symmetric object should have center of mass at geometric center".to_string(),
432 ],
433 };
434
435 self.add_result(validation);
436
437 let single_pixel = Array2::zeros((5, 5));
439 let mut single_pixel = single_pixel;
440 single_pixel[[2, 3]] = 1.0; let moments_result = moments(&single_pixel, 1)?;
443
444 let single_centroid = center_of_mass(&single_pixel)?;
446 let single_error = (single_centroid[0].to_f64().unwrap_or(0.0) - 2.0).abs()
447 + (single_centroid[1].to_f64().unwrap_or(0.0) - 3.0).abs();
448 let single_passed = single_error < 1e-10;
449
450 let validation = ValidationResult {
451 test_name: "center_of_mass_single_pixel".to_string(),
452 function_name: "center_of_mass".to_string(),
453 parameters: HashMap::new(),
454 passed: single_passed,
455 max_abs_diff: single_error,
456 mean_abs_diff: single_error / 2.0,
457 rmse: (single_error / 2.0).sqrt(),
458 relative_error: single_error / 2.5, reference_source: "Single pixel location".to_string(),
460 reference_sample: vec![2.0, 3.0],
461 computed_sample: single_centroid,
462 tolerance: 1e-10,
463 notes: vec!["Single pixel should have center of mass at pixel location".to_string()],
464 };
465
466 self.add_result(validation);
467
468 Ok(())
469 }
470
471 pub fn run_all_validations(&mut self) -> Result<()> {
473 println!("Running comprehensive SciPy numerical validation...");
474
475 self.validate_gaussian_filter()?;
476 self.validate_median_filter()?;
477 self.validate_morphological_operations()?;
478 self.validate_interpolation_operations()?;
479 self.validate_measurement_operations()?;
480
481 println!("Numerical validation completed!");
482 Ok(())
483 }
484
485 fn calculate_validationmetrics(
487 &self,
488 reference: &ArrayView2<f64>,
489 computed: &ArrayView2<f64>,
490 test_name: String,
491 function_name: String,
492 parameters: HashMap<String, String>,
493 reference_source: String,
494 ) -> ValidationResult {
495 let mut max_abs_diff: f64 = 0.0;
496 let mut sum_abs_diff: f64 = 0.0;
497 let mut sum_squared_diff: f64 = 0.0;
498 let mut sum_relative_error: f64 = 0.0;
499 let mut count = 0;
500 let mut count_nonzero = 0;
501
502 for (r, c) in reference.iter().zip(computed.iter()) {
503 let abs_diff = (*r - *c).abs();
504 max_abs_diff = max_abs_diff.max(abs_diff);
505 sum_abs_diff += abs_diff;
506 sum_squared_diff += abs_diff * abs_diff;
507 count += 1;
508
509 if r.abs() > 1e-15 {
510 sum_relative_error += abs_diff / r.abs();
511 count_nonzero += 1;
512 }
513 }
514
515 let mean_abs_diff = sum_abs_diff / count as f64;
516 let rmse = (sum_squared_diff / count as f64).sqrt();
517 let relative_error = if count_nonzero > 0 {
518 sum_relative_error / count_nonzero as f64
519 } else {
520 0.0
521 };
522
523 let passed = max_abs_diff < self.config.tolerance;
524
525 ValidationResult {
526 test_name,
527 function_name,
528 parameters,
529 passed,
530 max_abs_diff,
531 mean_abs_diff,
532 rmse,
533 relative_error,
534 reference_source,
535 reference_sample: reference.iter().take(5).cloned().collect(),
536 computed_sample: computed.iter().take(5).cloned().collect(),
537 tolerance: self.config.tolerance,
538 notes: Vec::new(),
539 }
540 }
541
542 fn arrays_approximately_equal(
544 &self,
545 a: &ArrayView2<f64>,
546 b: &ArrayView2<f64>,
547 tolerance: f64,
548 ) -> bool {
549 if a.shape() != b.shape() {
550 return false;
551 }
552
553 for (val_a, val_b) in a.iter().zip(b.iter()) {
554 if (val_a - val_b).abs() > tolerance {
555 return false;
556 }
557 }
558 true
559 }
560
561 fn add_result(&mut self, result: ValidationResult) {
563 if result.passed {
564 self.passed_tests += 1;
565 } else {
566 self.failed_tests += 1;
567 }
568 self.results.push(result);
569 }
570
571 pub fn generate_report(&self) -> String {
573 let mut report = String::new();
574 report.push_str("# Comprehensive SciPy Numerical Validation Report\n\n");
575
576 let total_tests = self.passed_tests + self.failed_tests;
577 let pass_rate = if total_tests > 0 {
578 (self.passed_tests as f64 / total_tests as f64) * 100.0
579 } else {
580 0.0
581 };
582
583 report.push_str(&format!("## Summary\n"));
584 report.push_str(&format!("- Total tests: {}\n", total_tests));
585 report.push_str(&format!(
586 "- Passed: {} ({:.1}%)\n",
587 self.passed_tests, pass_rate
588 ));
589 report.push_str(&format!(
590 "- Failed: {} ({:.1}%)\n",
591 self.failed_tests,
592 100.0 - pass_rate
593 ));
594 report.push_str(&format!("- Tolerance: {:.2e}\n\n", self.config.tolerance));
595
596 let mut by_function: HashMap<String, Vec<&ValidationResult>> = HashMap::new();
598 for result in &self.results {
599 by_function
600 .entry(result.function_name.clone())
601 .or_insert_with(Vec::new)
602 .push(result);
603 }
604
605 for (function, results) in by_function {
606 report.push_str(&format!("## {}\n\n", function));
607
608 for result in results {
609 let status = if result.passed {
610 "✓ PASS"
611 } else {
612 "✗ FAIL"
613 };
614 report.push_str(&format!("### {} - {}\n", result.test_name, status));
615 report.push_str(&format!(
616 "- Max absolute difference: {:.2e}\n",
617 result.max_abs_diff
618 ));
619 report.push_str(&format!(
620 "- Mean absolute difference: {:.2e}\n",
621 result.mean_abs_diff
622 ));
623 report.push_str(&format!("- Root mean square error: {:.2e}\n", result.rmse));
624 report.push_str(&format!(
625 "- Relative error: {:.2e}\n",
626 result.relative_error
627 ));
628 report.push_str(&format!(
629 "- Reference source: {}\n",
630 result.reference_source
631 ));
632
633 if !result.parameters.is_empty() {
634 report.push_str("- Parameters: ");
635 for (key, value) in &result.parameters {
636 report.push_str(&format!("{}={}, ", key, value));
637 }
638 report.push_str("\n");
639 }
640
641 if !result.notes.is_empty() {
642 report.push_str("- Notes:\n");
643 for note in &result.notes {
644 report.push_str(&format!(" - {}\n", note));
645 }
646 }
647
648 report.push_str("\n");
649 }
650 }
651
652 report
653 }
654
655 pub fn get_results(&self) -> &[ValidationResult] {
657 &self.results
658 }
659
660 pub fn get_pass_rate(&self) -> f64 {
662 let total = self.passed_tests + self.failed_tests;
663 if total > 0 {
664 self.passed_tests as f64 / total as f64
665 } else {
666 0.0
667 }
668 }
669}
670
671#[cfg(test)]
672mod tests {
673 use super::*;
674
675 #[test]
676 fn test_validation_suite_creation() {
677 let suite = SciPyValidationSuite::new();
678 assert_eq!(suite.results.len(), 0);
679 assert_eq!(suite.passed_tests, 0);
680 assert_eq!(suite.failed_tests, 0);
681 }
682
683 #[test]
684 fn test_arrays_approximately_equal() {
685 let suite = SciPyValidationSuite::new();
686 let a = scirs2_core::ndarray::array![[1.0, 2.0], [3.0, 4.0]];
687 let b = scirs2_core::ndarray::array![[1.0001, 2.0001], [3.0001, 4.0001]];
688
689 assert!(suite.arrays_approximately_equal(&a.view(), &b.view(), 1e-3));
690 assert!(!suite.arrays_approximately_equal(&a.view(), &b.view(), 1e-5));
691 }
692
693 #[test]
694 fn test_validation_result_creation() {
695 let result = ValidationResult {
696 test_name: "test".to_string(),
697 function_name: "test_func".to_string(),
698 parameters: HashMap::new(),
699 passed: true,
700 max_abs_diff: 1e-10,
701 mean_abs_diff: 1e-11,
702 rmse: 1e-10,
703 relative_error: 1e-12,
704 reference_source: "test".to_string(),
705 reference_sample: vec![1.0, 2.0],
706 computed_sample: vec![1.0, 2.0],
707 tolerance: 1e-9,
708 notes: vec![],
709 };
710
711 assert!(result.passed);
712 assert_eq!(result.test_name, "test");
713 }
714}