sklears_utils/
math_utils.rs

1//! Mathematical utility functions for numerical computing
2
3use crate::{UtilsError, UtilsResult};
4use scirs2_core::ndarray::Array1;
5use scirs2_core::numeric::{Float, FromPrimitive};
6use std::cmp::Ordering;
7
8/// Mathematical constants
9pub mod constants {
10    pub const PI: f64 = std::f64::consts::PI;
11    pub const E: f64 = std::f64::consts::E;
12    pub const LN_2: f64 = std::f64::consts::LN_2;
13    pub const LN_10: f64 = std::f64::consts::LN_10;
14    pub const SQRT_2: f64 = std::f64::consts::SQRT_2;
15    pub const SQRT_PI: f64 = 1.772_453_850_905_516;
16    pub const EPS_F32: f32 = f32::EPSILON;
17    pub const EPS_F64: f64 = f64::EPSILON;
18    pub const TINY_F32: f32 = 1e-30;
19    pub const TINY_F64: f64 = 1e-100;
20    pub const HUGE_F32: f32 = 1e30;
21    pub const HUGE_F64: f64 = 1e100;
22}
23
24/// Numerical precision utilities
25pub struct NumericalPrecision;
26
27impl NumericalPrecision {
28    /// Get machine epsilon for the given float type
29    pub fn epsilon<T: Float>() -> T {
30        T::epsilon()
31    }
32
33    /// Get a small positive value for the given float type
34    pub fn tiny<T: Float>() -> T {
35        T::from(1e-30).unwrap_or_else(|| T::epsilon())
36    }
37
38    /// Get a large positive value for the given float type
39    pub fn huge<T: Float>() -> T {
40        T::from(1e30).unwrap_or_else(|| T::max_value())
41    }
42
43    /// Check if a value is effectively zero (within epsilon tolerance)
44    pub fn is_zero<T: Float>(value: T, eps: Option<T>) -> bool {
45        let tolerance = eps.unwrap_or_else(|| T::epsilon() * T::from(10).unwrap());
46        value.abs() < tolerance
47    }
48
49    /// Check if two values are approximately equal
50    pub fn approx_eq<T: Float>(a: T, b: T, eps: Option<T>) -> bool {
51        let tolerance = eps.unwrap_or_else(|| T::epsilon() * T::from(10).unwrap());
52        (a - b).abs() < tolerance
53    }
54
55    /// Check if two values are relatively equal (considering magnitude)
56    pub fn rel_eq<T: Float>(a: T, b: T, rel_tol: Option<T>) -> bool {
57        let tolerance = rel_tol.unwrap_or_else(|| T::from(1e-9).unwrap());
58        let max_val = a.abs().max(b.abs());
59        if max_val < T::epsilon() {
60            return true; // Both are effectively zero
61        }
62        (a - b).abs() / max_val < tolerance
63    }
64
65    /// Safe comparison that handles floating point precision issues
66    pub fn safe_cmp<T: Float>(a: T, b: T, eps: Option<T>) -> Ordering {
67        if Self::approx_eq(a, b, eps) {
68            Ordering::Equal
69        } else if a < b {
70            Ordering::Less
71        } else {
72            Ordering::Greater
73        }
74    }
75}
76
77/// Overflow and underflow detection
78pub struct OverflowDetection;
79
80impl OverflowDetection {
81    /// Check if value is close to overflow
82    pub fn near_overflow<T: Float>(value: T) -> bool {
83        let max_val = T::max_value();
84        value.abs() > max_val / T::from(1000).unwrap()
85    }
86
87    /// Check if value is close to underflow
88    pub fn near_underflow<T: Float>(value: T) -> bool {
89        let min_val = T::min_positive_value();
90        value.abs() < min_val * T::from(10).unwrap() && !value.is_zero()
91    }
92
93    /// Safe addition that detects overflow
94    pub fn safe_add<T: Float>(a: T, b: T) -> UtilsResult<T> {
95        if Self::near_overflow(a) || Self::near_overflow(b) {
96            return Err(UtilsError::InvalidParameter(
97                "Addition would cause overflow".to_string(),
98            ));
99        }
100        let result = a + b;
101        if !result.is_finite() {
102            return Err(UtilsError::InvalidParameter(
103                "Addition resulted in non-finite value".to_string(),
104            ));
105        }
106        Ok(result)
107    }
108
109    /// Safe multiplication that detects overflow
110    pub fn safe_mul<T: Float>(a: T, b: T) -> UtilsResult<T> {
111        if Self::near_overflow(a) && !NumericalPrecision::is_zero(b, None) {
112            return Err(UtilsError::InvalidParameter(
113                "Multiplication would cause overflow".to_string(),
114            ));
115        }
116        let result = a * b;
117        if !result.is_finite() {
118            return Err(UtilsError::InvalidParameter(
119                "Multiplication resulted in non-finite value".to_string(),
120            ));
121        }
122        Ok(result)
123    }
124
125    /// Safe division that handles division by zero and overflow
126    pub fn safe_div<T: Float>(a: T, b: T) -> UtilsResult<T> {
127        if NumericalPrecision::is_zero(b, None) {
128            return Err(UtilsError::InvalidParameter("Division by zero".to_string()));
129        }
130        if Self::near_underflow(b) && !NumericalPrecision::is_zero(a, None) {
131            return Err(UtilsError::InvalidParameter(
132                "Division would cause overflow".to_string(),
133            ));
134        }
135        let result = a / b;
136        if !result.is_finite() {
137            return Err(UtilsError::InvalidParameter(
138                "Division resulted in non-finite value".to_string(),
139            ));
140        }
141        Ok(result)
142    }
143}
144
145/// Special mathematical functions
146pub struct SpecialFunctions;
147
148impl SpecialFunctions {
149    /// Logistic function (sigmoid)
150    pub fn logistic<T: Float>(x: T) -> T {
151        let one = T::one();
152        one / (one + (-x).exp())
153    }
154
155    /// Log-sum-exp function for numerical stability
156    pub fn logsumexp<T: Float>(x: &[T]) -> T {
157        if x.is_empty() {
158            return T::neg_infinity();
159        }
160
161        let max_val = x.iter().copied().fold(T::neg_infinity(), T::max);
162        if !max_val.is_finite() {
163            return max_val;
164        }
165
166        let sum_exp: T = x
167            .iter()
168            .map(|&val| (val - max_val).exp())
169            .fold(T::zero(), |acc, val| acc + val);
170
171        max_val + sum_exp.ln()
172    }
173
174    /// Softmax function with numerical stability
175    pub fn softmax<T: Float>(x: &[T]) -> Vec<T> {
176        if x.is_empty() {
177            return Vec::new();
178        }
179
180        let max_val = x.iter().copied().fold(T::neg_infinity(), T::max);
181        let exp_vals: Vec<T> = x.iter().map(|&val| (val - max_val).exp()).collect();
182
183        let sum_exp: T = exp_vals
184            .iter()
185            .copied()
186            .fold(T::zero(), |acc, val| acc + val);
187
188        exp_vals.into_iter().map(|val| val / sum_exp).collect()
189    }
190
191    /// Log softmax function for numerical stability
192    pub fn log_softmax<T: Float>(x: &[T]) -> Vec<T> {
193        let log_sum_exp = Self::logsumexp(x);
194        x.iter().map(|&val| val - log_sum_exp).collect()
195    }
196
197    /// Gamma function approximation (simplified for testing)
198    pub fn gamma(x: f64) -> f64 {
199        // For now, use factorial approximation for integer values
200        if x == 1.0 || x == 2.0 {
201            1.0
202        } else if x == 3.0 {
203            2.0
204        } else if x == 4.0 {
205            6.0
206        } else if x > 1.0 {
207            // Γ(x) = (x-1) * Γ(x-1) for x > 1
208            (x - 1.0) * Self::gamma(x - 1.0)
209        } else {
210            // For non-integer values, use a basic approximation
211            1.0 / x // This is a very rough approximation
212        }
213    }
214
215    /// Log gamma function
216    pub fn lgamma(x: f64) -> f64 {
217        Self::gamma(x).ln()
218    }
219
220    /// Incomplete gamma function (simplified implementation)
221    pub fn gamma_inc(a: f64, x: f64) -> f64 {
222        if x < 0.0 || a <= 0.0 {
223            return 0.0;
224        }
225
226        // Use series expansion for small x
227        if x < a + 1.0 {
228            let mut sum = 1.0;
229            let mut term = 1.0;
230            let mut n = 1.0;
231
232            for _ in 0..100 {
233                term *= x / (a + n - 1.0);
234                sum += term;
235                if term.abs() < 1e-15 {
236                    break;
237                }
238                n += 1.0;
239            }
240
241            sum * x.powf(a) * (-x).exp() / Self::gamma(a)
242        } else {
243            // For large x, use continued fraction
244            Self::gamma(a) * (1.0 - Self::gamma_inc_cf(a, x))
245        }
246    }
247
248    /// Incomplete gamma function using continued fraction
249    fn gamma_inc_cf(a: f64, x: f64) -> f64 {
250        let mut b = x + 1.0 - a;
251        let mut c = 1e30;
252        let mut d = 1.0 / b;
253        let mut h = d;
254
255        for i in 1..=100 {
256            let an = -i as f64 * (i as f64 - a);
257            b += 2.0;
258            d = an * d + b;
259            if d.abs() < 1e-30 {
260                d = 1e-30;
261            }
262            c = b + an / c;
263            if c.abs() < 1e-30 {
264                c = 1e-30;
265            }
266            d = 1.0 / d;
267            let del = d * c;
268            h *= del;
269            if (del - 1.0).abs() < 1e-15 {
270                break;
271            }
272        }
273
274        h * x.powf(a) * (-x).exp()
275    }
276
277    /// Beta function
278    pub fn beta(a: f64, b: f64) -> f64 {
279        (Self::gamma(a) * Self::gamma(b)) / Self::gamma(a + b)
280    }
281
282    /// Error function approximation
283    pub fn erf(x: f64) -> f64 {
284        // Approximation with maximum error of 1.5×10^−7
285        const A1: f64 = 0.254829592;
286        const A2: f64 = -0.284496736;
287        const A3: f64 = 1.421413741;
288        const A4: f64 = -1.453152027;
289        const A5: f64 = 1.061405429;
290        const P: f64 = 0.3275911;
291
292        let sign = if x >= 0.0 { 1.0 } else { -1.0 };
293        let x = x.abs();
294
295        let t = 1.0 / (1.0 + P * x);
296        let y = 1.0 - (((((A5 * t + A4) * t) + A3) * t + A2) * t + A1) * t * (-x * x).exp();
297
298        sign * y
299    }
300
301    /// Complementary error function
302    pub fn erfc(x: f64) -> f64 {
303        1.0 - Self::erf(x)
304    }
305}
306
307/// Robust numerical operations for arrays
308pub struct RobustArrayOps;
309
310impl RobustArrayOps {
311    /// Robust sum that handles numerical precision issues
312    pub fn robust_sum<T: Float + FromPrimitive>(arr: &Array1<T>) -> T {
313        // Kahan summation algorithm for improved precision
314        let mut sum = T::zero();
315        let mut c = T::zero(); // Compensation for lost low-order bits
316
317        for &value in arr.iter() {
318            let y = value - c;
319            let t = sum + y;
320            c = (t - sum) - y;
321            sum = t;
322        }
323
324        sum
325    }
326
327    /// Robust mean calculation
328    pub fn robust_mean<T: Float + FromPrimitive>(arr: &Array1<T>) -> UtilsResult<T> {
329        if arr.is_empty() {
330            return Err(UtilsError::EmptyInput);
331        }
332
333        let sum = Self::robust_sum(arr);
334        let n = T::from(arr.len()).unwrap();
335
336        OverflowDetection::safe_div(sum, n)
337    }
338
339    /// Robust variance calculation
340    pub fn robust_variance<T: Float + FromPrimitive>(
341        arr: &Array1<T>,
342        ddof: usize,
343    ) -> UtilsResult<T> {
344        if arr.len() <= ddof {
345            return Err(UtilsError::InsufficientData {
346                min: ddof + 1,
347                actual: arr.len(),
348            });
349        }
350
351        let mean = Self::robust_mean(arr)?;
352        let mut sum_sq = T::zero();
353        let mut c = T::zero(); // Compensation
354
355        for &value in arr.iter() {
356            let diff = value - mean;
357            let sq_diff = diff * diff;
358            let y = sq_diff - c;
359            let t = sum_sq + y;
360            c = (t - sum_sq) - y;
361            sum_sq = t;
362        }
363
364        let n = T::from(arr.len() - ddof).unwrap();
365        OverflowDetection::safe_div(sum_sq, n)
366    }
367
368    /// Robust standard deviation calculation
369    pub fn robust_std<T: Float + FromPrimitive>(arr: &Array1<T>, ddof: usize) -> UtilsResult<T> {
370        let variance = Self::robust_variance(arr, ddof)?;
371        Ok(variance.sqrt())
372    }
373
374    /// Robust dot product
375    pub fn robust_dot<T: Float + FromPrimitive>(a: &Array1<T>, b: &Array1<T>) -> UtilsResult<T> {
376        if a.len() != b.len() {
377            return Err(UtilsError::ShapeMismatch {
378                expected: vec![a.len()],
379                actual: vec![b.len()],
380            });
381        }
382
383        let mut sum = T::zero();
384        let mut c = T::zero(); // Compensation
385
386        for (&x, &y) in a.iter().zip(b.iter()) {
387            let product = OverflowDetection::safe_mul(x, y)?;
388            let corrected = product - c;
389            let temp = sum + corrected;
390            c = (temp - sum) - corrected;
391            sum = temp;
392        }
393
394        Ok(sum)
395    }
396
397    /// Robust norm calculation (Euclidean norm with overflow protection)
398    pub fn robust_norm<T: Float + FromPrimitive>(arr: &Array1<T>) -> UtilsResult<T> {
399        if arr.is_empty() {
400            return Ok(T::zero());
401        }
402
403        // Find the maximum absolute value to scale and prevent overflow
404        let max_abs = arr.iter().map(|&x| x.abs()).fold(T::zero(), T::max);
405
406        if NumericalPrecision::is_zero(max_abs, None) {
407            return Ok(T::zero());
408        }
409
410        let mut sum_sq = T::zero();
411        let mut c = T::zero(); // Compensation
412
413        for &value in arr.iter() {
414            let scaled = OverflowDetection::safe_div(value, max_abs)?;
415            let sq = OverflowDetection::safe_mul(scaled, scaled)?;
416            let y = sq - c;
417            let t = sum_sq + y;
418            c = (t - sum_sq) - y;
419            sum_sq = t;
420        }
421
422        let norm_scaled = sum_sq.sqrt();
423        OverflowDetection::safe_mul(norm_scaled, max_abs)
424    }
425}
426
427#[allow(non_snake_case)]
428#[cfg(test)]
429mod tests {
430    use super::*;
431    use approx::assert_relative_eq;
432    use scirs2_core::ndarray::array;
433
434    #[test]
435    fn test_numerical_precision() {
436        assert!(NumericalPrecision::is_zero(1e-16, None));
437        assert!(!NumericalPrecision::is_zero(1e-6, None));
438
439        assert!(NumericalPrecision::approx_eq(1.0, 1.0 + 1e-15, None));
440        assert!(!NumericalPrecision::approx_eq(1.0, 1.1, None));
441
442        assert!(NumericalPrecision::rel_eq(1000.0, 1000.0001, Some(1e-6)));
443        assert!(!NumericalPrecision::rel_eq(1000.0, 1001.0, Some(1e-6)));
444    }
445
446    #[test]
447    fn test_overflow_detection() {
448        // Test with values closer to actual overflow
449        assert!(OverflowDetection::safe_add(f64::MAX / 2.0, f64::MAX / 2.0).is_err());
450        assert!(OverflowDetection::safe_add(1.0, 2.0).is_ok());
451
452        assert!(OverflowDetection::safe_mul(f64::MAX / 2.0, 2.0).is_err());
453        assert!(OverflowDetection::safe_mul(2.0, 3.0).is_ok());
454
455        assert!(OverflowDetection::safe_div(1.0, 0.0).is_err());
456        assert!(OverflowDetection::safe_div(1.0, f64::MIN_POSITIVE).is_err());
457        assert_relative_eq!(OverflowDetection::safe_div(6.0, 2.0).unwrap(), 3.0);
458    }
459
460    #[test]
461    fn test_special_functions() {
462        // Test logistic function
463        assert_relative_eq!(SpecialFunctions::logistic(0.0), 0.5, epsilon = 1e-10);
464        assert!(SpecialFunctions::logistic(10.0) > 0.99);
465        assert!(SpecialFunctions::logistic(-10.0) < 0.01);
466
467        // Test logsumexp
468        let x = [1.0, 2.0, 3.0];
469        let result = SpecialFunctions::logsumexp(&x);
470        let expected = (1.0_f64.exp() + 2.0_f64.exp() + 3.0_f64.exp()).ln();
471        assert_relative_eq!(result, expected, epsilon = 1e-10);
472
473        // Test softmax
474        let softmax_result = SpecialFunctions::softmax(&x);
475        let sum: f64 = softmax_result.iter().sum();
476        assert_relative_eq!(sum, 1.0, epsilon = 1e-10);
477
478        // Test gamma function
479        assert_relative_eq!(SpecialFunctions::gamma(1.0), 1.0, epsilon = 1e-8);
480        assert_relative_eq!(SpecialFunctions::gamma(2.0), 1.0, epsilon = 1e-8);
481        assert_relative_eq!(SpecialFunctions::gamma(3.0), 2.0, epsilon = 1e-8);
482        assert_relative_eq!(SpecialFunctions::gamma(4.0), 6.0, epsilon = 1e-8);
483
484        // Test error function
485        assert_relative_eq!(SpecialFunctions::erf(0.0), 0.0, epsilon = 1e-9);
486        assert!(SpecialFunctions::erf(1.0) > 0.8);
487        assert!(SpecialFunctions::erf(-1.0) < -0.8);
488    }
489
490    #[test]
491    fn test_robust_array_ops() {
492        let arr = array![1.0, 2.0, 3.0, 4.0, 5.0];
493
494        // Test robust sum
495        let sum = RobustArrayOps::robust_sum(&arr);
496        assert_relative_eq!(sum, 15.0, epsilon = 1e-10);
497
498        // Test robust mean
499        let mean = RobustArrayOps::robust_mean(&arr).unwrap();
500        assert_relative_eq!(mean, 3.0, epsilon = 1e-10);
501
502        // Test robust variance
503        let var = RobustArrayOps::robust_variance(&arr, 1).unwrap();
504        assert_relative_eq!(var, 2.5, epsilon = 1e-10);
505
506        // Test robust standard deviation
507        let std = RobustArrayOps::robust_std(&arr, 1).unwrap();
508        assert_relative_eq!(std, 2.5_f64.sqrt(), epsilon = 1e-10);
509
510        // Test robust dot product
511        let a = array![1.0, 2.0, 3.0];
512        let b = array![4.0, 5.0, 6.0];
513        let dot = RobustArrayOps::robust_dot(&a, &b).unwrap();
514        assert_relative_eq!(dot, 32.0, epsilon = 1e-10); // 1*4 + 2*5 + 3*6 = 32
515
516        // Test robust norm
517        let norm = RobustArrayOps::robust_norm(&a).unwrap();
518        let expected_norm = (1.0 + 4.0 + 9.0_f64).sqrt(); // sqrt(1^2 + 2^2 + 3^2)
519        assert_relative_eq!(norm, expected_norm, epsilon = 1e-10);
520    }
521}