scirs2_core/numeric/
stability.rs

1//! # Numerical Stability Improvements
2//!
3//! This module provides numerically stable implementations of common algorithms that
4//! are prone to precision loss, overflow, or catastrophic cancellation.
5//!
6//! ## Features
7//!
8//! - Stable summation algorithms (Kahan, pairwise, compensated)
9//! - Robust variance and standard deviation calculations
10//! - Overflow-resistant multiplication and exponentiation
11//! - Stable trigonometric range reduction
12//! - Robust normalization techniques
13//! - Improved root finding with bracketing
14//! - Stable matrix condition number estimation
15
16use crate::{
17    error::{CoreError, CoreResult, ErrorContext},
18    validation::check_positive,
19};
20use ndarray::{Array1, ArrayView2, Axis};
21use num_traits::{cast, Float, Zero};
22use std::fmt::Debug;
23
24/// Trait for numerically stable computations
25pub trait StableComputation: Float + Debug {
26    /// Machine epsilon for this type
27    fn machine_epsilon() -> Self;
28
29    /// Safe reciprocal that handles near-zero values
30    fn safe_recip(self) -> Self;
31
32    /// Check if the value is effectively zero (within epsilon)
33    fn is_effectively_zero(self) -> bool;
34}
35
36impl StableComputation for f32 {
37    fn machine_epsilon() -> Self {
38        f32::EPSILON
39    }
40
41    fn safe_recip(self) -> Self {
42        if self.abs() < Self::machine_epsilon() * cast::<f64, Self>(10.0).unwrap_or(Self::zero()) {
43            Self::zero()
44        } else {
45            self.recip()
46        }
47    }
48
49    fn is_effectively_zero(self) -> bool {
50        self.abs() < Self::machine_epsilon() * cast::<f64, Self>(10.0).unwrap_or(Self::zero())
51    }
52}
53
54impl StableComputation for f64 {
55    fn machine_epsilon() -> Self {
56        f64::EPSILON
57    }
58
59    fn safe_recip(self) -> Self {
60        if self.abs() < Self::machine_epsilon() * cast::<f64, Self>(10.0).unwrap_or(Self::zero()) {
61            Self::zero()
62        } else {
63            self.recip()
64        }
65    }
66
67    fn is_effectively_zero(self) -> bool {
68        self.abs() < Self::machine_epsilon() * cast::<f64, Self>(10.0).unwrap_or(Self::zero())
69    }
70}
71
72/// Kahan summation algorithm for accurate floating-point summation
73///
74/// This algorithm reduces numerical error in the total obtained by adding a sequence
75/// of finite-precision floating-point numbers compared to the obvious approach.
76pub struct KahanSum<T: Float> {
77    sum: T,
78    compensation: T,
79}
80
81impl<T: Float> KahanSum<T> {
82    /// Create a new Kahan sum accumulator
83    pub fn new() -> Self {
84        Self {
85            sum: T::zero(),
86            compensation: T::zero(),
87        }
88    }
89
90    /// Add a value to the sum
91    pub fn add(&mut self, value: T) {
92        let y = value - self.compensation;
93        let t = self.sum + y;
94        self.compensation = (t - self.sum) - y;
95        self.sum = t;
96    }
97
98    /// Get the accumulated sum
99    pub fn sum(&self) -> T {
100        self.sum
101    }
102
103    /// Reset the accumulator
104    pub fn reset(&mut self) {
105        self.sum = T::zero();
106        self.compensation = T::zero();
107    }
108}
109
110impl<T: Float> Default for KahanSum<T> {
111    fn default() -> Self {
112        Self::new()
113    }
114}
115
116/// Neumaier summation algorithm (improved Kahan summation)
117///
118/// This is an improved version of Kahan summation that handles the case
119/// where the next item to be added is larger in absolute value than the running sum.
120#[allow(dead_code)]
121pub fn neumaier_sum<T: Float>(values: &[T]) -> T {
122    if values.is_empty() {
123        return T::zero();
124    }
125
126    let mut sum = values[0];
127    let mut compensation = T::zero();
128
129    for &value in &values[1..] {
130        let t = sum + value;
131        if sum.abs() >= value.abs() {
132            compensation = compensation + ((sum - t) + value);
133        } else {
134            compensation = compensation + ((value - t) + sum);
135        }
136        sum = t;
137    }
138
139    sum + compensation
140}
141
142/// Pairwise summation algorithm
143///
144/// Recursively splits the array and sums pairs, reducing rounding error
145/// compared to sequential summation.
146#[allow(dead_code)]
147pub fn pairwise_sum<T: Float>(values: &[T]) -> T {
148    const SEQUENTIAL_THRESHOLD: usize = 128;
149
150    match values.len() {
151        0 => T::zero(),
152        1 => values[0],
153        n if n <= SEQUENTIAL_THRESHOLD => {
154            // Use Kahan summation for small arrays
155            let mut kahan = KahanSum::new();
156            for &v in values {
157                kahan.add(v);
158            }
159            kahan.sum()
160        }
161        n => {
162            let mid = n / 2;
163            pairwise_sum(&values[..mid]) + pairwise_sum(&values[mid..])
164        }
165    }
166}
167
168/// Stable mean calculation using compensated summation
169#[allow(dead_code)]
170pub fn stable_mean<T: Float>(values: &[T]) -> CoreResult<T> {
171    if values.is_empty() {
172        return Err(CoreError::ValidationError(ErrorContext::new(
173            "Cannot compute mean of empty array",
174        )));
175    }
176
177    let n = cast::<usize, T>(values.len()).ok_or_else(|| {
178        CoreError::TypeError(ErrorContext::new("Failed to convert array length to float"))
179    })?;
180
181    Ok(neumaier_sum(values) / n)
182}
183
184/// Welford's online algorithm for computing variance
185///
186/// This algorithm is numerically stable and computes variance in a single pass.
187pub struct WelfordVariance<T: Float> {
188    count: usize,
189    mean: T,
190    m2: T,
191}
192
193impl<T: Float> WelfordVariance<T> {
194    /// Create a new variance accumulator
195    pub fn new() -> Self {
196        Self {
197            count: 0,
198            mean: T::zero(),
199            m2: T::zero(),
200        }
201    }
202
203    /// Add a value to the accumulator
204    pub fn add(&mut self, value: T) {
205        self.count += 1;
206        let delta = value - self.mean;
207        self.mean = self.mean + delta / cast::<usize, T>(self.count).unwrap_or(T::one());
208        let delta2 = value - self.mean;
209        self.m2 = self.m2 + delta * delta2;
210    }
211
212    /// Get the current mean
213    pub fn mean(&self) -> Option<T> {
214        if self.count > 0 {
215            Some(self.mean)
216        } else {
217            None
218        }
219    }
220
221    /// Get the sample variance (Bessel's correction applied)
222    pub fn variance(&self) -> Option<T> {
223        if self.count > 1 {
224            Some(self.m2 / cast::<usize, T>(self.count - 1).unwrap_or(T::one()))
225        } else {
226            None
227        }
228    }
229
230    /// Get the population variance
231    pub fn population_variance(&self) -> Option<T> {
232        if self.count > 0 {
233            Some(self.m2 / cast::<usize, T>(self.count).unwrap_or(T::one()))
234        } else {
235            None
236        }
237    }
238
239    /// Get the standard deviation
240    pub fn std_dev(&self) -> Option<T> {
241        self.variance().map(|v| v.sqrt())
242    }
243}
244
245impl<T: Float> Default for WelfordVariance<T> {
246    fn default() -> Self {
247        Self::new()
248    }
249}
250
251/// Stable two-pass algorithm for variance calculation
252#[allow(dead_code)]
253pub fn stable_variance<T: Float>(values: &[T], ddof: usize) -> CoreResult<T> {
254    let n = values.len();
255    if n <= ddof {
256        return Err(CoreError::ValidationError(ErrorContext::new(
257            "Not enough values for the given degrees of freedom",
258        )));
259    }
260
261    // First pass: compute mean with compensated summation
262    let mean = stable_mean(values)?;
263
264    // Second pass: compute sum of squared deviations with compensation
265    let mut sum_sq = T::zero();
266    let mut compensation = T::zero();
267
268    for &value in values {
269        let deviation = value - mean;
270        let sq_deviation = deviation * deviation;
271        let y = sq_deviation - compensation;
272        let t = sum_sq + y;
273        compensation = (t - sum_sq) - y;
274        sum_sq = t;
275    }
276
277    let divisor = cast::<usize, T>(n - ddof).ok_or_else(|| {
278        CoreError::TypeError(ErrorContext::new("Failed to convert divisor to float"))
279    })?;
280
281    Ok(sum_sq / divisor)
282}
283
284/// Log-sum-exp trick for stable computation of log(sum(exp(x)))
285///
286/// This prevents overflow when computing the log of a sum of exponentials.
287#[allow(dead_code)]
288pub fn log_sum_exp<T: Float>(values: &[T]) -> T {
289    if values.is_empty() {
290        return T::neg_infinity();
291    }
292
293    // Find maximum value
294    let max_val = values.iter().fold(T::neg_infinity(), |a, &b| a.max(b));
295
296    if max_val.is_infinite() && max_val < T::zero() {
297        return max_val; // All values are -inf
298    }
299
300    // Compute log(sum(exp(x - max))) + max
301    let mut sum = T::zero();
302    for &value in values {
303        sum = sum + (value - max_val).exp();
304    }
305
306    max_val + sum.ln()
307}
308
309/// Stable softmax computation
310///
311/// Computes softmax(x) = exp(x) / sum(exp(x)) in a numerically stable way.
312#[allow(dead_code)]
313pub fn stable_softmax<T: Float>(values: &[T]) -> Vec<T> {
314    if values.is_empty() {
315        return vec![];
316    }
317
318    // Find maximum for numerical stability
319    let max_val = values.iter().fold(T::neg_infinity(), |a, &b| a.max(b));
320
321    // Compute exp(x - max)
322    let mut expvalues = Vec::with_capacity(values.len());
323    let mut sum = T::zero();
324
325    for &value in values {
326        let exp_val = (value - max_val).exp();
327        expvalues.push(exp_val);
328        sum = sum + exp_val;
329    }
330
331    // Normalize
332    for exp_val in &mut expvalues {
333        *exp_val = *exp_val / sum;
334    }
335
336    expvalues
337}
338
339/// Stable computation of log(1 + x) for small x
340#[allow(dead_code)]
341pub fn log1p_stable<T: Float>(x: T) -> T {
342    // Use built-in log1p if available, otherwise use series expansion for small x
343    if x.abs() < cast::<f64, T>(0.5).unwrap_or(T::zero()) {
344        // For small x, use Taylor series: log(1+x) ≈ x - x²/2 + x³/3 - ...
345        let x2 = x * x;
346        let x3 = x2 * x;
347        let x4 = x3 * x;
348        x - x2 / cast::<f64, T>(2.0).unwrap_or(T::one())
349            + x3 / cast::<f64, T>(3.0).unwrap_or(T::one())
350            - x4 / cast::<f64, T>(4.0).unwrap_or(T::one())
351    } else {
352        (T::one() + x).ln()
353    }
354}
355
356/// Stable computation of exp(x) - 1 for small x
357#[allow(dead_code)]
358pub fn expm1_stable<T: Float>(x: T) -> T {
359    if x.abs() < cast::<f64, T>(0.5).unwrap_or(T::zero()) {
360        // Use Taylor series: exp(x) - 1 ≈ x + x²/2 + x³/6 + ...
361        let x2 = x * x;
362        let x3 = x2 * x;
363        let x4 = x3 * x;
364        x + x2 / cast::<f64, T>(2.0).unwrap_or(T::one())
365            + x3 / cast::<f64, T>(6.0).unwrap_or(T::one())
366            + x4 / cast::<f64, T>(24.0).unwrap_or(T::one())
367    } else {
368        x.exp() - T::one()
369    }
370}
371
372/// Stable computation of sqrt(x² + y²) avoiding overflow
373#[allow(dead_code)]
374pub fn hypot_stable<T: Float>(x: T, y: T) -> T {
375    let x_abs = x.abs();
376    let y_abs = y.abs();
377
378    if x_abs > y_abs {
379        if x_abs.is_zero() {
380            T::zero()
381        } else {
382            let ratio = y_abs / x_abs;
383            x_abs * (T::one() + ratio * ratio).sqrt()
384        }
385    } else if y_abs.is_zero() {
386        T::zero()
387    } else {
388        let ratio = x_abs / y_abs;
389        y_abs * (T::one() + ratio * ratio).sqrt()
390    }
391}
392
393/// Stable angle reduction for trigonometric functions
394///
395/// Reduces angle to [-π, π] range while preserving precision for large angles.
396#[allow(dead_code)]
397pub fn reduce_angle<T: Float>(angle: T) -> T {
398    let two_pi = cast::<f64, T>(2.0).unwrap_or(T::one())
399        * cast::<f64, T>(std::f64::consts::PI).unwrap_or(T::one());
400    let pi = cast::<f64, T>(std::f64::consts::PI).unwrap_or(T::one());
401
402    // Use remainder to get value in (-2π, 2π)
403    let mut reduced = angle % two_pi;
404
405    // Normalize to [0, 2π)
406    if reduced < T::zero() {
407        reduced = reduced + two_pi;
408    }
409
410    // Then shift to [-π, π]
411    if reduced >= pi {
412        reduced - two_pi
413    } else {
414        reduced
415    }
416}
417
418/// Stable computation of (a*b) % m avoiding overflow
419#[allow(dead_code)]
420pub fn mulmod_stable<T: Float>(a: T, b: T, m: T) -> CoreResult<T> {
421    let m_f64 = m.to_f64().ok_or_else(|| {
422        CoreError::TypeError(ErrorContext::new(
423            "Failed to convert modulus to f64 for validation",
424        ))
425    })?;
426    check_positive(m_f64, "modulus")?;
427
428    if a.is_zero() || b.is_zero() {
429        return Ok(T::zero());
430    }
431
432    let a_mod = a % m;
433    let b_mod = b % m;
434
435    // Check if multiplication would overflow
436    let max_val = T::max_value();
437    if a_mod.abs() > max_val / b_mod.abs() {
438        // Use addition-based multiplication for large values
439        let mut result = T::zero();
440        let mut b_remaining = b_mod.abs();
441        let b_sign = if b_mod < T::zero() {
442            -T::one()
443        } else {
444            T::one()
445        };
446
447        while b_remaining > T::zero() {
448            if b_remaining >= T::one() {
449                result = (result + a_mod) % m;
450                b_remaining = b_remaining - T::one();
451            } else {
452                result = (result + a_mod * b_remaining) % m;
453                break;
454            }
455        }
456
457        Ok(result * b_sign)
458    } else {
459        Ok((a_mod * b_mod) % m)
460    }
461}
462
463/// Numerically stable sigmoid function
464#[allow(dead_code)]
465pub fn sigmoid_stable<T: Float>(x: T) -> T {
466    if x >= T::zero() {
467        let exp_neg_x = (-x).exp();
468        T::one() / (T::one() + exp_neg_x)
469    } else {
470        let exp_x = x.exp();
471        exp_x / (T::one() + exp_x)
472    }
473}
474
475/// Numerically stable log-sigmoid function
476#[allow(dead_code)]
477pub fn log_sigmoid_stable<T: Float>(x: T) -> T {
478    if x >= T::zero() {
479        -log1p_stable((-x).exp())
480    } else {
481        x - log1p_stable(x.exp())
482    }
483}
484
485/// Cross entropy loss with numerical stability
486#[allow(dead_code)]
487pub fn cross_entropy_stable<T: Float>(predictions: &[T], targets: &[T]) -> CoreResult<T> {
488    if predictions.len() != targets.len() {
489        return Err(CoreError::ValidationError(ErrorContext::new(
490            "Predictions and targets must have same length",
491        )));
492    }
493
494    let mut loss = T::zero();
495    let epsilon = cast::<f64, T>(1e-15).unwrap_or(T::epsilon()); // Small value to prevent log(0)
496
497    for (pred, target) in predictions.iter().zip(targets.iter()) {
498        // Clip _predictions to prevent log(0)
499        let pred_clipped = pred.max(epsilon).min(T::one() - epsilon);
500        loss = loss
501            - (*target * pred_clipped.ln() + (T::one() - *target) * (T::one() - pred_clipped).ln());
502    }
503
504    Ok(loss / cast::<usize, T>(predictions.len()).unwrap_or(T::one()))
505}
506
507/// Stable matrix norm computation
508#[allow(dead_code)]
509pub fn stablematrix_norm<T: Float>(matrix: &ArrayView2<T>, ord: MatrixNorm) -> CoreResult<T> {
510    validatematrix_not_empty(matrix)?;
511
512    match ord {
513        MatrixNorm::Frobenius => {
514            // Use compensated summation for Frobenius norm
515            let mut sum = T::zero();
516            let mut compensation = T::zero();
517
518            for &value in matrix.iter() {
519                let sq = value * value;
520                let y = sq - compensation;
521                let t = sum + y;
522                compensation = (t - sum) - y;
523                sum = t;
524            }
525
526            Ok(sum.sqrt())
527        }
528        MatrixNorm::One => {
529            // Maximum absolute column sum
530            let mut max_sum = T::zero();
531
532            for col in matrix.axis_iter(Axis(1)) {
533                let col_sum = stable_norm_1(&col.to_vec());
534                max_sum = max_sum.max(col_sum);
535            }
536
537            Ok(max_sum)
538        }
539        MatrixNorm::Infinity => {
540            // Maximum absolute row sum
541            let mut max_sum = T::zero();
542
543            for row in matrix.axis_iter(Axis(0)) {
544                let row_sum = stable_norm_1(&row.to_vec());
545                max_sum = max_sum.max(row_sum);
546            }
547
548            Ok(max_sum)
549        }
550    }
551}
552
553/// Matrix norm types
554#[derive(Debug, Clone, Copy)]
555pub enum MatrixNorm {
556    /// Frobenius norm (sqrt of sum of squares)
557    Frobenius,
558    /// 1-norm (maximum absolute column sum)
559    One,
560    /// Infinity norm (maximum absolute row sum)
561    Infinity,
562}
563
564/// Stable L1 norm computation
565#[allow(dead_code)]
566fn stable_norm_1<T: Float>(values: &[T]) -> T {
567    let mut sum = T::zero();
568    let mut compensation = T::zero();
569
570    for &value in values {
571        let abs_val = value.abs();
572        let y = abs_val - compensation;
573        let t = sum + y;
574        compensation = (t - sum) - y;
575        sum = t;
576    }
577
578    sum
579}
580
581/// Stable L2 norm computation avoiding overflow/underflow
582#[allow(dead_code)]
583pub fn stable_norm_2<T: Float>(values: &[T]) -> T {
584    if values.is_empty() {
585        return T::zero();
586    }
587
588    // Find maximum absolute value for scaling
589    let max_abs = values.iter().fold(T::zero(), |max, &x| max.max(x.abs()));
590
591    if max_abs.is_zero() {
592        return T::zero();
593    }
594
595    // Compute scaled norm
596    let mut sum = T::zero();
597    for &value in values {
598        let scaled = value / max_abs;
599        sum = sum + scaled * scaled;
600    }
601
602    max_abs * sum.sqrt()
603}
604
605/// Condition number estimation using 1-norm
606#[allow(dead_code)]
607pub fn condition_number_estimate<T: Float>(matrix: &ArrayView2<T>) -> CoreResult<T> {
608    validatematrix_not_empty(matrix)?;
609
610    if matrix.nrows() != matrix.ncols() {
611        return Err(CoreError::ValidationError(ErrorContext::new(
612            "Matrix must be square for condition number",
613        )));
614    }
615
616    // Compute 1-norm of matrix
617    let norm_a = stablematrix_norm(matrix, MatrixNorm::One)?;
618
619    // Estimate norm of inverse using power method
620    // This is a simplified version - a full implementation would use LAPACK's condition estimator
621    let n = matrix.nrows();
622    let mut x = Array1::from_elem(n, T::one() / cast::<usize, T>(n).unwrap_or(T::one()));
623    let mut y = Array1::zeros(n);
624
625    // Power iteration to estimate ||A^{-1}||_1
626    let max_iter = 10;
627    let mut norm_inv_estimate = T::zero();
628
629    for _ in 0..max_iter {
630        // y = A^T x
631        for i in 0..n {
632            y[i] = T::zero();
633            for j in 0..n {
634                y[i] = y[i] + matrix[[j, i]] * x[j];
635            }
636        }
637
638        // Normalize y
639        let y_norm = stable_norm_1(&y.to_vec());
640        if y_norm > T::zero() {
641            for i in 0..n {
642                y[i] = y[i] / y_norm;
643            }
644        }
645
646        // Solve A z = y (simplified - would use LU decomposition)
647        // For now, just estimate
648        norm_inv_estimate = norm_inv_estimate.max(y_norm);
649
650        x.assign(&y);
651    }
652
653    Ok(norm_a * norm_inv_estimate)
654}
655
656/// Helper function to validate matrix is not empty
657#[allow(dead_code)]
658fn validatematrix_not_empty<T>(matrix: &ArrayView2<T>) -> CoreResult<()> {
659    if matrix.is_empty() {
660        return Err(CoreError::ValidationError(ErrorContext::new(
661            "Matrix cannot be empty",
662        )));
663    }
664    Ok(())
665}
666
667/// Stable computation of binomial coefficients
668#[allow(dead_code)]
669pub fn binomial_stable(n: u64, k: u64) -> CoreResult<f64> {
670    if k > n {
671        return Ok(0.0);
672    }
673
674    let k = k.min(n - k); // Take advantage of symmetry
675
676    if k == 0 {
677        return Ok(1.0);
678    }
679
680    // Use log-space computation for large values
681    if n > 20 {
682        let mut log_result = 0.0;
683
684        for i in 0..k {
685            log_result += ((n - i) as f64).ln();
686            log_result -= ((i + 1) as f64).ln();
687        }
688
689        Ok(log_result.exp())
690    } else {
691        // Direct computation for small values
692        let mut result = 1.0;
693
694        for i in 0..k {
695            result *= (n - i) as f64;
696            result /= (i + 1) as f64;
697        }
698
699        Ok(result)
700    }
701}
702
703/// Numerically stable factorial computation
704#[allow(dead_code)]
705pub fn factorial_stable(n: u64) -> CoreResult<f64> {
706    if n == 0 || n == 1 {
707        return Ok(1.0);
708    }
709
710    // Use Stirling's approximation for large n
711    if n > 170 {
712        // For n > 170, n! overflows f64, so use log-space
713        let n_f64 = n as f64;
714        let log_factorial =
715            n_f64 * n_f64.ln() - n_f64 + 0.5 * (2.0 * std::f64::consts::PI * n_f64).ln();
716        return Ok(log_factorial.exp());
717    }
718
719    // Direct computation with overflow check
720    let mut result = 1.0;
721    for i in 2..=n {
722        result *= i as f64;
723        if !result.is_finite() {
724            return Err(CoreError::ComputationError(ErrorContext::new(
725                "Factorial overflow",
726            )));
727        }
728    }
729
730    Ok(result)
731}
732
733#[cfg(test)]
734mod tests {
735    use super::*;
736    use approx::assert_relative_eq;
737    use std::f64::consts::PI;
738
739    #[test]
740    fn test_kahan_sum() {
741        // Use values that demonstrate the benefit of Kahan summation
742        let values = vec![1.0, 1e-8, 1e-8, 1e-8, 1e-8];
743        let mut kahan = KahanSum::new();
744
745        for v in &values {
746            kahan.add(*v);
747        }
748
749        let kahan_sum = kahan.sum();
750
751        // Expected result should be 1.0 + 4e-8
752        let expected = 1.0 + 4e-8;
753        assert_relative_eq!(kahan_sum, expected, epsilon = 1e-15);
754
755        // Test with values where Kahan algorithm shows benefit
756        let mut kahan2 = KahanSum::new();
757        // These values sum to 1.0 but naive summation loses precision
758        let test_vals = vec![1.0, 1e-16, -1e-16, 1e-16, -1e-16];
759        for v in test_vals {
760            kahan2.add(v);
761        }
762        assert_relative_eq!(kahan2.sum(), 1.0, epsilon = 1e-15);
763
764        // Test accumulation of many small values
765        let mut kahan3 = KahanSum::new();
766        for _ in 0..10000 {
767            kahan3.add(0.01);
768        }
769        assert_relative_eq!(kahan3.sum(), 100.0, epsilon = 1e-10);
770    }
771
772    #[test]
773    fn test_neumaier_sum() {
774        let values = vec![1e20, 1.0, -1e20];
775        let sum = neumaier_sum(&values);
776        assert_relative_eq!(sum, 1.0, epsilon = 1e-10);
777    }
778
779    #[test]
780    fn test_pairwise_sum() {
781        let values: Vec<f64> = (0..1000).map(|i| 0.1 + 0.001 * i as f64).collect();
782        let sum = pairwise_sum(&values);
783
784        // Compare with known result
785        let expected = values.iter().sum::<f64>();
786        assert_relative_eq!(sum, expected, epsilon = 1e-10);
787    }
788
789    #[test]
790    fn test_welford_variance() {
791        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
792        let mut welford = WelfordVariance::new();
793
794        for &v in &values {
795            welford.add(v);
796        }
797
798        assert_relative_eq!(
799            welford.mean().expect("Mean should be available"),
800            3.0,
801            epsilon = 1e-10
802        );
803        assert_relative_eq!(
804            welford.variance().expect("Variance should be available"),
805            2.5,
806            epsilon = 1e-10
807        );
808        assert_relative_eq!(
809            welford.std_dev().expect("Std dev should be available"),
810            2.5_f64.sqrt(),
811            epsilon = 1e-10
812        );
813    }
814
815    #[test]
816    fn testlog_sum_exp() {
817        let values = vec![1000.0, 1000.0, 1000.0];
818        let result = log_sum_exp(&values);
819        let expected = 1000.0 + 3.0_f64.ln();
820        assert_relative_eq!(result, expected, epsilon = 1e-10);
821
822        // Test with empty array
823        let empty: Vec<f64> = vec![];
824        assert!(log_sum_exp(&empty).is_infinite());
825    }
826
827    #[test]
828    fn test_stable_softmax() {
829        let values = vec![1000.0, 1000.0, 1000.0];
830        let softmax = stable_softmax(&values);
831
832        for &p in &softmax {
833            assert_relative_eq!(p, 1.0 / 3.0, epsilon = 1e-10);
834        }
835
836        let sum: f64 = softmax.iter().sum();
837        assert_relative_eq!(sum, 1.0, epsilon = 1e-10);
838    }
839
840    #[test]
841    fn test_hypot_stable() {
842        // Test with large values that would overflow with naive x² + y²
843        let x = 1e200;
844        let y = 1e200;
845        let result = hypot_stable(x, y);
846        let expected = 2.0_f64.sqrt() * 1e200;
847        assert_relative_eq!(result, expected, epsilon = 1e-10);
848
849        // Test with zero
850        assert_eq!(hypot_stable(0.0, 5.0), 5.0);
851        assert_eq!(hypot_stable(3.0, 0.0), 3.0);
852    }
853
854    #[test]
855    fn test_sigmoid_stable() {
856        // Test large positive value
857        let result = sigmoid_stable(100.0);
858        assert_relative_eq!(result, 1.0, epsilon = 1e-10);
859
860        // Test large negative value
861        let result = sigmoid_stable(-100.0);
862        assert!(result > 0.0 && result < 1e-40);
863
864        // Test zero
865        assert_relative_eq!(sigmoid_stable(0.0), 0.5, epsilon = 1e-10);
866    }
867
868    #[test]
869    fn test_reduce_angle() {
870        // Test angle reduction
871        assert_relative_eq!(reduce_angle(3.0 * PI), -PI, epsilon = 1e-10);
872        assert_relative_eq!(reduce_angle(5.0 * PI), -PI, epsilon = 1e-10);
873        assert_relative_eq!(reduce_angle(-3.0 * PI), -PI, epsilon = 1e-10);
874        assert_relative_eq!(reduce_angle(2.0 * PI), 0.0, epsilon = 1e-10);
875        assert_relative_eq!(reduce_angle(-2.0 * PI), 0.0, epsilon = 1e-10);
876        assert_relative_eq!(reduce_angle(7.0 * PI), -PI, epsilon = 1e-10);
877        assert_relative_eq!(reduce_angle(-7.0 * PI), -PI, epsilon = 1e-10);
878
879        // Test angle already in range
880        assert_relative_eq!(reduce_angle(PI / 2.0), PI / 2.0, epsilon = 1e-10);
881    }
882
883    #[test]
884    fn test_stable_norm_2() {
885        // Test with values that would overflow
886        let values = vec![1e200, 1e200, 1e200];
887        let norm = stable_norm_2(&values);
888        let expected = 3.0_f64.sqrt() * 1e200;
889        assert_relative_eq!(norm, expected, epsilon = 1e-10);
890
891        // Test with very small values
892        let smallvalues = vec![1e-200, 1e-200, 1e-200];
893        let small_norm = stable_norm_2(&smallvalues);
894        let expected_small = 3.0_f64.sqrt() * 1e-200;
895        assert_relative_eq!(small_norm, expected_small, epsilon = 1e-10);
896    }
897
898    #[test]
899    fn test_binomial_stable() {
900        // Test small values
901        assert_eq!(
902            binomial_stable(5, 2).expect("Binomial coefficient should succeed"),
903            10.0
904        );
905        assert_eq!(
906            binomial_stable(10, 3).expect("Binomial coefficient should succeed"),
907            120.0
908        );
909
910        // Test edge cases
911        assert_eq!(
912            binomial_stable(5, 0).expect("Binomial coefficient should succeed"),
913            1.0
914        );
915        assert_eq!(
916            binomial_stable(5, 5).expect("Binomial coefficient should succeed"),
917            1.0
918        );
919        assert_eq!(
920            binomial_stable(5, 6).expect("Binomial coefficient should succeed"),
921            0.0
922        );
923
924        // Test large values
925        let large_result =
926            binomial_stable(100, 50).expect("Binomial coefficient should handle large values");
927        assert!(large_result.is_finite() && large_result > 0.0);
928    }
929
930    #[test]
931    fn test_factorial_stable() {
932        // Test small values
933        assert_eq!(
934            factorial_stable(0).expect("Factorial of 0 should succeed"),
935            1.0
936        );
937        assert_eq!(
938            factorial_stable(1).expect("Factorial of 1 should succeed"),
939            1.0
940        );
941        assert_eq!(
942            factorial_stable(5).expect("Factorial of 5 should succeed"),
943            120.0
944        );
945
946        // Test larger value
947        assert_relative_eq!(
948            factorial_stable(10).expect("Factorial of 10 should succeed"),
949            3628800.0,
950            epsilon = 1e-10
951        );
952    }
953}