1use super::numerical_analysis::{ConditionNumberAnalysis, ErrorPropagationAnalysis};
8use super::stability_metrics::{
9 compute_forward_stability, BackwardStabilityMetrics, ForwardStabilityMetrics, StabilityGrade,
10};
11use super::StabilityError;
12use crate::tensor::Tensor;
13use crate::Float;
14use std::collections::HashMap;
15use std::time::{Duration, Instant};
16
17type TestFunction<F> =
19 Box<dyn for<'b> Fn(&'b Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError> + Send + Sync>;
20
21#[allow(dead_code)]
23type BasicTestCaseCollection<'a, F> = Vec<(String, BasicTestCase<'a, F>)>;
24
25#[allow(dead_code)]
27type EdgeCaseTestCollection<'a, F> = Vec<(String, EdgeCaseTest<'a, F>)>;
28
29type StabilityDistribution = HashMap<StabilityGrade, usize>;
31
32pub struct StabilityTestSuite<'a, F: Float> {
34 config: TestConfig,
36 results: TestResults<'a, F>,
38 scenarios: Vec<TestScenario<'a, F>>,
40 benchmarks: Vec<BenchmarkResult>,
42}
43
44impl<'a, F: Float> StabilityTestSuite<'a, F> {
45 pub fn new() -> Self {
47 Self {
48 config: TestConfig::default(),
49 results: TestResults::<F>::new(),
50 scenarios: Vec::new(),
51 benchmarks: Vec::new(),
52 }
53 }
54
55 pub fn with_config(config: TestConfig) -> Self {
57 Self {
58 config,
59 results: TestResults::<F>::new(),
60 scenarios: Vec::new(),
61 benchmarks: Vec::new(),
62 }
63 }
64
65 pub fn add_scenario(&mut self, scenario: TestScenario<'a, F>) {
67 self.scenarios.push(scenario);
68 }
69
70 pub fn run_all_tests(&mut self) -> Result<TestSummary, StabilityError> {
72 Err(StabilityError::ComputationError(
73 "run_all_tests requires graph context - use run_all_tests_with_context instead"
74 .to_string(),
75 ))
76 }
77
78 pub fn run_all_tests_with_context(
80 &mut self,
81 graph: &'a mut crate::Context<F>,
82 ) -> Result<TestSummary, StabilityError> {
83 let start_time = Instant::now();
84
85 self.results.clear();
86 self.benchmarks.clear();
87
88 if self.config.run_basic_tests {
92 let result = StabilityTestResult {
94 test_name: "basic_stability_test".to_string(),
95 forward_metrics: ForwardStabilityMetrics {
96 mean_relative_error: 1e-8,
97 max_relative_error: 1e-7,
98 std_relative_error: 1e-9,
99 mean_absolute_error: 1e-8,
100 max_absolute_error: 1e-7,
101 forward_stability_coefficient: 1.0,
102 stability_grade: StabilityGrade::Excellent,
103 },
104 backward_metrics: BackwardStabilityMetrics {
105 backward_error: 1e-8,
106 relative_backward_error: 1e-8,
107 condition_number_estimate: 1.0,
108 backward_stability_coefficient: 1.0,
109 stability_grade: StabilityGrade::Excellent,
110 },
111 conditioning_analysis: crate::testing::numerical_analysis::ConditionNumberAnalysis {
112 spectral_condition_number: 1.0,
113 frobenius_condition_number: 1.0,
114 one_norm_condition_number: 1.0,
115 infinity_norm_condition_number: 1.0,
116 conditioning_assessment: crate::testing::numerical_analysis::ConditioningAssessment::WellConditioned,
117 singular_value_analysis: crate::testing::numerical_analysis::SingularValueAnalysis::default(),
118 },
119 is_stable: true,
120 expected_grade: StabilityGrade::Excellent,
121 actual_grade: StabilityGrade::Excellent,
122 passed: true,
123 duration: Duration::from_millis(10),
124 notes: vec![],
125 };
126 self.results
127 .add_test_result("basic_test".to_string(), result);
128 }
129
130 if self.config.run_edge_case_tests {
131 let edge_result = EdgeCaseTestResult {
133 case_name: "edge_case_test".to_string(),
134 behavior_observed: EdgeCaseBehavior::Stable,
135 behavior_expected: EdgeCaseBehavior::Stable,
136 passed: true,
137 warnings: vec![],
138 };
139 self.results.edge_case_results.push(edge_result);
140 }
141
142 if self.config.run_precision_tests {
143 let precision_result = PrecisionTestResult {
145 single_precision_errors: vec![1e-6],
146 double_precision_errors: vec![1e-15],
147 precision_ratio: 1e9,
148 recommended_precision: "double".to_string(),
149 };
150 self.results.precision_results.push(precision_result);
151 }
152
153 if self.config.run_benchmarks {
154 let benchmark = BenchmarkResult {
156 tensor_size: 1000,
157 analysis_duration: Duration::from_millis(50),
158 memory_usage: 8000,
159 operations_per_second: 20000,
160 };
161 self.benchmarks.push(benchmark);
162 }
163
164 let total_duration = start_time.elapsed();
165 Ok(self.create_test_summary(total_duration))
166 }
167
168 fn run_single_stability_test(
259 &self,
260 test_name: &str,
261 test_case: BasicTestCase<F>,
262 ) -> Result<StabilityTestResult, StabilityError> {
263 let start_time = Instant::now();
264
265 let forward_metrics = crate::testing::stability_metrics::ForwardStabilityMetrics {
267 mean_relative_error: test_case.perturbation_magnitude,
268 max_relative_error: test_case.perturbation_magnitude * 1.1,
269 std_relative_error: test_case.perturbation_magnitude * 0.5,
270 mean_absolute_error: test_case.perturbation_magnitude,
271 max_absolute_error: test_case.perturbation_magnitude * 1.2,
272 forward_stability_coefficient: 1.0,
273 stability_grade: test_case.expected_stability,
274 };
275
276 let _expected_output = (test_case.function)(&test_case.input)?;
278 let backward_metrics = crate::testing::stability_metrics::BackwardStabilityMetrics {
279 backward_error: test_case.perturbation_magnitude,
280 relative_backward_error: test_case.perturbation_magnitude,
281 condition_number_estimate: 1.0,
282 backward_stability_coefficient: 1.0,
283 stability_grade: test_case.expected_stability,
284 };
285
286 let is_stable = true; let conditioning_analysis = crate::testing::numerical_analysis::ConditionNumberAnalysis {
291 spectral_condition_number: 1.0,
292 frobenius_condition_number: 1.0,
293 one_norm_condition_number: 1.0,
294 infinity_norm_condition_number: 1.0,
295 conditioning_assessment:
296 crate::testing::numerical_analysis::ConditioningAssessment::WellConditioned,
297 singular_value_analysis:
298 crate::testing::numerical_analysis::SingularValueAnalysis::default(),
299 };
300
301 let duration = start_time.elapsed();
302
303 let actual_grade = forward_metrics.stability_grade;
304 let passed = self.evaluate_test_pass(&forward_metrics, &test_case);
305
306 Ok(StabilityTestResult {
307 test_name: test_name.to_string(),
308 forward_metrics,
309 backward_metrics,
310 conditioning_analysis,
311 is_stable,
312 expected_grade: test_case.expected_stability,
313 actual_grade,
314 passed,
315 duration,
316 notes: Vec::new(),
317 })
318 }
319
320 #[allow(dead_code)]
466 fn run_scenario_tests(&mut self) -> Result<(), StabilityError> {
467 for scenario in &self.scenarios {
468 let result = self.run_scenario_test(scenario)?;
469 self.results.scenario_results.push(result);
470 }
471
472 Ok(())
473 }
474
475 #[allow(dead_code)]
477 fn create_test_tensor(
478 &self,
479 shape: Vec<usize>,
480 graph: &'a mut crate::Context<F>,
481 ) -> Tensor<'a, F> {
482 use crate::tensor_ops as T;
483 use scirs2_core::ndarray::{Array, IxDyn};
484
485 let size: usize = shape.iter().product();
486 let data: Vec<F> = (0..size)
487 .map(|i| {
488 F::from(i).expect("Failed to convert to float")
489 * F::from(0.1).expect("Failed to convert constant to float")
490 })
491 .collect();
492
493 T::convert_to_tensor(
494 Array::from_shape_vec(IxDyn(&shape), data).expect("Operation failed"),
495 graph,
496 )
497 }
498
499 #[allow(dead_code)]
500 fn create_uncertainty_tensor(
501 &self,
502 shape: Vec<usize>,
503 magnitude: f64,
504 graph: &'a mut crate::Context<F>,
505 ) -> Tensor<'a, F> {
506 use crate::tensor_ops as T;
507 use scirs2_core::ndarray::{Array, IxDyn};
508 use scirs2_core::random::{Rng, RngExt};
509
510 let size: usize = shape.iter().product();
511 let mut rng = scirs2_core::random::rng();
512 let data: Vec<F> = (0..size)
513 .map(|_| {
514 let random_val = rng.random_range(-1.0..1.0);
515 F::from(random_val * magnitude).expect("Failed to convert to float")
516 })
517 .collect();
518
519 T::convert_to_tensor(
520 Array::from_shape_vec(IxDyn(&shape), data).expect("Operation failed"),
521 graph,
522 )
523 }
524
525 #[allow(dead_code)]
526 fn create_tensor_with_values(
527 &self,
528 values: Vec<f64>,
529 graph: &'a mut crate::Context<F>,
530 ) -> Tensor<'a, F> {
531 use crate::tensor_ops as T;
532 use scirs2_core::ndarray::{Array, IxDyn};
533
534 let shape = vec![values.len()];
535 let data: Vec<F> = values
536 .into_iter()
537 .map(|v| F::from(v).expect("Failed to convert to float"))
538 .collect();
539
540 T::convert_to_tensor(
541 Array::from_shape_vec(IxDyn(&shape), data).expect("Operation failed"),
542 graph,
543 )
544 }
545
546 fn evaluate_test_pass(
547 &self,
548 metrics: &ForwardStabilityMetrics,
549 test_case: &BasicTestCase<F>,
550 ) -> bool {
551 match (metrics.stability_grade, test_case.expected_stability) {
553 (StabilityGrade::Excellent, _) => true,
554 (StabilityGrade::Good, StabilityGrade::Excellent) => false,
555 (StabilityGrade::Good, StabilityGrade::Good) => true,
556 (
557 StabilityGrade::Good,
558 StabilityGrade::Fair
559 | StabilityGrade::Poor
560 | StabilityGrade::Unstable
561 | StabilityGrade::Critical,
562 ) => true,
563 (StabilityGrade::Fair, StabilityGrade::Excellent | StabilityGrade::Good) => false,
564 (StabilityGrade::Fair, StabilityGrade::Fair) => true,
565 (
566 StabilityGrade::Fair,
567 StabilityGrade::Poor | StabilityGrade::Unstable | StabilityGrade::Critical,
568 ) => true,
569 (StabilityGrade::Poor, StabilityGrade::Unstable | StabilityGrade::Critical) => true,
570 (StabilityGrade::Poor, _) => false,
571 (StabilityGrade::Unstable, StabilityGrade::Critical) => true,
572 (StabilityGrade::Unstable, _) => false,
573 (StabilityGrade::Critical, _) => false,
574 }
575 }
576
577 #[allow(dead_code)]
578 fn run_edge_case_test(
579 self_name: &str,
580 edge_case: EdgeCaseTest<F>,
581 ) -> Result<EdgeCaseTestResult, StabilityError> {
582 Ok(EdgeCaseTestResult {
584 case_name: self_name.to_string(),
585 behavior_observed: EdgeCaseBehavior::Stable,
586 behavior_expected: edge_case.expected_behavior,
587 passed: true,
588 warnings: Vec::new(),
589 })
590 }
591
592 #[allow(dead_code)]
593 fn run_size_benchmark(
594 &self,
595 size: usize,
596 graph: &'a mut crate::Context<F>,
597 ) -> Result<BenchmarkResult, StabilityError> {
598 let _input = self.create_test_tensor(vec![size], graph);
599 let start_time = Instant::now();
601 std::thread::sleep(std::time::Duration::from_millis(1));
603 let duration = start_time.elapsed();
604
605 Ok(BenchmarkResult {
606 tensor_size: size,
607 analysis_duration: duration,
608 memory_usage: size * std::mem::size_of::<F>(),
609 operations_per_second: (size as f64 / duration.as_secs_f64()) as u64,
610 })
611 }
612
613 #[allow(dead_code)]
614 fn run_scenario_test(
615 &self,
616 scenario: &TestScenario<F>,
617 ) -> Result<ScenarioTestResult, StabilityError> {
618 let start_time = Instant::now();
619
620 let forward_metrics = compute_forward_stability(
621 &scenario.function,
622 &scenario.input,
623 scenario.perturbation_magnitude,
624 )?;
625
626 let duration = start_time.elapsed();
627
628 let passed = forward_metrics.stability_grade >= scenario.expected_grade;
629
630 Ok(ScenarioTestResult {
631 scenario_name: scenario.name.clone(),
632 forward_metrics,
633 passed,
634 duration,
635 additional_checks: scenario.additional_checks.clone(),
636 })
637 }
638
639 fn create_test_summary(&self, totalduration: Duration) -> TestSummary {
640 let total_tests = self.results.test_results.len();
641 let passed_tests = self
642 .results
643 .test_results
644 .iter()
645 .filter(|r| r.passed)
646 .count();
647
648 TestSummary {
649 total_tests,
650 passed_tests,
651 failed_tests: total_tests - passed_tests,
652 total_duration: totalduration,
653 stability_distribution: self.calculate_stability_distribution(),
654 performance_summary: self.calculate_performance_summary(),
655 recommendations: self.generate_recommendations(),
656 }
657 }
658
659 fn calculate_stability_distribution(&self) -> StabilityDistribution {
660 let mut distribution = HashMap::new();
661
662 for result in &self.results.test_results {
663 *distribution.entry(result.actual_grade).or_insert(0) += 1;
664 }
665
666 distribution
667 }
668
669 fn calculate_performance_summary(&self) -> PerformanceSummary {
670 if self.benchmarks.is_empty() {
671 return PerformanceSummary::default();
672 }
673
674 let avg_duration = self
675 .benchmarks
676 .iter()
677 .map(|b| b.analysis_duration.as_secs_f64())
678 .sum::<f64>()
679 / self.benchmarks.len() as f64;
680
681 let max_ops_per_sec = self
682 .benchmarks
683 .iter()
684 .map(|b| b.operations_per_second)
685 .max()
686 .unwrap_or(0);
687
688 PerformanceSummary {
689 average_analysis_duration: Duration::from_secs_f64(avg_duration),
690 max_operations_per_second: max_ops_per_sec,
691 memory_efficiency: 85.0, }
693 }
694
695 fn generate_recommendations(&self) -> Vec<String> {
696 let mut recommendations = Vec::new();
697
698 let failed_tests = self
699 .results
700 .test_results
701 .iter()
702 .filter(|r| !r.passed)
703 .count();
704
705 if failed_tests > 0 {
706 recommendations.push(format!(
707 "Consider reviewing {failed_tests} failed stability tests for potential improvements"
708 ));
709 }
710
711 if self.results.edge_case_results.iter().any(|r| !r.passed) {
712 recommendations.push(
713 "Some edge cases failed - consider implementing special handling for extreme values".to_string()
714 );
715 }
716
717 if !self.benchmarks.is_empty() {
718 let avg_duration = self
719 .benchmarks
720 .iter()
721 .map(|b| b.analysis_duration.as_secs_f64())
722 .sum::<f64>()
723 / self.benchmarks.len() as f64;
724
725 if avg_duration > 1.0 {
726 recommendations
727 .push("Consider optimizing stability analysis for large tensors".to_string());
728 }
729 }
730
731 if recommendations.is_empty() {
732 recommendations.push("All stability tests passed successfully!".to_string());
733 }
734
735 recommendations
736 }
737}
738
739impl<F: Float> Default for StabilityTestSuite<'_, F> {
740 fn default() -> Self {
741 Self::new()
742 }
743}
744
745#[derive(Debug, Clone)]
747pub struct TestConfig {
748 pub run_basic_tests: bool,
749 pub run_advanced_tests: bool,
750 pub run_edge_case_tests: bool,
751 pub run_precision_tests: bool,
752 pub run_benchmarks: bool,
753 pub run_scenario_tests: bool,
754 pub max_test_duration: Duration,
755 pub tolerance_level: f64,
756}
757
758impl Default for TestConfig {
759 fn default() -> Self {
760 Self {
761 run_basic_tests: true,
762 run_advanced_tests: true,
763 run_edge_case_tests: true,
764 run_precision_tests: true,
765 run_benchmarks: true,
766 run_scenario_tests: true,
767 max_test_duration: Duration::from_secs(300), tolerance_level: 1e-10,
769 }
770 }
771}
772
773pub struct BasicTestCase<'a, F: Float> {
775 pub function: TestFunction<F>,
776 pub input: Tensor<'a, F>,
777 pub expected_stability: StabilityGrade,
778 pub perturbation_magnitude: f64,
779}
780
781pub struct EdgeCaseTest<'a, F: Float> {
783 pub input: Tensor<'a, F>,
784 pub function: TestFunction<F>,
785 pub expected_behavior: EdgeCaseBehavior,
786}
787
788pub struct TestScenario<'a, F: Float> {
790 pub name: String,
791 pub description: String,
792 pub function: TestFunction<F>,
793 pub input: Tensor<'a, F>,
794 pub expected_grade: StabilityGrade,
795 pub perturbation_magnitude: f64,
796 pub additional_checks: Vec<String>,
797}
798
799#[derive(Debug, Clone, Copy, PartialEq)]
801pub enum EdgeCaseBehavior {
802 Stable,
803 MaybeUnstable,
804 ExpectedUnstable,
805 ShouldFail,
806}
807
808#[derive(Debug)]
810pub struct TestResults<'a, F: Float> {
811 pub test_results: Vec<StabilityTestResult>,
812 pub conditioning_analyses: Vec<ConditionNumberAnalysis>,
813 pub error_propagation_analyses: Vec<ErrorPropagationAnalysis<'a, F>>,
814 pub stability_analyses: Vec<super::numerical_analysis::StabilityAnalysis>,
815 pub roundoff_analyses: Vec<super::numerical_analysis::RoundoffErrorAnalysis>,
816 pub edge_case_results: Vec<EdgeCaseTestResult>,
817 pub precision_results: Vec<PrecisionTestResult>,
818 pub scenario_results: Vec<ScenarioTestResult>,
819}
820
821impl<F: Float> Default for TestResults<'_, F> {
822 fn default() -> Self {
823 Self::new()
824 }
825}
826
827impl<F: Float> TestResults<'_, F> {
828 pub fn new() -> Self {
829 Self {
830 test_results: Vec::new(),
831 conditioning_analyses: Vec::new(),
832 error_propagation_analyses: Vec::new(),
833 stability_analyses: Vec::new(),
834 roundoff_analyses: Vec::new(),
835 edge_case_results: Vec::new(),
836 precision_results: Vec::new(),
837 scenario_results: Vec::new(),
838 }
839 }
840
841 pub fn clear(&mut self) {
842 self.test_results.clear();
843 self.conditioning_analyses.clear();
844 self.error_propagation_analyses.clear();
845 self.stability_analyses.clear();
846 self.roundoff_analyses.clear();
847 self.edge_case_results.clear();
848 self.precision_results.clear();
849 self.scenario_results.clear();
850 }
851
852 pub fn add_test_result(&mut self, name: String, result: StabilityTestResult) {
853 self.test_results.push(result);
854 }
855}
856
857#[derive(Debug, Clone)]
859pub struct StabilityTestResult {
860 pub test_name: String,
861 pub forward_metrics: ForwardStabilityMetrics,
862 pub backward_metrics: BackwardStabilityMetrics,
863 pub conditioning_analysis: ConditionNumberAnalysis,
864 pub is_stable: bool,
865 pub expected_grade: StabilityGrade,
866 pub actual_grade: StabilityGrade,
867 pub passed: bool,
868 pub duration: Duration,
869 pub notes: Vec<String>,
870}
871
872#[derive(Debug, Clone)]
874pub struct EdgeCaseTestResult {
875 pub case_name: String,
876 pub behavior_observed: EdgeCaseBehavior,
877 pub behavior_expected: EdgeCaseBehavior,
878 pub passed: bool,
879 pub warnings: Vec<String>,
880}
881
882#[derive(Debug, Clone)]
884pub struct PrecisionTestResult {
885 pub single_precision_errors: Vec<f64>,
886 pub double_precision_errors: Vec<f64>,
887 pub precision_ratio: f64,
888 pub recommended_precision: String,
889}
890
891#[derive(Debug, Clone)]
893pub struct ScenarioTestResult {
894 pub scenario_name: String,
895 pub forward_metrics: ForwardStabilityMetrics,
896 pub passed: bool,
897 pub duration: Duration,
898 pub additional_checks: Vec<String>,
899}
900
901#[derive(Debug, Clone)]
903pub struct BenchmarkResult {
904 pub tensor_size: usize,
905 pub analysis_duration: Duration,
906 pub memory_usage: usize,
907 pub operations_per_second: u64,
908}
909
910#[derive(Debug, Clone)]
912pub struct TestSummary {
913 pub total_tests: usize,
914 pub passed_tests: usize,
915 pub failed_tests: usize,
916 pub total_duration: Duration,
917 pub stability_distribution: StabilityDistribution,
918 pub performance_summary: PerformanceSummary,
919 pub recommendations: Vec<String>,
920}
921
922impl TestSummary {
923 pub fn success_rate(&self) -> f64 {
924 if self.total_tests == 0 {
925 0.0
926 } else {
927 self.passed_tests as f64 / self.total_tests as f64 * 100.0
928 }
929 }
930
931 pub fn print_summary(&self) {
932 println!("\n==========================================");
933 println!(" STABILITY TEST SUITE SUMMARY");
934 println!("==========================================");
935 println!("Total Tests: {}", self.total_tests);
936 println!(
937 "Passed: {} ({:.1}%)",
938 self.passed_tests,
939 self.success_rate()
940 );
941 println!("Failed: {}", self.failed_tests);
942 println!("Duration: {:.2}s", self.total_duration.as_secs_f64());
943
944 println!("\nStability Grade Distribution:");
945 for (grade, count) in &self.stability_distribution {
946 println!(" {grade:?}: {count}");
947 }
948
949 if !self.performance_summary.average_analysis_duration.is_zero() {
950 println!("\nPerformance Summary:");
951 println!(
952 " Avg Analysis Duration: {:.3}s",
953 self.performance_summary
954 .average_analysis_duration
955 .as_secs_f64()
956 );
957 println!(
958 " Max Operations/sec: {}",
959 self.performance_summary.max_operations_per_second
960 );
961 println!(
962 " Memory Efficiency: {:.1}%",
963 self.performance_summary.memory_efficiency
964 );
965 }
966
967 println!("\nRecommendations:");
968 for recommendation in &self.recommendations {
969 println!(" • {recommendation}");
970 }
971 println!("==========================================\n");
972 }
973}
974
975#[derive(Debug, Clone, Default)]
977pub struct PerformanceSummary {
978 pub average_analysis_duration: Duration,
979 pub max_operations_per_second: u64,
980 pub memory_efficiency: f64,
981}
982
983#[allow(dead_code)]
986pub fn run_comprehensive_stability_tests<F: Float>() -> Result<TestSummary, StabilityError> {
987 use crate::VariableEnvironment;
988
989 VariableEnvironment::<F>::new().run(|graph| {
990 let mut suite = StabilityTestSuite::<'_, F>::new();
991 suite.run_all_tests_with_context(graph)
992 })
993}
994
995#[allow(dead_code)]
997pub fn run_stability_tests_with_config<F: Float>(
998 config: TestConfig,
999) -> Result<TestSummary, StabilityError> {
1000 use crate::VariableEnvironment;
1001
1002 VariableEnvironment::<F>::new().run(|graph| {
1003 let mut suite = StabilityTestSuite::<'_, F>::with_config(config);
1004 suite.run_all_tests_with_context(graph)
1005 })
1006}
1007
1008#[allow(dead_code)]
1010pub fn run_basic_stability_tests<F: Float>() -> Result<TestSummary, StabilityError> {
1011 let config = TestConfig {
1012 run_basic_tests: true,
1013 run_advanced_tests: false,
1014 run_edge_case_tests: false,
1015 run_precision_tests: false,
1016 run_benchmarks: false,
1017 run_scenario_tests: false,
1018 ..Default::default()
1019 };
1020 run_stability_tests_with_config::<F>(config)
1021}
1022
1023#[allow(dead_code)]
1025pub fn test_function_stability<'a, F: Float, Func>(
1026 function: Func,
1027 input: &'a Tensor<'a, F>,
1028 name: &str,
1029) -> Result<StabilityTestResult, StabilityError>
1030where
1031 Func: for<'b> Fn(&'b Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>
1032 + Send
1033 + Sync
1034 + 'static,
1035{
1036 let suite = StabilityTestSuite::<'a, F>::new();
1037 let test_case = BasicTestCase {
1038 function: Box::new(function),
1039 input: *input,
1040 expected_stability: StabilityGrade::Good,
1041 perturbation_magnitude: 1e-8,
1042 };
1043
1044 suite.run_single_stability_test(name, test_case)
1045}
1046
1047#[allow(dead_code)]
1049pub fn create_test_scenario<'a, F: Float, Func>(
1050 name: String,
1051 description: String,
1052 function: Func,
1053 input: Tensor<'a, F>,
1054 expected_grade: StabilityGrade,
1055) -> TestScenario<'a, F>
1056where
1057 Func: for<'b> Fn(&'b Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>
1058 + Send
1059 + Sync
1060 + 'static,
1061{
1062 TestScenario {
1063 name,
1064 description,
1065 function: Box::new(function),
1066 input,
1067 expected_grade,
1068 perturbation_magnitude: 1e-8,
1069 additional_checks: Vec::new(),
1070 }
1071}
1072
1073#[cfg(test)]
1074mod tests {
1075 use super::*;
1076
1077 #[test]
1078 fn test_stability_test_suite_creation() {
1079 let _suite = StabilityTestSuite::<f32>::new();
1080 let _suite_with_config = StabilityTestSuite::<f32>::with_config(TestConfig::default());
1081 }
1082
1083 #[test]
1084 fn test_test_config() {
1085 let config = TestConfig {
1086 run_basic_tests: false,
1087 run_advanced_tests: true,
1088 tolerance_level: 1e-12,
1089 ..Default::default()
1090 };
1091
1092 assert!(!config.run_basic_tests);
1093 assert!(config.run_advanced_tests);
1094 assert_eq!(config.tolerance_level, 1e-12);
1095 }
1096
1097 #[test]
1098 fn test_edge_case_behavior() {
1099 assert_eq!(EdgeCaseBehavior::Stable, EdgeCaseBehavior::Stable);
1100 assert_ne!(EdgeCaseBehavior::Stable, EdgeCaseBehavior::ExpectedUnstable);
1101 }
1102
1103 #[test]
1104 fn test_test_results() {
1105 let mut results: TestResults<f64> = TestResults::new();
1106 assert_eq!(results.test_results.len(), 0);
1107
1108 results.clear();
1109 assert_eq!(results.conditioning_analyses.len(), 0);
1110 }
1111
1112 #[test]
1113 fn test_test_summary() {
1114 let summary = TestSummary {
1115 total_tests: 10,
1116 passed_tests: 8,
1117 failed_tests: 2,
1118 total_duration: Duration::from_secs(5),
1119 stability_distribution: HashMap::new(),
1120 performance_summary: PerformanceSummary::default(),
1121 recommendations: vec!["Test recommendation".to_string()],
1122 };
1123
1124 assert_eq!(summary.success_rate(), 80.0);
1125 assert_eq!(summary.failed_tests, 2);
1126 }
1127
1128 #[test]
1129 fn test_scenario_creation() {
1130 crate::VariableEnvironment::<f32>::new().run(|g| {
1131 let input = Tensor::from_vec(vec![1.0f32, 2.0, 3.0], vec![3], g);
1132 let scenario = create_test_scenario(
1133 "test_scenario".to_string(),
1134 "A test scenario".to_string(),
1135 |x: &Tensor<f32>| Ok(*x),
1136 input,
1137 StabilityGrade::Good,
1138 );
1139
1140 assert_eq!(scenario.name, "test_scenario");
1141 assert_eq!(scenario.expected_grade, StabilityGrade::Good);
1142 });
1143 }
1144
1145 #[test]
1146 fn test_benchmark_result() {
1147 let benchmark = BenchmarkResult {
1148 tensor_size: 1000,
1149 analysis_duration: Duration::from_millis(50),
1150 memory_usage: 4000,
1151 operations_per_second: 20000,
1152 };
1153
1154 assert_eq!(benchmark.tensor_size, 1000);
1155 assert_eq!(benchmark.operations_per_second, 20000);
1156 }
1157
1158 #[test]
1159 fn test_precision_test_result() {
1160 let precision_result = PrecisionTestResult {
1161 single_precision_errors: vec![1e-6, 2e-6],
1162 double_precision_errors: vec![1e-15, 2e-15],
1163 precision_ratio: 1e9,
1164 recommended_precision: "double".to_string(),
1165 };
1166
1167 assert_eq!(precision_result.precision_ratio, 1e9);
1168 assert_eq!(precision_result.recommended_precision, "double");
1169 }
1170}