1use crate::graph::Graph;
7use crate::tensor::Tensor;
8use crate::Float;
9use scirs2_core::ScientificNumber;
10use std::collections::HashMap;
11use std::fmt;
12
13pub mod finite_differences;
14pub mod gradient_checking;
15pub mod numerical_analysis;
16pub mod stability_metrics;
17pub mod stability_test_framework;
18
19#[derive(Debug, Clone)]
21pub struct StabilityTestConfig {
22 pub gradient_tolerance: f64,
24 pub finite_diff_tolerance: f64,
26 pub finite_diff_step: f64,
28 pub num_test_points: usize,
30 pub check_second_order: bool,
32 pub max_condition_number: f64,
34 pub comprehensive_analysis: bool,
36}
37
38impl Default for StabilityTestConfig {
39 fn default() -> Self {
40 Self {
41 gradient_tolerance: 1e-5,
42 finite_diff_tolerance: 1e-6,
43 finite_diff_step: 1e-8,
44 num_test_points: 100,
45 check_second_order: false,
46 max_condition_number: 1e12,
47 comprehensive_analysis: true,
48 }
49 }
50}
51
52pub struct NumericalStabilityTester<F: Float> {
54 config: StabilityTestConfig,
55 phantom: std::marker::PhantomData<F>,
56}
57
58impl<F: Float> NumericalStabilityTester<F> {
59 pub fn new() -> Self {
61 Self {
62 config: StabilityTestConfig::default(),
63 phantom: std::marker::PhantomData,
64 }
65 }
66
67 pub fn with_config(config: StabilityTestConfig) -> Self {
69 Self {
70 config,
71 phantom: std::marker::PhantomData,
72 }
73 }
74
75 pub fn test_graph(&self, graph: &Graph<F>) -> Result<StabilityReport<F>, StabilityError> {
77 let mut report = StabilityReport::new();
78
79 let gradient_tests = self.test_gradient_accuracy(graph)?;
81 report.gradient_tests = gradient_tests;
82
83 let conditioning_tests = self.test_numerical_conditioning(graph)?;
85 report.conditioning_tests = conditioning_tests;
86
87 let perturbation_tests = self.test_perturbation_stability(graph)?;
89 report.perturbation_tests = perturbation_tests;
90
91 let overflow_tests = self.test_overflow_underflow(graph)?;
93 report.overflow_tests = overflow_tests;
94
95 report.overall_grade = self.compute_overall_grade(&report);
97
98 Ok(report)
99 }
100
101 fn test_gradient_accuracy(
103 &self,
104 selfgraph: &Graph<F>,
105 ) -> Result<GradientTestResults, StabilityError> {
106 let mut results = GradientTestResults {
107 tests_performed: 0,
108 tests_passed: 0,
109 max_error: 0.0,
110 mean_error: 0.0,
111 failed_tests: Vec::new(),
112 };
113
114 for _test_point in 0..self.config.num_test_points {
120 results.tests_performed += 1;
121
122 let analytical_grad = self.compute_analytical_gradient()?;
124 let finite_diff_grad = self.compute_finite_difference_gradient()?;
125
126 let error = self.compute_gradient_error(&analytical_grad, &finite_diff_grad);
127
128 if error < self.config.gradient_tolerance {
129 results.tests_passed += 1;
130 } else {
131 results.failed_tests.push(GradientTestFailure {
132 test_id: results.tests_performed,
133 error,
134 analytical_gradient: analytical_grad,
135 finite_diff_gradient: finite_diff_grad,
136 });
137 }
138
139 results.max_error = results.max_error.max(error);
140 results.mean_error += error;
141 }
142
143 if results.tests_performed > 0 {
144 results.mean_error /= results.tests_performed as f64;
145 }
146
147 Ok(results)
148 }
149
150 fn test_numerical_conditioning(
152 &self,
153 selfgraph: &Graph<F>,
154 ) -> Result<ConditioningTestResults, StabilityError> {
155 let mut results = ConditioningTestResults {
156 condition_numbers: HashMap::new(),
157 ill_conditioned_operations: Vec::new(),
158 stability_warnings: Vec::new(),
159 };
160
161 let operations_to_check = vec![
168 "matrix_inverse",
169 "solve_linear_system",
170 "eigenvalue_decomposition",
171 "singular_value_decomposition",
172 "division_operations",
173 ];
174
175 for op_name in operations_to_check {
176 let condition_number = self.estimate_condition_number(op_name)?;
177 results
178 .condition_numbers
179 .insert(op_name.to_string(), condition_number);
180
181 if condition_number > self.config.max_condition_number {
182 results
183 .ill_conditioned_operations
184 .push(IllConditionedOperation {
185 operation: op_name.to_string(),
186 condition_number,
187 severity: if condition_number > 1e15 {
188 ConditioningSeverity::Critical
189 } else if condition_number > 1e12 {
190 ConditioningSeverity::High
191 } else {
192 ConditioningSeverity::Medium
193 },
194 });
195 }
196 }
197
198 Ok(results)
199 }
200
201 fn test_perturbation_stability(
203 &self,
204 selfgraph: &Graph<F>,
205 ) -> Result<PerturbationTestResults, StabilityError> {
206 let mut results = PerturbationTestResults {
207 perturbation_tests: Vec::new(),
208 max_sensitivity: 0.0,
209 mean_sensitivity: 0.0,
210 };
211
212 for perturbation_magnitude in [
214 F::from(1e-8).expect("Failed to convert constant to float"),
215 F::from(1e-6).expect("Failed to convert constant to float"),
216 F::from(1e-4).expect("Failed to convert constant to float"),
217 F::from(1e-2).expect("Failed to convert constant to float"),
218 ] {
219 let sensitivity = self
220 .measure_perturbation_sensitivity(perturbation_magnitude.to_f64().unwrap_or(0.0))?;
221
222 results.perturbation_tests.push(PerturbationTest {
223 perturbation_magnitude: perturbation_magnitude.to_f64().unwrap_or(0.0),
224 output_change: sensitivity,
225 sensitivity_ratio: sensitivity / perturbation_magnitude.to_f64().unwrap_or(1.0),
226 });
227
228 results.max_sensitivity = results.max_sensitivity.max(sensitivity);
229 results.mean_sensitivity += sensitivity;
230 }
231
232 if !results.perturbation_tests.is_empty() {
233 results.mean_sensitivity /= results.perturbation_tests.len() as f64;
234 }
235
236 Ok(results)
237 }
238
239 fn test_overflow_underflow(
241 &self,
242 selfgraph: &Graph<F>,
243 ) -> Result<OverflowTestResults<F>, StabilityError> {
244 let mut results = OverflowTestResults {
245 overflow_risks: Vec::new(),
246 underflow_risks: Vec::new(),
247 safe_ranges: HashMap::new(),
248 };
249
250 let extreme_values = vec![
252 F::from(1e-100).expect("Failed to convert constant to float"), F::from(1e-10).expect("Failed to convert constant to float"), F::from(1e10).expect("Failed to convert constant to float"), F::from(1e100).expect("Failed to convert constant to float"), ];
257
258 for &extreme_value in &extreme_values {
259 let risk_assessment = self.assess_overflow_risk(extreme_value)?;
260
261 if risk_assessment.overflow_probability > 0.1 {
262 results.overflow_risks.push(OverflowRisk {
263 input_value: extreme_value,
264 operation: risk_assessment.risky_operation.clone(),
265 probability: risk_assessment.overflow_probability,
266 });
267 }
268
269 if risk_assessment.underflow_probability > 0.1 {
270 results.underflow_risks.push(UnderflowRisk {
271 input_value: extreme_value,
272 operation: risk_assessment.risky_operation,
273 probability: risk_assessment.underflow_probability,
274 });
275 }
276 }
277
278 Ok(results)
279 }
280
281 fn compute_analytical_gradient(&self) -> Result<Vec<f64>, StabilityError> {
283 Ok(vec![1.0, 2.0, 3.0])
285 }
286
287 fn compute_finite_difference_gradient(&self) -> Result<Vec<f64>, StabilityError> {
288 Ok(vec![1.0001, 1.9999, 3.0001])
290 }
291
292 fn compute_gradient_error(&self, analytical: &[f64], finitediff: &[f64]) -> f64 {
293 analytical
294 .iter()
295 .zip(finitediff.iter())
296 .map(|(&a, &f)| (a - f).abs())
297 .fold(0.0, f64::max)
298 }
299
300 fn estimate_condition_number(&self, operation: &str) -> Result<f64, StabilityError> {
301 Ok(1e6)
303 }
304
305 fn measure_perturbation_sensitivity(&self, perturbation: f64) -> Result<f64, StabilityError> {
306 Ok(perturbation * 1.5) }
309
310 fn assess_overflow_risk(&self, input: F) -> Result<OverflowRiskAssessment, StabilityError> {
311 Ok(OverflowRiskAssessment {
312 risky_operation: "exponential".to_string(),
313 overflow_probability: 0.05,
314 underflow_probability: 0.02,
315 })
316 }
317
318 fn compute_overall_grade(&self, report: &StabilityReport<F>) -> StabilityGrade {
319 let mut score = 100.0;
320
321 if report.gradient_tests.tests_performed > 0 {
323 let pass_rate = report.gradient_tests.tests_passed as f64
324 / report.gradient_tests.tests_performed as f64;
325 score *= pass_rate;
326 }
327
328 let conditioning_penalty =
330 report.conditioning_tests.ill_conditioned_operations.len() as f64 * 10.0;
331 score -= conditioning_penalty;
332
333 let overflow_penalty = (report.overflow_tests.overflow_risks.len()
335 + report.overflow_tests.underflow_risks.len()) as f64
336 * 5.0;
337 score -= overflow_penalty;
338
339 match score as i32 {
340 90..=100 => StabilityGrade::Excellent,
341 80..=89 => StabilityGrade::Good,
342 70..=79 => StabilityGrade::Fair,
343 60..=69 => StabilityGrade::Poor,
344 _ => StabilityGrade::Critical,
345 }
346 }
347}
348
349impl<F: Float> Default for NumericalStabilityTester<F> {
350 fn default() -> Self {
351 Self::new()
352 }
353}
354
355#[derive(Debug, Clone)]
357pub struct StabilityReport<F: Float> {
358 pub gradient_tests: GradientTestResults,
359 pub conditioning_tests: ConditioningTestResults,
360 pub perturbation_tests: PerturbationTestResults,
361 pub overflow_tests: OverflowTestResults<F>,
362 pub overall_grade: StabilityGrade,
363}
364
365impl<F: Float> Default for StabilityReport<F> {
366 fn default() -> Self {
367 Self::new()
368 }
369}
370
371impl<F: Float> StabilityReport<F> {
372 pub fn new() -> Self {
373 Self {
374 gradient_tests: GradientTestResults::default(),
375 conditioning_tests: ConditioningTestResults::default(),
376 perturbation_tests: PerturbationTestResults::default(),
377 overflow_tests: OverflowTestResults::default(),
378 overall_grade: StabilityGrade::Unknown,
379 }
380 }
381
382 pub fn print_report(&self) {
384 println!("Numerical Stability Report");
385 println!("==========================");
386 println!("Overall Grade: {:?}", self.overall_grade);
387 println!();
388
389 println!("Gradient Tests:");
390 println!(" Tests Performed: {}", self.gradient_tests.tests_performed);
391 println!(" Tests Passed: {}", self.gradient_tests.tests_passed);
392 println!(
393 " Pass Rate: {:.2}%",
394 if self.gradient_tests.tests_performed > 0 {
395 (self.gradient_tests.tests_passed as f64
396 / self.gradient_tests.tests_performed as f64)
397 * 100.0
398 } else {
399 0.0
400 }
401 );
402 println!(" Max Error: {:.2e}", self.gradient_tests.max_error);
403 println!(" Mean Error: {:.2e}", self.gradient_tests.mean_error);
404 println!();
405
406 println!("Conditioning Tests:");
407 println!(
408 " Ill-conditioned Operations: {}",
409 self.conditioning_tests.ill_conditioned_operations.len()
410 );
411 for op in &self.conditioning_tests.ill_conditioned_operations {
412 println!(
413 " {} (cond: {:.2e}, severity: {:?})",
414 op.operation, op.condition_number, op.severity
415 );
416 }
417 println!();
418
419 println!("Perturbation Tests:");
420 println!(
421 " Max Sensitivity: {:.2e}",
422 self.perturbation_tests.max_sensitivity
423 );
424 println!(
425 " Mean Sensitivity: {:.2e}",
426 self.perturbation_tests.mean_sensitivity
427 );
428 println!();
429
430 println!("Overflow/Underflow Tests:");
431 println!(
432 " Overflow Risks: {}",
433 self.overflow_tests.overflow_risks.len()
434 );
435 println!(
436 " Underflow Risks: {}",
437 self.overflow_tests.underflow_risks.len()
438 );
439 }
440}
441
442#[derive(Debug, Clone, Default)]
444pub struct GradientTestResults {
445 pub tests_performed: usize,
446 pub tests_passed: usize,
447 pub max_error: f64,
448 pub mean_error: f64,
449 pub failed_tests: Vec<GradientTestFailure>,
450}
451
452#[derive(Debug, Clone)]
454pub struct GradientTestFailure {
455 pub test_id: usize,
456 pub error: f64,
457 pub analytical_gradient: Vec<f64>,
458 pub finite_diff_gradient: Vec<f64>,
459}
460
461#[derive(Debug, Clone, Default)]
463pub struct ConditioningTestResults {
464 pub condition_numbers: HashMap<String, f64>,
465 pub ill_conditioned_operations: Vec<IllConditionedOperation>,
466 pub stability_warnings: Vec<String>,
467}
468
469#[derive(Debug, Clone)]
471pub struct IllConditionedOperation {
472 pub operation: String,
473 pub condition_number: f64,
474 pub severity: ConditioningSeverity,
475}
476
477#[derive(Debug, Clone, Copy, PartialEq)]
479pub enum ConditioningSeverity {
480 Low,
481 Medium,
482 High,
483 Critical,
484}
485
486#[derive(Debug, Clone, Default)]
488pub struct PerturbationTestResults {
489 pub perturbation_tests: Vec<PerturbationTest>,
490 pub max_sensitivity: f64,
491 pub mean_sensitivity: f64,
492}
493
494#[derive(Debug, Clone)]
496pub struct PerturbationTest {
497 pub perturbation_magnitude: f64,
498 pub output_change: f64,
499 pub sensitivity_ratio: f64,
500}
501
502#[derive(Debug, Clone)]
504pub struct OverflowTestResults<F: Float> {
505 pub overflow_risks: Vec<OverflowRisk<F>>,
506 pub underflow_risks: Vec<UnderflowRisk<F>>,
507 pub safe_ranges: HashMap<String, (f64, f64)>,
508}
509
510impl<F: Float> Default for OverflowTestResults<F> {
511 fn default() -> Self {
512 Self {
513 overflow_risks: Vec::new(),
514 underflow_risks: Vec::new(),
515 safe_ranges: HashMap::new(),
516 }
517 }
518}
519
520#[derive(Debug, Clone)]
522pub struct OverflowRisk<F: Float> {
523 pub input_value: F,
524 pub operation: String,
525 pub probability: f64,
526}
527
528#[derive(Debug, Clone)]
530pub struct UnderflowRisk<F: Float> {
531 pub input_value: F,
532 pub operation: String,
533 pub probability: f64,
534}
535
536#[derive(Debug, Clone)]
538pub struct OverflowRiskAssessment {
539 pub risky_operation: String,
540 pub overflow_probability: f64,
541 pub underflow_probability: f64,
542}
543
544#[derive(Debug, Clone, Copy, PartialEq)]
546pub enum StabilityGrade {
547 Excellent,
548 Good,
549 Fair,
550 Poor,
551 Critical,
552 Unknown,
553}
554
555impl fmt::Display for StabilityGrade {
556 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
557 match self {
558 StabilityGrade::Excellent => write!(f, "Excellent (A+)"),
559 StabilityGrade::Good => write!(f, "Good (A)"),
560 StabilityGrade::Fair => write!(f, "Fair (B)"),
561 StabilityGrade::Poor => write!(f, "Poor (C)"),
562 StabilityGrade::Critical => write!(f, "Critical (F)"),
563 StabilityGrade::Unknown => write!(f, "Unknown"),
564 }
565 }
566}
567
568#[derive(Debug, thiserror::Error)]
570pub enum StabilityError {
571 #[error("Computation error: {0}")]
572 ComputationError(String),
573 #[error("Gradient computation failed: {0}")]
574 GradientError(String),
575 #[error("Numerical error: {0}")]
576 NumericalError(String),
577 #[error("Configuration error: {0}")]
578 ConfigError(String),
579}
580
581#[allow(dead_code)]
584pub fn test_numerical_stability<F: Float>(
585 graph: &Graph<F>,
586) -> Result<StabilityReport<F>, StabilityError> {
587 let tester = NumericalStabilityTester::new();
588 tester.test_graph(graph)
589}
590
591#[allow(dead_code)]
593pub fn test_numerical_stability_with_config<F: Float>(
594 graph: &Graph<F>,
595 config: StabilityTestConfig,
596) -> Result<StabilityReport<F>, StabilityError> {
597 let tester = NumericalStabilityTester::with_config(config);
598 tester.test_graph(graph)
599}
600
601#[allow(dead_code)]
603pub fn quick_gradient_check<F: Float>(
604 _inputs: &[Tensor<F>],
605 _output: &Tensor<F>,
606) -> Result<bool, StabilityError> {
607 Ok(true)
609}
610
611#[allow(dead_code)]
613pub fn assess_conditioning<F: Float>(
614 _operation_name: &str,
615 _inputs: &[Tensor<F>],
616) -> Result<f64, StabilityError> {
617 Ok(1e6)
619}
620
621#[cfg(test)]
622mod tests {
623 use super::*;
624
625 #[test]
626 fn test_stability_tester_creation() {
627 let _tester = NumericalStabilityTester::<f32>::new();
628 }
629
630 #[test]
631 fn test_stability_config() {
632 let config = StabilityTestConfig {
633 gradient_tolerance: 1e-6,
634 num_test_points: 50,
635 ..Default::default()
636 };
637
638 let _tester = NumericalStabilityTester::<f32>::with_config(config.clone());
639 assert_eq!(config.gradient_tolerance, 1e-6);
640 assert_eq!(config.num_test_points, 50);
641 }
642
643 #[test]
644 fn test_stability_report() {
645 let report: StabilityReport<f64> = StabilityReport::new();
646 assert!(matches!(report.overall_grade, StabilityGrade::Unknown));
647 }
648
649 #[test]
650 fn test_stability_grade_display() {
651 assert_eq!(format!("{}", StabilityGrade::Excellent), "Excellent (A+)");
652 assert_eq!(format!("{}", StabilityGrade::Poor), "Poor (C)");
653 assert_eq!(format!("{}", StabilityGrade::Critical), "Critical (F)");
654 }
655
656 #[test]
657 fn test_conditioning_severity() {
658 let operation = IllConditionedOperation {
659 operation: "matrix_inverse".to_string(),
660 condition_number: 1e15,
661 severity: ConditioningSeverity::Critical,
662 };
663
664 assert!(matches!(operation.severity, ConditioningSeverity::Critical));
665 assert!(operation.condition_number > 1e14);
666 }
667
668 #[test]
669 fn test_perturbation_test() {
670 let test = PerturbationTest {
671 perturbation_magnitude: 1e-8,
672 output_change: 1.5e-8,
673 sensitivity_ratio: 1.5,
674 };
675
676 let calculated_ratio = test.output_change / test.perturbation_magnitude;
677 assert!((test.sensitivity_ratio - calculated_ratio).abs() < 1e-14);
678 }
679}