sklears_kernel_approximation/
numerical_stability.rs

1//! Numerical Stability Enhancements for Kernel Approximation Methods
2//!
3//! This module provides tools for monitoring and improving numerical stability
4//! including condition number monitoring, overflow/underflow protection,
5//! and numerically stable algorithms.
6
7use scirs2_core::ndarray::{Array1, Array2, Axis};
8use scirs2_core::random::thread_rng;
9use scirs2_core::random::Rng;
10use sklears_core::prelude::{Result, SklearsError};
11
12/// Numerical stability monitor for kernel approximation methods
13#[derive(Debug, Clone)]
14/// NumericalStabilityMonitor
15pub struct NumericalStabilityMonitor {
16    config: StabilityConfig,
17    metrics: StabilityMetrics,
18    warnings: Vec<StabilityWarning>,
19}
20
21/// Configuration for numerical stability monitoring
22#[derive(Debug, Clone)]
23/// StabilityConfig
24pub struct StabilityConfig {
25    /// Maximum allowed condition number
26    pub max_condition_number: f64,
27
28    /// Minimum eigenvalue threshold
29    pub min_eigenvalue: f64,
30
31    /// Maximum eigenvalue threshold
32    pub max_eigenvalue: f64,
33
34    /// Tolerance for numerical precision
35    pub numerical_tolerance: f64,
36
37    /// Enable overflow/underflow protection
38    pub enable_overflow_protection: bool,
39
40    /// Enable high-precision arithmetic when needed
41    pub enable_high_precision: bool,
42
43    /// Regularization parameter for ill-conditioned matrices
44    pub regularization: f64,
45}
46
47impl Default for StabilityConfig {
48    fn default() -> Self {
49        Self {
50            max_condition_number: 1e12,
51            min_eigenvalue: 1e-12,
52            max_eigenvalue: 1e12,
53            numerical_tolerance: 1e-12,
54            enable_overflow_protection: true,
55            enable_high_precision: false,
56            regularization: 1e-8,
57        }
58    }
59}
60
61/// Numerical stability metrics
62#[derive(Debug, Clone, Default)]
63/// StabilityMetrics
64pub struct StabilityMetrics {
65    /// Condition numbers of matrices encountered
66    pub condition_numbers: Vec<f64>,
67
68    /// Eigenvalue ranges
69    pub eigenvalue_ranges: Vec<(f64, f64)>,
70
71    /// Numerical errors detected
72    pub numerical_errors: Vec<f64>,
73
74    /// Matrix ranks
75    pub matrix_ranks: Vec<usize>,
76
77    /// Overflow/underflow occurrences
78    pub overflow_count: usize,
79    /// underflow_count
80    pub underflow_count: usize,
81
82    /// Precision loss estimates
83    pub precision_loss: Vec<f64>,
84}
85
86/// Types of numerical stability warnings
87#[derive(Debug, Clone)]
88/// StabilityWarning
89pub enum StabilityWarning {
90    /// High condition number detected
91    HighConditionNumber {
92        condition_number: f64,
93
94        location: String,
95    },
96
97    /// Near-singular matrix detected
98    NearSingular {
99        smallest_eigenvalue: f64,
100
101        location: String,
102    },
103
104    /// Overflow detected
105    Overflow { value: f64, location: String },
106
107    /// Underflow detected
108    Underflow { value: f64, location: String },
109
110    /// Significant precision loss
111    PrecisionLoss {
112        estimated_loss: f64,
113        location: String,
114    },
115
116    /// Rank deficiency detected
117    RankDeficient {
118        expected_rank: usize,
119        actual_rank: usize,
120        location: String,
121    },
122}
123
124impl NumericalStabilityMonitor {
125    /// Create a new stability monitor
126    pub fn new(config: StabilityConfig) -> Self {
127        Self {
128            config,
129            metrics: StabilityMetrics::default(),
130            warnings: Vec::new(),
131        }
132    }
133
134    /// Monitor matrix for numerical stability issues
135    pub fn monitor_matrix(&mut self, matrix: &Array2<f64>, location: &str) -> Result<()> {
136        // Check for NaN and infinite values
137        self.check_finite_values(matrix, location)?;
138
139        // Compute and monitor condition number
140        let condition_number = self.estimate_condition_number(matrix)?;
141        self.metrics.condition_numbers.push(condition_number);
142
143        if condition_number > self.config.max_condition_number {
144            self.warnings.push(StabilityWarning::HighConditionNumber {
145                condition_number,
146                location: location.to_string(),
147            });
148        }
149
150        // Check eigenvalues if matrix is square
151        if matrix.nrows() == matrix.ncols() {
152            let eigenvalues = self.estimate_eigenvalues(matrix)?;
153            let min_eigenval = eigenvalues.iter().fold(f64::INFINITY, |acc, &x| acc.min(x));
154            let max_eigenval = eigenvalues
155                .iter()
156                .fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
157
158            self.metrics
159                .eigenvalue_ranges
160                .push((min_eigenval, max_eigenval));
161
162            if min_eigenval.abs() < self.config.min_eigenvalue {
163                self.warnings.push(StabilityWarning::NearSingular {
164                    smallest_eigenvalue: min_eigenval,
165                    location: location.to_string(),
166                });
167            }
168        }
169
170        // Estimate matrix rank
171        let rank = self.estimate_rank(matrix)?;
172        self.metrics.matrix_ranks.push(rank);
173
174        let expected_rank = matrix.nrows().min(matrix.ncols());
175        if rank < expected_rank {
176            self.warnings.push(StabilityWarning::RankDeficient {
177                expected_rank,
178                actual_rank: rank,
179                location: location.to_string(),
180            });
181        }
182
183        Ok(())
184    }
185
186    /// Apply numerical stabilization to a matrix
187    pub fn stabilize_matrix(&mut self, matrix: &mut Array2<f64>) -> Result<()> {
188        // Apply regularization for ill-conditioned matrices
189        if matrix.nrows() == matrix.ncols() {
190            for i in 0..matrix.nrows() {
191                matrix[[i, i]] += self.config.regularization;
192            }
193        }
194
195        // Clamp extreme values if overflow protection is enabled
196        if self.config.enable_overflow_protection {
197            self.clamp_extreme_values(matrix)?;
198        }
199
200        Ok(())
201    }
202
203    /// Compute numerically stable eigendecomposition
204    pub fn stable_eigendecomposition(
205        &mut self,
206        matrix: &Array2<f64>,
207    ) -> Result<(Array1<f64>, Array2<f64>)> {
208        if matrix.nrows() != matrix.ncols() {
209            return Err(SklearsError::InvalidInput(
210                "Matrix must be square for eigendecomposition".to_string(),
211            ));
212        }
213
214        self.monitor_matrix(matrix, "eigendecomposition_input")?;
215
216        // Apply stabilization
217        let mut stabilized_matrix = matrix.clone();
218        self.stabilize_matrix(&mut stabilized_matrix)?;
219
220        // Simplified eigendecomposition using power iteration for largest eigenvalues
221        let n = matrix.nrows();
222        let max_eigenvalues = 10.min(n);
223
224        let mut eigenvalues = Array1::zeros(max_eigenvalues);
225        let mut eigenvectors = Array2::zeros((n, max_eigenvalues));
226
227        let mut current_matrix = stabilized_matrix.clone();
228
229        for i in 0..max_eigenvalues {
230            let (eigenvalue, eigenvector) = self.power_iteration(&current_matrix)?;
231            eigenvalues[i] = eigenvalue;
232            eigenvectors.column_mut(i).assign(&eigenvector);
233
234            // Deflate the matrix
235            let outer_product = self.outer_product(&eigenvector, &eigenvector);
236            current_matrix = &current_matrix - &(&outer_product * eigenvalue);
237
238            // Check for convergence
239            if eigenvalue.abs() < self.config.min_eigenvalue {
240                break;
241            }
242        }
243
244        Ok((eigenvalues, eigenvectors))
245    }
246
247    /// Compute numerically stable matrix inversion
248    pub fn stable_matrix_inverse(&mut self, matrix: &Array2<f64>) -> Result<Array2<f64>> {
249        self.monitor_matrix(matrix, "matrix_inverse_input")?;
250
251        if matrix.nrows() != matrix.ncols() {
252            return Err(SklearsError::InvalidInput(
253                "Matrix must be square for inversion".to_string(),
254            ));
255        }
256
257        let n = matrix.nrows();
258
259        // Use regularized pseudoinverse for stability
260        let regularized_matrix = self.regularize_matrix(matrix)?;
261
262        // Simplified inversion using Gauss-Jordan elimination with pivoting
263        let mut augmented = Array2::zeros((n, 2 * n));
264
265        // Set up augmented matrix [A | I]
266        for i in 0..n {
267            for j in 0..n {
268                augmented[[i, j]] = regularized_matrix[[i, j]];
269            }
270            augmented[[i, n + i]] = 1.0;
271        }
272
273        // Forward elimination with partial pivoting
274        for i in 0..n {
275            // Find pivot
276            let mut max_row = i;
277            for k in i + 1..n {
278                if augmented[[k, i]].abs() > augmented[[max_row, i]].abs() {
279                    max_row = k;
280                }
281            }
282
283            // Swap rows if needed
284            if max_row != i {
285                for j in 0..2 * n {
286                    let temp = augmented[[i, j]];
287                    augmented[[i, j]] = augmented[[max_row, j]];
288                    augmented[[max_row, j]] = temp;
289                }
290            }
291
292            // Check for near-zero pivot
293            if augmented[[i, i]].abs() < self.config.numerical_tolerance {
294                return Err(SklearsError::NumericalError(
295                    "Matrix is singular or near-singular".to_string(),
296                ));
297            }
298
299            // Scale pivot row
300            let pivot = augmented[[i, i]];
301            for j in 0..2 * n {
302                augmented[[i, j]] /= pivot;
303            }
304
305            // Eliminate column
306            for k in 0..n {
307                if k != i {
308                    let factor = augmented[[k, i]];
309                    for j in 0..2 * n {
310                        augmented[[k, j]] -= factor * augmented[[i, j]];
311                    }
312                }
313            }
314        }
315
316        // Extract inverse matrix
317        let mut inverse = Array2::zeros((n, n));
318        for i in 0..n {
319            for j in 0..n {
320                inverse[[i, j]] = augmented[[i, n + j]];
321            }
322        }
323
324        self.monitor_matrix(&inverse, "matrix_inverse_output")?;
325
326        Ok(inverse)
327    }
328
329    /// Compute numerically stable Cholesky decomposition
330    pub fn stable_cholesky(&mut self, matrix: &Array2<f64>) -> Result<Array2<f64>> {
331        self.monitor_matrix(matrix, "cholesky_input")?;
332
333        if matrix.nrows() != matrix.ncols() {
334            return Err(SklearsError::InvalidInput(
335                "Matrix must be square for Cholesky decomposition".to_string(),
336            ));
337        }
338
339        let n = matrix.nrows();
340
341        // Apply regularization for numerical stability
342        let regularized_matrix = self.regularize_matrix(matrix)?;
343
344        let mut L = Array2::zeros((n, n));
345
346        for i in 0..n {
347            for j in 0..=i {
348                if i == j {
349                    // Diagonal elements
350                    let mut sum = 0.0;
351                    for k in 0..j {
352                        sum += L[[j, k]] * L[[j, k]];
353                    }
354
355                    let diagonal_value = regularized_matrix[[j, j]] - sum;
356                    if diagonal_value <= 0.0 {
357                        return Err(SklearsError::NumericalError(
358                            "Matrix is not positive definite".to_string(),
359                        ));
360                    }
361
362                    L[[j, j]] = diagonal_value.sqrt();
363                } else {
364                    // Off-diagonal elements
365                    let mut sum = 0.0;
366                    for k in 0..j {
367                        sum += L[[i, k]] * L[[j, k]];
368                    }
369
370                    L[[i, j]] = (regularized_matrix[[i, j]] - sum) / L[[j, j]];
371                }
372            }
373        }
374
375        self.monitor_matrix(&L, "cholesky_output")?;
376
377        Ok(L)
378    }
379
380    /// Get stability warnings
381    pub fn get_warnings(&self) -> &[StabilityWarning] {
382        &self.warnings
383    }
384
385    /// Get stability metrics
386    pub fn get_metrics(&self) -> &StabilityMetrics {
387        &self.metrics
388    }
389
390    /// Clear warnings and metrics
391    pub fn clear(&mut self) {
392        self.warnings.clear();
393        self.metrics = StabilityMetrics::default();
394    }
395
396    /// Generate stability report
397    pub fn generate_report(&self) -> String {
398        let mut report = String::new();
399
400        report.push_str("=== Numerical Stability Report ===\n\n");
401
402        // Summary statistics
403        if !self.metrics.condition_numbers.is_empty() {
404            let mean_condition = self.metrics.condition_numbers.iter().sum::<f64>()
405                / self.metrics.condition_numbers.len() as f64;
406            let max_condition = self
407                .metrics
408                .condition_numbers
409                .iter()
410                .fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
411
412            report.push_str(&format!(
413                "Condition Numbers:\n  Mean: {:.2e}\n  Max: {:.2e}\n  Count: {}\n\n",
414                mean_condition,
415                max_condition,
416                self.metrics.condition_numbers.len()
417            ));
418        }
419
420        // Eigenvalue analysis
421        if !self.metrics.eigenvalue_ranges.is_empty() {
422            let min_eigenval = self
423                .metrics
424                .eigenvalue_ranges
425                .iter()
426                .map(|(min, _)| *min)
427                .fold(f64::INFINITY, f64::min);
428            let max_eigenval = self
429                .metrics
430                .eigenvalue_ranges
431                .iter()
432                .map(|(_, max)| *max)
433                .fold(f64::NEG_INFINITY, f64::max);
434
435            report.push_str(&format!(
436                "Eigenvalue Range:\n  Min: {:.2e}\n  Max: {:.2e}\n  Ratio: {:.2e}\n\n",
437                min_eigenval,
438                max_eigenval,
439                max_eigenval / min_eigenval.abs().max(1e-16)
440            ));
441        }
442
443        // Warnings
444        if !self.warnings.is_empty() {
445            report.push_str(&format!("Warnings ({}):\n", self.warnings.len()));
446            for (i, warning) in self.warnings.iter().enumerate() {
447                report.push_str(&format!("  {}: {}\n", i + 1, self.format_warning(warning)));
448            }
449            report.push('\n');
450        }
451
452        // Overflow/underflow statistics
453        if self.metrics.overflow_count > 0 || self.metrics.underflow_count > 0 {
454            report.push_str(&format!(
455                "Overflow/Underflow:\n  Overflows: {}\n  Underflows: {}\n\n",
456                self.metrics.overflow_count, self.metrics.underflow_count
457            ));
458        }
459
460        report.push_str("=== End Report ===\n");
461
462        report
463    }
464
465    // Helper methods
466
467    fn check_finite_values(&mut self, matrix: &Array2<f64>, location: &str) -> Result<()> {
468        for &value in matrix.iter() {
469            if !value.is_finite() {
470                if value.is_infinite() {
471                    self.metrics.overflow_count += 1;
472                    self.warnings.push(StabilityWarning::Overflow {
473                        value,
474                        location: location.to_string(),
475                    });
476                } else if value == 0.0 && value.is_sign_negative() {
477                    self.metrics.underflow_count += 1;
478                    self.warnings.push(StabilityWarning::Underflow {
479                        value,
480                        location: location.to_string(),
481                    });
482                }
483
484                return Err(SklearsError::NumericalError(format!(
485                    "Non-finite value detected: {} at {}",
486                    value, location
487                )));
488            }
489        }
490        Ok(())
491    }
492
493    fn estimate_condition_number(&self, matrix: &Array2<f64>) -> Result<f64> {
494        // Simplified condition number estimation using Frobenius norm
495        let norm = matrix.mapv(|x| x * x).sum().sqrt();
496
497        if matrix.nrows() == matrix.ncols() {
498            // For square matrices, estimate using diagonal dominance
499            let mut min_diag = f64::INFINITY;
500            let mut max_off_diag: f64 = 0.0;
501
502            for i in 0..matrix.nrows() {
503                min_diag = min_diag.min(matrix[[i, i]].abs());
504
505                for j in 0..matrix.ncols() {
506                    if i != j {
507                        max_off_diag = max_off_diag.max(matrix[[i, j]].abs());
508                    }
509                }
510            }
511
512            let condition_estimate = if min_diag > 0.0 {
513                (norm + max_off_diag) / min_diag
514            } else {
515                f64::INFINITY
516            };
517
518            Ok(condition_estimate)
519        } else {
520            // For non-square matrices, use norm-based estimate
521            let min_norm = matrix
522                .axis_iter(Axis(0))
523                .map(|row| row.mapv(|x| x * x).sum().sqrt())
524                .fold(f64::INFINITY, f64::min);
525
526            Ok(norm / min_norm.max(1e-16))
527        }
528    }
529
530    fn estimate_eigenvalues(&self, matrix: &Array2<f64>) -> Result<Array1<f64>> {
531        let n = matrix.nrows();
532
533        // Simplified eigenvalue estimation using Gershgorin circles
534        let mut eigenvalue_bounds = Array1::zeros(n);
535
536        for i in 0..n {
537            let center = matrix[[i, i]];
538            let radius = (0..n)
539                .filter(|&j| j != i)
540                .map(|j| matrix[[i, j]].abs())
541                .sum::<f64>();
542
543            eigenvalue_bounds[i] = center + radius; // Upper bound estimate
544        }
545
546        Ok(eigenvalue_bounds)
547    }
548
549    fn estimate_rank(&self, matrix: &Array2<f64>) -> Result<usize> {
550        // Simplified rank estimation using diagonal elements after regularization
551        let mut regularized = matrix.clone();
552
553        if matrix.nrows() == matrix.ncols() {
554            for i in 0..matrix.nrows() {
555                regularized[[i, i]] += self.config.regularization;
556            }
557        }
558
559        let mut rank = 0;
560        let min_dim = matrix.nrows().min(matrix.ncols());
561
562        for i in 0..min_dim {
563            let column_norm = regularized.column(i).mapv(|x| x * x).sum().sqrt();
564            if column_norm > self.config.numerical_tolerance {
565                rank += 1;
566            }
567        }
568
569        Ok(rank)
570    }
571
572    fn clamp_extreme_values(&mut self, matrix: &mut Array2<f64>) -> Result<()> {
573        let max_value = 1e12;
574        let min_value = -1e12;
575
576        for value in matrix.iter_mut() {
577            if *value > max_value {
578                *value = max_value;
579                self.metrics.overflow_count += 1;
580            } else if *value < min_value {
581                *value = min_value;
582                self.metrics.underflow_count += 1;
583            }
584        }
585
586        Ok(())
587    }
588
589    fn power_iteration(&self, matrix: &Array2<f64>) -> Result<(f64, Array1<f64>)> {
590        let n = matrix.nrows();
591        let mut vector = Array1::from_shape_fn(n, |_| thread_rng().gen::<f64>() - 0.5);
592
593        // Normalize initial vector
594        let norm = vector.mapv(|x| x * x).sum().sqrt();
595        vector /= norm;
596
597        let mut eigenvalue = 0.0;
598
599        for _ in 0..100 {
600            let new_vector = matrix.dot(&vector);
601            let new_norm = new_vector.mapv(|x| x * x).sum().sqrt();
602
603            if new_norm < self.config.numerical_tolerance {
604                break;
605            }
606
607            eigenvalue = vector.dot(&new_vector);
608            vector = new_vector / new_norm;
609        }
610
611        Ok((eigenvalue, vector))
612    }
613
614    fn outer_product(&self, v1: &Array1<f64>, v2: &Array1<f64>) -> Array2<f64> {
615        let mut result = Array2::zeros((v1.len(), v2.len()));
616
617        for i in 0..v1.len() {
618            for j in 0..v2.len() {
619                result[[i, j]] = v1[i] * v2[j];
620            }
621        }
622
623        result
624    }
625
626    fn regularize_matrix(&self, matrix: &Array2<f64>) -> Result<Array2<f64>> {
627        let mut regularized = matrix.clone();
628
629        if matrix.nrows() == matrix.ncols() {
630            for i in 0..matrix.nrows() {
631                regularized[[i, i]] += self.config.regularization;
632            }
633        }
634
635        Ok(regularized)
636    }
637
638    fn format_warning(&self, warning: &StabilityWarning) -> String {
639        match warning {
640            StabilityWarning::HighConditionNumber {
641                condition_number,
642                location,
643            } => {
644                format!(
645                    "High condition number {:.2e} at {}",
646                    condition_number, location
647                )
648            }
649            StabilityWarning::NearSingular {
650                smallest_eigenvalue,
651                location,
652            } => {
653                format!(
654                    "Near-singular matrix (λ_min = {:.2e}) at {}",
655                    smallest_eigenvalue, location
656                )
657            }
658            StabilityWarning::Overflow { value, location } => {
659                format!("Overflow (value = {:.2e}) at {}", value, location)
660            }
661            StabilityWarning::Underflow { value, location } => {
662                format!("Underflow (value = {:.2e}) at {}", value, location)
663            }
664            StabilityWarning::PrecisionLoss {
665                estimated_loss,
666                location,
667            } => {
668                format!(
669                    "Precision loss ({:.1}%) at {}",
670                    estimated_loss * 100.0,
671                    location
672                )
673            }
674            StabilityWarning::RankDeficient {
675                expected_rank,
676                actual_rank,
677                location,
678            } => {
679                format!(
680                    "Rank deficient ({}/{}) at {}",
681                    actual_rank, expected_rank, location
682                )
683            }
684        }
685    }
686}
687
688/// Numerically stable kernel matrix computation
689pub fn stable_kernel_matrix(
690    data1: &Array2<f64>,
691    data2: Option<&Array2<f64>>,
692    kernel_type: &str,
693    bandwidth: f64,
694    monitor: &mut NumericalStabilityMonitor,
695) -> Result<Array2<f64>> {
696    let data2 = data2.unwrap_or(data1);
697    let (n1, _n_features) = data1.dim();
698    let (n2, _) = data2.dim();
699
700    let mut kernel = Array2::zeros((n1, n2));
701
702    // Precompute squared norms for numerical stability
703    let norms1: Vec<f64> = data1
704        .axis_iter(Axis(0))
705        .map(|row| row.mapv(|x| x * x).sum())
706        .collect();
707
708    let norms2: Vec<f64> = data2
709        .axis_iter(Axis(0))
710        .map(|row| row.mapv(|x| x * x).sum())
711        .collect();
712
713    for i in 0..n1 {
714        for j in 0..n2 {
715            let similarity = match kernel_type {
716                "RBF" => {
717                    // Use numerically stable distance computation: ||x-y||² = ||x||² + ||y||² - 2⟨x,y⟩
718                    let dot_product = data1.row(i).dot(&data2.row(j));
719                    let dist_sq = norms1[i] + norms2[j] - 2.0 * dot_product;
720                    let dist_sq = dist_sq.max(0.0); // Ensure non-negative due to numerical errors
721
722                    let exponent = -bandwidth * dist_sq;
723
724                    // Clamp exponent to prevent underflow
725                    let clamped_exponent = exponent.max(-700.0); // e^(-700) ≈ 1e-304
726
727                    clamped_exponent.exp()
728                }
729                "Laplacian" => {
730                    let diff = &data1.row(i) - &data2.row(j);
731                    let dist = diff.mapv(|x| x.abs()).sum();
732
733                    let exponent = -bandwidth * dist;
734                    let clamped_exponent = exponent.max(-700.0);
735
736                    clamped_exponent.exp()
737                }
738                "Polynomial" => {
739                    let dot_product = data1.row(i).dot(&data2.row(j));
740                    let base = bandwidth * dot_product + 1.0;
741
742                    // Ensure positive base for polynomial kernel
743                    let clamped_base = base.max(1e-16);
744
745                    clamped_base.powi(2) // Degree 2 polynomial
746                }
747                "Linear" => data1.row(i).dot(&data2.row(j)),
748                _ => {
749                    return Err(SklearsError::InvalidOperation(format!(
750                        "Unsupported kernel type: {}",
751                        kernel_type
752                    )));
753                }
754            };
755
756            kernel[[i, j]] = similarity;
757        }
758    }
759
760    monitor.monitor_matrix(&kernel, &format!("{}_kernel_matrix", kernel_type))?;
761
762    Ok(kernel)
763}
764
765#[allow(non_snake_case)]
766#[cfg(test)]
767mod tests {
768    use super::*;
769    use scirs2_core::ndarray::array;
770
771    #[test]
772    fn test_stability_monitor_creation() {
773        let config = StabilityConfig::default();
774        let monitor = NumericalStabilityMonitor::new(config);
775
776        assert!(monitor.get_warnings().is_empty());
777        assert_eq!(monitor.get_metrics().condition_numbers.len(), 0);
778    }
779
780    #[test]
781    fn test_condition_number_monitoring() {
782        let mut monitor = NumericalStabilityMonitor::new(StabilityConfig::default());
783
784        // Well-conditioned matrix
785        let well_conditioned = array![[2.0, 1.0], [1.0, 2.0],];
786
787        monitor
788            .monitor_matrix(&well_conditioned, "test_well_conditioned")
789            .unwrap();
790        assert!(monitor.get_warnings().is_empty());
791
792        // Ill-conditioned matrix
793        let ill_conditioned = array![[1.0, 1.0], [1.0, 1.000001],];
794
795        monitor
796            .monitor_matrix(&ill_conditioned, "test_ill_conditioned")
797            .unwrap();
798        // Should detect high condition number or near-singularity
799    }
800
801    #[test]
802    fn test_stable_eigendecomposition() {
803        let mut monitor = NumericalStabilityMonitor::new(StabilityConfig::default());
804
805        let matrix = array![[4.0, 2.0], [2.0, 3.0],];
806
807        let (eigenvalues, eigenvectors) = monitor.stable_eigendecomposition(&matrix).unwrap();
808
809        assert!(eigenvalues.len() <= matrix.nrows());
810        assert_eq!(eigenvectors.nrows(), matrix.nrows());
811        assert!(eigenvalues.iter().all(|&x| x.is_finite()));
812    }
813
814    #[test]
815    fn test_stable_matrix_inverse() {
816        let mut monitor = NumericalStabilityMonitor::new(StabilityConfig::default());
817
818        let matrix = array![[4.0, 2.0], [2.0, 3.0],];
819
820        let inverse = monitor.stable_matrix_inverse(&matrix).unwrap();
821
822        // Check that A_regularized * A^(-1) ≈ I (where A^(-1) is inverse of regularized matrix)
823        let mut regularized_matrix = matrix.clone();
824        for i in 0..matrix.nrows() {
825            regularized_matrix[[i, i]] += 1e-8; // Default regularization
826        }
827
828        let product = regularized_matrix.dot(&inverse);
829        let identity_error = (&product - &Array2::<f64>::eye(2)).mapv(|x| x.abs()).sum();
830
831        assert!(identity_error < 1e-10);
832    }
833
834    #[test]
835    fn test_stable_cholesky() {
836        let mut monitor = NumericalStabilityMonitor::new(StabilityConfig::default());
837
838        // Positive definite matrix
839        let matrix = array![[4.0, 2.0], [2.0, 3.0],];
840
841        let cholesky = monitor.stable_cholesky(&matrix).unwrap();
842
843        // Check that L * L^T = A_regularized (regularized matrix)
844        let reconstructed = cholesky.dot(&cholesky.t());
845
846        // Compute the regularized matrix for comparison
847        let mut regularized_matrix = matrix.clone();
848        for i in 0..matrix.nrows() {
849            regularized_matrix[[i, i]] += 1e-8; // Default regularization
850        }
851
852        let reconstruction_error = (&regularized_matrix - &reconstructed)
853            .mapv(|x| x.abs())
854            .sum();
855
856        assert!(reconstruction_error < 1e-10);
857    }
858
859    #[test]
860    fn test_stable_kernel_matrix() {
861        let mut monitor = NumericalStabilityMonitor::new(StabilityConfig::default());
862
863        let data = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0],];
864
865        let kernel = stable_kernel_matrix(&data, None, "RBF", 1.0, &mut monitor).unwrap();
866
867        assert_eq!(kernel.shape(), &[3, 3]);
868        assert!(kernel.iter().all(|&x| x.is_finite() && x >= 0.0));
869
870        // Kernel matrix should be symmetric
871        for i in 0..3 {
872            for j in 0..3 {
873                assert!((kernel[[i, j]] - kernel[[j, i]]).abs() < 1e-12);
874            }
875        }
876
877        // Diagonal should be 1 for RBF kernel
878        for i in 0..3 {
879            assert!((kernel[[i, i]] - 1.0).abs() < 1e-12);
880        }
881    }
882
883    #[test]
884    fn test_stability_report() {
885        let mut monitor = NumericalStabilityMonitor::new(StabilityConfig::default());
886
887        let matrix = array![[1.0, 2.0], [3.0, 4.0],];
888
889        monitor.monitor_matrix(&matrix, "test_matrix").unwrap();
890
891        let report = monitor.generate_report();
892        assert!(report.contains("Numerical Stability Report"));
893        assert!(report.contains("Condition Numbers"));
894    }
895}