1use crate::error::StatsResult;
15use scirs2_core::ndarray::{Array1, ArrayView1};
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18
19#[derive(Debug)]
21pub struct NumericalStabilityAnalyzer {
22 config: StabilityConfig,
23 analysis_results: HashMap<String, StabilityAnalysisResult>,
24}
25
26#[derive(Debug, Clone)]
28pub struct StabilityConfig {
29 pub zero_tolerance: f64,
31 pub precision_tolerance: f64,
33 pub max_condition_number: f64,
35 pub perturbation_tests: usize,
37 pub perturbation_magnitude: f64,
39 pub test_extreme_values: bool,
41 pub test_singular_cases: bool,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct StabilityAnalysisResult {
48 pub function_name: String,
50 pub stability_grade: StabilityGrade,
52 pub condition_analysis: ConditionNumberAnalysis,
54 pub error_propagation: ErrorPropagationAnalysis,
56 pub edge_case_robustness: EdgeCaseRobustness,
58 pub precision_analysis: PrecisionAnalysis,
60 pub recommendations: Vec<StabilityRecommendation>,
62 pub stability_score: f64,
64}
65
66#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
68pub enum StabilityGrade {
69 Excellent,
71 Good,
73 Acceptable,
75 Poor,
77 Unstable,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct ConditionNumberAnalysis {
84 pub condition_number: f64,
86 pub conditioning_class: ConditioningClass,
88 pub accuracy_loss_digits: f64,
90 pub input_sensitivity: f64,
92}
93
94#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
96pub enum ConditioningClass {
97 WellConditioned,
99 ModeratelyConditioned,
101 PoorlyConditioned,
103 NearlySingular,
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct ErrorPropagationAnalysis {
110 pub forward_error_bound: f64,
112 pub backward_error_bound: f64,
114 pub error_amplification: f64,
116 pub rounding_error_stability: f64,
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct EdgeCaseRobustness {
123 pub handles_infinity: bool,
125 pub handles_nan: bool,
127 pub handles_zero: bool,
129 pub handles_large_values: bool,
131 pub handles_small_values: bool,
133 pub edge_case_success_rate: f64,
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct PrecisionAnalysis {
140 pub precision_loss_bits: f64,
142 pub relative_precision: f64,
144 pub cancellation_errors: Vec<CancellationError>,
146 pub overflow_underflow_risk: OverflowRisk,
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct CancellationError {
153 pub location: String,
155 pub precision_loss: f64,
157 pub mitigation: String,
159}
160
161#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
163pub enum OverflowRisk {
164 None,
166 Low,
168 Moderate,
170 High,
172 Certain,
174}
175
176#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct StabilityRecommendation {
179 pub recommendation_type: RecommendationType,
181 pub description: String,
183 pub suggestion: String,
185 pub priority: RecommendationPriority,
187 pub expected_improvement: f64,
189}
190
191#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
193pub enum RecommendationType {
194 Algorithm,
196 Numerical,
198 InputValidation,
200 Precision,
202 ErrorHandling,
204}
205
206#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
208pub enum RecommendationPriority {
209 Critical,
211 High,
213 Medium,
215 Low,
217}
218
219impl Default for StabilityConfig {
220 fn default() -> Self {
221 Self {
222 zero_tolerance: 1e-15,
223 precision_tolerance: 1e-12,
224 max_condition_number: 1e12,
225 perturbation_tests: 100,
226 perturbation_magnitude: 1e-10,
227 test_extreme_values: true,
228 test_singular_cases: true,
229 }
230 }
231}
232
233impl NumericalStabilityAnalyzer {
234 pub fn new(config: StabilityConfig) -> Self {
236 Self {
237 config,
238 analysis_results: HashMap::new(),
239 }
240 }
241
242 pub fn default() -> Self {
244 Self::new(StabilityConfig::default())
245 }
246
247 pub fn analyze_function<F>(
249 &mut self,
250 function_name: &str,
251 function: F,
252 testdata: &ArrayView1<f64>,
253 ) -> StatsResult<StabilityAnalysisResult>
254 where
255 F: Fn(&ArrayView1<f64>) -> StatsResult<f64>,
256 {
257 let condition_analysis = self.analyze_condition_number(testdata)?;
259
260 let error_propagation = self.analyze_error_propagation(&function, testdata)?;
262
263 let edge_case_robustness = self.test_edge_case_robustness(&function)?;
265
266 let precision_analysis = self.analyze_precision_loss(&function, testdata)?;
268
269 let recommendations = self.generate_recommendations(
271 &condition_analysis,
272 &error_propagation,
273 &edge_case_robustness,
274 &precision_analysis,
275 );
276
277 let stability_score = self.calculate_stability_score(
279 &condition_analysis,
280 &error_propagation,
281 &edge_case_robustness,
282 &precision_analysis,
283 );
284
285 let stability_grade = self.grade_stability(stability_score);
287
288 let result = StabilityAnalysisResult {
289 function_name: function_name.to_string(),
290 stability_grade,
291 condition_analysis,
292 error_propagation,
293 edge_case_robustness,
294 precision_analysis,
295 recommendations,
296 stability_score,
297 };
298
299 self.analysis_results
300 .insert(function_name.to_string(), result.clone());
301 Ok(result)
302 }
303
304 fn analyze_condition_number(
306 &self,
307 data: &ArrayView1<f64>,
308 ) -> StatsResult<ConditionNumberAnalysis> {
309 let data_range = data.iter().fold(0.0f64, |acc, &x| acc.max(x.abs()));
311 let data_min = data.iter().fold(f64::INFINITY, |acc, &x| acc.min(x.abs()));
312
313 let condition_number = if data_min > self.config.zero_tolerance {
315 data_range / data_min
316 } else {
317 f64::INFINITY
318 };
319
320 let conditioning_class = if condition_number < 1e12 {
321 ConditioningClass::WellConditioned
322 } else if condition_number < 1e14 {
323 ConditioningClass::ModeratelyConditioned
324 } else if condition_number < 1e16 {
325 ConditioningClass::PoorlyConditioned
326 } else {
327 ConditioningClass::NearlySingular
328 };
329
330 let accuracy_loss_digits = condition_number.log10().max(0.0);
332
333 let input_sensitivity = condition_number / 1e16;
335
336 Ok(ConditionNumberAnalysis {
337 condition_number,
338 conditioning_class,
339 accuracy_loss_digits,
340 input_sensitivity,
341 })
342 }
343
344 fn analyze_error_propagation<F>(
346 &self,
347 function: &F,
348 data: &ArrayView1<f64>,
349 ) -> StatsResult<ErrorPropagationAnalysis>
350 where
351 F: Fn(&ArrayView1<f64>) -> StatsResult<f64>,
352 {
353 let reference_result = function(data)?;
355
356 let mut forward_errors = Vec::new();
357 let mut backward_errors = Vec::new();
358
359 for i in 0..self.config.perturbation_tests.min(data.len()) {
361 let mut perturbeddata = data.to_owned();
362 let perturbation = self.config.perturbation_magnitude * perturbeddata[i].abs().max(1.0);
363 perturbeddata[i] += perturbation;
364
365 if let Ok(perturbed_result) = function(&perturbeddata.view()) {
366 let forward_error = (perturbed_result - reference_result).abs();
367 let backward_error = perturbation.abs();
368
369 forward_errors.push(forward_error);
370 backward_errors.push(backward_error);
371 }
372 }
373
374 let forward_error_bound = forward_errors.iter().fold(0.0f64, |acc, &x| acc.max(x));
375 let backward_error_bound = backward_errors.iter().fold(0.0f64, |acc, &x| acc.max(x));
376
377 let error_amplification = if backward_error_bound > 0.0 {
379 forward_error_bound / backward_error_bound
380 } else {
381 1.0
382 };
383
384 let rounding_error_stability = 1.0 / (1.0 + error_amplification);
386
387 Ok(ErrorPropagationAnalysis {
388 forward_error_bound,
389 backward_error_bound,
390 error_amplification,
391 rounding_error_stability,
392 })
393 }
394
395 fn test_edge_case_robustness<F>(&self, function: &F) -> StatsResult<EdgeCaseRobustness>
397 where
398 F: Fn(&ArrayView1<f64>) -> StatsResult<f64>,
399 {
400 let mut tests_passed = 0;
401 let mut total_tests = 0;
402
403 let mut handles_infinity = false;
404 let mut handles_nan = false;
405 let mut handles_zero = false;
406 let mut handles_large_values = false;
407 let mut handles_small_values = false;
408
409 if self.config.test_extreme_values {
411 total_tests += 1;
412 let infdata = Array1::from_vec(vec![f64::INFINITY, 1.0, 2.0]);
413 if let Ok(result) = function(&infdata.view()) {
414 if result.is_finite() || result.is_infinite() {
415 handles_infinity = true;
416 tests_passed += 1;
417 }
418 }
419
420 total_tests += 1;
422 let nandata = Array1::from_vec(vec![f64::NAN, 1.0, 2.0]);
423 if let Ok(result) = function(&nandata.view()) {
424 if result.is_nan() || result.is_finite() {
425 handles_nan = true;
426 tests_passed += 1;
427 }
428 }
429
430 total_tests += 1;
432 let zerodata = Array1::from_vec(vec![0.0, 0.0, 0.0]);
433 if function(&zerodata.view()).is_ok() {
434 handles_zero = true;
435 tests_passed += 1;
436 }
437
438 total_tests += 1;
440 let largedata = Array1::from_vec(vec![1e100, 1e200, 1e300]);
441 if function(&largedata.view()).is_ok() {
442 handles_large_values = true;
443 tests_passed += 1;
444 }
445
446 total_tests += 1;
448 let smalldata = Array1::from_vec(vec![1e-100, 1e-200, 1e-300]);
449 if function(&smalldata.view()).is_ok() {
450 handles_small_values = true;
451 tests_passed += 1;
452 }
453 }
454
455 let edge_case_success_rate = if total_tests > 0 {
456 tests_passed as f64 / total_tests as f64
457 } else {
458 1.0
459 };
460
461 Ok(EdgeCaseRobustness {
462 handles_infinity,
463 handles_nan,
464 handles_zero,
465 handles_large_values,
466 handles_small_values,
467 edge_case_success_rate,
468 })
469 }
470
471 fn analyze_precision_loss<F>(
473 &self,
474 function: &F,
475 data: &ArrayView1<f64>,
476 ) -> StatsResult<PrecisionAnalysis>
477 where
478 F: Fn(&ArrayView1<f64>) -> StatsResult<f64>,
479 {
480 let result = function(data)?;
482
483 let precision_loss_bits = if result.abs() < self.config.precision_tolerance {
485 16.0 } else if result.abs() < 1e-10 {
487 8.0 } else {
489 2.0 };
491
492 let relative_precision = 1.0 - (precision_loss_bits / 64.0);
493
494 let mut cancellation_errors = Vec::new();
496 if data.iter().any(|&x| x.abs() < self.config.zero_tolerance) {
497 cancellation_errors.push(CancellationError {
498 location: "inputdata".to_string(),
499 precision_loss: precision_loss_bits,
500 mitigation: "Use higher precision arithmetic or alternative algorithm".to_string(),
501 });
502 }
503
504 let max_val = data.iter().fold(0.0f64, |acc, &x| acc.max(x.abs()));
506 let overflow_underflow_risk = if max_val > 1e100 {
507 OverflowRisk::High
508 } else if max_val > 1e50 {
509 OverflowRisk::Moderate
510 } else if max_val < 1e-100 {
511 OverflowRisk::Moderate
512 } else {
513 OverflowRisk::Low
514 };
515
516 Ok(PrecisionAnalysis {
517 precision_loss_bits,
518 relative_precision,
519 cancellation_errors,
520 overflow_underflow_risk,
521 })
522 }
523
524 fn generate_recommendations(
526 &self,
527 condition_analysis: &ConditionNumberAnalysis,
528 error_propagation: &ErrorPropagationAnalysis,
529 edge_case_robustness: &EdgeCaseRobustness,
530 precision_analysis: &PrecisionAnalysis,
531 ) -> Vec<StabilityRecommendation> {
532 let mut recommendations = Vec::new();
533
534 if matches!(
536 condition_analysis.conditioning_class,
537 ConditioningClass::PoorlyConditioned | ConditioningClass::NearlySingular
538 ) {
539 recommendations.push(StabilityRecommendation {
540 recommendation_type: RecommendationType::Algorithm,
541 description: "Poor conditioning detected".to_string(),
542 suggestion: "Consider using regularization or alternative algorithms for ill-conditioned problems".to_string(),
543 priority: RecommendationPriority::High,
544 expected_improvement: 30.0,
545 });
546 }
547
548 if error_propagation.error_amplification > 100.0 {
550 recommendations.push(StabilityRecommendation {
551 recommendation_type: RecommendationType::Numerical,
552 description: "High error amplification detected".to_string(),
553 suggestion: "Implement error _analysis and use more stable numerical methods"
554 .to_string(),
555 priority: RecommendationPriority::High,
556 expected_improvement: 25.0,
557 });
558 }
559
560 if edge_case_robustness.edge_case_success_rate < 0.8 {
562 recommendations.push(StabilityRecommendation {
563 recommendation_type: RecommendationType::InputValidation,
564 description: "Poor edge case handling".to_string(),
565 suggestion:
566 "Improve input validation and add special case handling for extreme values"
567 .to_string(),
568 priority: RecommendationPriority::Medium,
569 expected_improvement: 20.0,
570 });
571 }
572
573 if precision_analysis.precision_loss_bits > 10.0 {
575 recommendations.push(StabilityRecommendation {
576 recommendation_type: RecommendationType::Precision,
577 description: "Significant precision loss detected".to_string(),
578 suggestion:
579 "Consider using higher precision arithmetic or numerically stable algorithms"
580 .to_string(),
581 priority: RecommendationPriority::High,
582 expected_improvement: 35.0,
583 });
584 }
585
586 recommendations
587 }
588
589 fn calculate_stability_score(
591 &self,
592 condition_analysis: &ConditionNumberAnalysis,
593 error_propagation: &ErrorPropagationAnalysis,
594 edge_case_robustness: &EdgeCaseRobustness,
595 precision_analysis: &PrecisionAnalysis,
596 ) -> f64 {
597 let mut score = 100.0;
598
599 score -= match condition_analysis.conditioning_class {
601 ConditioningClass::WellConditioned => 0.0,
602 ConditioningClass::ModeratelyConditioned => 10.0,
603 ConditioningClass::PoorlyConditioned => 25.0,
604 ConditioningClass::NearlySingular => 40.0,
605 };
606
607 score -= (error_propagation.error_amplification.log10() * 5.0).min(30.0);
609
610 score -= (1.0 - edge_case_robustness.edge_case_success_rate) * 20.0;
612
613 score -= (precision_analysis.precision_loss_bits / 64.0) * 30.0;
615
616 score.max(0.0)
617 }
618
619 fn grade_stability(&self, score: f64) -> StabilityGrade {
621 if score >= 90.0 {
622 StabilityGrade::Excellent
623 } else if score >= 75.0 {
624 StabilityGrade::Good
625 } else if score >= 60.0 {
626 StabilityGrade::Acceptable
627 } else if score >= 40.0 {
628 StabilityGrade::Poor
629 } else {
630 StabilityGrade::Unstable
631 }
632 }
633
634 pub fn generate_stability_report(&self) -> StabilityReport {
636 let results: Vec<_> = self.analysis_results.values().cloned().collect();
637
638 let total_functions = results.len();
639 let excellent_count = results
640 .iter()
641 .filter(|r| r.stability_grade == StabilityGrade::Excellent)
642 .count();
643 let good_count = results
644 .iter()
645 .filter(|r| r.stability_grade == StabilityGrade::Good)
646 .count();
647 let acceptable_count = results
648 .iter()
649 .filter(|r| r.stability_grade == StabilityGrade::Acceptable)
650 .count();
651 let poor_count = results
652 .iter()
653 .filter(|r| r.stability_grade == StabilityGrade::Poor)
654 .count();
655 let unstable_count = results
656 .iter()
657 .filter(|r| r.stability_grade == StabilityGrade::Unstable)
658 .count();
659
660 let average_score = if total_functions > 0 {
661 results.iter().map(|r| r.stability_score).sum::<f64>() / total_functions as f64
662 } else {
663 0.0
664 };
665
666 StabilityReport {
667 total_functions,
668 excellent_count,
669 good_count,
670 acceptable_count,
671 poor_count,
672 unstable_count,
673 average_score,
674 function_results: results,
675 generated_at: chrono::Utc::now(),
676 }
677 }
678}
679
680#[derive(Debug, Clone, Serialize, Deserialize)]
682pub struct StabilityReport {
683 pub total_functions: usize,
685 pub excellent_count: usize,
687 pub good_count: usize,
689 pub acceptable_count: usize,
691 pub poor_count: usize,
693 pub unstable_count: usize,
695 pub average_score: f64,
697 pub function_results: Vec<StabilityAnalysisResult>,
699 pub generated_at: chrono::DateTime<chrono::Utc>,
701}
702
703#[cfg(test)]
704mod tests {
705 use super::*;
706 use crate::descriptive::mean;
707
708 #[test]
709 fn test_stability_analyzer_creation() {
710 let analyzer = NumericalStabilityAnalyzer::default();
711 assert_eq!(analyzer.config.zero_tolerance, 1e-15);
712 assert_eq!(analyzer.config.precision_tolerance, 1e-12);
713 }
714
715 #[test]
716 fn test_condition_number_analysis() {
717 let analyzer = NumericalStabilityAnalyzer::default();
718 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
719 let result = analyzer
720 .analyze_condition_number(&data.view())
721 .expect("Operation failed");
722
723 assert_eq!(
724 result.conditioning_class,
725 ConditioningClass::WellConditioned
726 );
727 assert!(result.condition_number > 0.0);
728 }
729
730 #[test]
731 fn test_stability_grading() {
732 let analyzer = NumericalStabilityAnalyzer::default();
733
734 assert_eq!(analyzer.grade_stability(95.0), StabilityGrade::Excellent);
735 assert_eq!(analyzer.grade_stability(80.0), StabilityGrade::Good);
736 assert_eq!(analyzer.grade_stability(65.0), StabilityGrade::Acceptable);
737 assert_eq!(analyzer.grade_stability(45.0), StabilityGrade::Poor);
738 assert_eq!(analyzer.grade_stability(20.0), StabilityGrade::Unstable);
739 }
740
741 #[test]
742 fn test_mean_stability_analysis() {
743 let mut analyzer = NumericalStabilityAnalyzer::default();
744 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
745
746 let result = analyzer
747 .analyze_function("mean", |x| mean(x), &data.view())
748 .expect("Operation failed");
749
750 assert_eq!(result.function_name, "mean");
751 assert!(matches!(
752 result.stability_grade,
753 StabilityGrade::Excellent | StabilityGrade::Good
754 ));
755 assert!(result.stability_score > 50.0);
756 }
757}