Skip to main content

scirs2_core/numeric/
stability_toolkit.rs

1//! # Numeric Stability Toolkit
2//!
3//! Extended numeric stability utilities that complement the core `stability` module.
4//!
5//! This module provides:
6//! - **Error helpers**: `relative_error`, `absolute_error`
7//! - **Compensated summation**: convenience wrapper around Kahan accumulation
8//! - **Stable activations**: `softmax_array`, `sigmoid_array` (ndarray-based)
9//! - **Condition estimation**: `condition_number_1d` for 1-D ratio analysis
10//! - **Numerical differentiation**: `numerical_gradient` with forward/backward/central modes
11//! - **Gradient checking**: `check_gradient` to compare analytical vs numerical gradients
12
13use crate::error::{CoreError, CoreResult, ErrorContext};
14use ::ndarray::{Array1, ArrayView1};
15use num_traits::{Float, FromPrimitive};
16use std::fmt::{Debug, Display};
17
18// ---------------------------------------------------------------------------
19// Error measurement helpers
20// ---------------------------------------------------------------------------
21
22/// Absolute error between two values: |a - b|.
23pub fn absolute_error<T: Float>(a: T, b: T) -> T {
24    (a - b).abs()
25}
26
27/// Relative error between `computed` and `reference`: |computed - reference| / |reference|.
28///
29/// Returns `T::infinity()` when `reference` is zero and `computed` is non-zero.
30/// Returns `T::zero()` when both are zero.
31pub fn relative_error<T: Float>(computed: T, reference: T) -> T {
32    let diff = (computed - reference).abs();
33    let denom = reference.abs();
34    if denom.is_zero() {
35        if diff.is_zero() {
36            T::zero()
37        } else {
38            T::infinity()
39        }
40    } else {
41        diff / denom
42    }
43}
44
45/// Element-wise relative errors between two arrays.
46/// Returns `Err` if the arrays have different lengths.
47pub fn relative_errors<T: Float + Display>(
48    computed: &ArrayView1<T>,
49    reference: &ArrayView1<T>,
50) -> CoreResult<Array1<T>> {
51    if computed.len() != reference.len() {
52        return Err(CoreError::ShapeError(ErrorContext::new(format!(
53            "Array length mismatch: computed has {} elements, reference has {}",
54            computed.len(),
55            reference.len()
56        ))));
57    }
58    let out: Vec<T> = computed
59        .iter()
60        .zip(reference.iter())
61        .map(|(&c, &r)| relative_error(c, r))
62        .collect();
63    Ok(Array1::from_vec(out))
64}
65
66/// Maximum relative error across all elements.
67pub fn max_relative_error<T: Float + Display>(
68    computed: &ArrayView1<T>,
69    reference: &ArrayView1<T>,
70) -> CoreResult<T> {
71    let errs = relative_errors(computed, reference)?;
72    Ok(errs
73        .iter()
74        .copied()
75        .fold(T::zero(), |acc, e| if e > acc { e } else { acc }))
76}
77
78// ---------------------------------------------------------------------------
79// Compensated summation (convenience)
80// ---------------------------------------------------------------------------
81
82/// Compute a compensated (Kahan) sum of a slice.
83///
84/// This is a convenience wrapper that calls Kahan summation from the
85/// `stability` module internals.
86pub fn compensated_sum<T: Float>(values: &[T]) -> T {
87    let mut sum = T::zero();
88    let mut compensation = T::zero();
89    for &val in values {
90        let y = val - compensation;
91        let t = sum + y;
92        compensation = (t - sum) - y;
93        sum = t;
94    }
95    sum
96}
97
98/// Compute a compensated (Neumaier) sum of an ndarray view.
99///
100/// Uses Neumaier's improvement over Kahan summation: when the addend
101/// is larger in magnitude than the running sum the compensation tracks
102/// the smaller value, giving correct results even for inputs like
103/// `[1e20, 1.0, -1e20]`.
104pub fn compensated_sum_array<T: Float>(values: &ArrayView1<T>) -> T {
105    if values.is_empty() {
106        return T::zero();
107    }
108    let mut sum = values[0];
109    let mut compensation = T::zero();
110    for &val in values.iter().skip(1) {
111        let t = sum + val;
112        if sum.abs() >= val.abs() {
113            compensation = compensation + ((sum - t) + val);
114        } else {
115            compensation = compensation + ((val - t) + sum);
116        }
117        sum = t;
118    }
119    sum + compensation
120}
121
122// ---------------------------------------------------------------------------
123// Pairwise summation (array-friendly)
124// ---------------------------------------------------------------------------
125
126/// Pairwise summation for an ndarray view.
127///
128/// Recursively splits the array and sums halves, achieving O(log n) error growth.
129pub fn pairwise_sum_array<T: Float>(values: &ArrayView1<T>) -> T {
130    const THRESHOLD: usize = 128;
131    let n = values.len();
132    match n {
133        0 => T::zero(),
134        1 => values[0],
135        _ if n <= THRESHOLD => compensated_sum_array(values),
136        _ => {
137            let mid = n / 2;
138            let left = values.slice(ndarray::s![..mid]);
139            let right = values.slice(ndarray::s![mid..]);
140            pairwise_sum_array(&left) + pairwise_sum_array(&right)
141        }
142    }
143}
144
145// ---------------------------------------------------------------------------
146// Stable softmax / sigmoid for arrays
147// ---------------------------------------------------------------------------
148
149/// Numerically stable softmax for an ndarray 1-D array.
150///
151/// Subtracts the maximum before exponentiation to prevent overflow.
152pub fn softmax_array<T: Float>(values: &ArrayView1<T>) -> Array1<T> {
153    if values.is_empty() {
154        return Array1::from_vec(vec![]);
155    }
156    let max_val = values
157        .iter()
158        .copied()
159        .fold(T::neg_infinity(), |a, b| a.max(b));
160
161    let exp_vals: Vec<T> = values.iter().map(|&v| (v - max_val).exp()).collect();
162    let sum: T = exp_vals.iter().copied().fold(T::zero(), |a, b| a + b);
163    Array1::from_vec(exp_vals.into_iter().map(|e| e / sum).collect())
164}
165
166/// Numerically stable sigmoid for an ndarray 1-D array.
167pub fn sigmoid_array<T: Float>(values: &ArrayView1<T>) -> Array1<T> {
168    let out: Vec<T> = values
169        .iter()
170        .map(|&x| {
171            if x >= T::zero() {
172                let exp_neg = (-x).exp();
173                T::one() / (T::one() + exp_neg)
174            } else {
175                let exp_x = x.exp();
176                exp_x / (T::one() + exp_x)
177            }
178        })
179        .collect();
180    Array1::from_vec(out)
181}
182
183/// Numerically stable log-sum-exp for an ndarray 1-D array.
184pub fn log_sum_exp_array<T: Float>(values: &ArrayView1<T>) -> T {
185    if values.is_empty() {
186        return T::neg_infinity();
187    }
188    let max_val = values
189        .iter()
190        .copied()
191        .fold(T::neg_infinity(), |a, b| a.max(b));
192    if max_val.is_infinite() && max_val < T::zero() {
193        return max_val;
194    }
195    let sum: T = values
196        .iter()
197        .map(|&v| (v - max_val).exp())
198        .fold(T::zero(), |a, b| a + b);
199    max_val + sum.ln()
200}
201
202// ---------------------------------------------------------------------------
203// Condition number estimation for 1-D (ratio of max/min absolute values)
204// ---------------------------------------------------------------------------
205
206/// Estimate a "condition number" for a 1-D array as max(|x|) / min_nonzero(|x|).
207///
208/// This gives insight into the dynamic range and potential for cancellation.
209/// Returns `Err` if the array is empty or all zeros.
210pub fn condition_number_1d<T: Float + Display>(values: &ArrayView1<T>) -> CoreResult<T> {
211    if values.is_empty() {
212        return Err(CoreError::ValueError(ErrorContext::new(
213            "Cannot compute condition number of empty array",
214        )));
215    }
216    let mut max_abs = T::zero();
217    let mut min_abs = T::infinity();
218    for &v in values.iter() {
219        let a = v.abs();
220        if a > max_abs {
221            max_abs = a;
222        }
223        if a > T::zero() && a < min_abs {
224            min_abs = a;
225        }
226    }
227    if max_abs.is_zero() {
228        return Err(CoreError::ValueError(ErrorContext::new(
229            "All elements are zero; condition number is undefined",
230        )));
231    }
232    if min_abs.is_infinite() {
233        return Err(CoreError::ValueError(ErrorContext::new(
234            "No non-zero elements found for condition number",
235        )));
236    }
237    Ok(max_abs / min_abs)
238}
239
240// ---------------------------------------------------------------------------
241// Numerical differentiation
242// ---------------------------------------------------------------------------
243
244/// Mode of finite difference approximation.
245#[derive(Debug, Clone, Copy, PartialEq, Eq)]
246pub enum DifferenceMode {
247    /// f'(x) ~ (f(x+h) - f(x)) / h
248    Forward,
249    /// f'(x) ~ (f(x) - f(x-h)) / h
250    Backward,
251    /// f'(x) ~ (f(x+h) - f(x-h)) / (2h)  -- O(h^2) accuracy
252    Central,
253}
254
255/// Compute the numerical gradient of a scalar function at point `x`.
256///
257/// * `f` -- function from R^n -> R, taking a slice and returning a scalar.
258/// * `x` -- the point at which to evaluate the gradient.
259/// * `h` -- step size (e.g. 1e-5).
260/// * `mode` -- finite difference mode.
261///
262/// Returns an `Array1<T>` of the same length as `x`.
263pub fn numerical_gradient<T, F>(f: &F, x: &[T], h: T, mode: DifferenceMode) -> CoreResult<Array1<T>>
264where
265    T: Float + FromPrimitive + Debug,
266    F: Fn(&[T]) -> T,
267{
268    let n = x.len();
269    let two = T::from_f64(2.0).ok_or_else(|| {
270        CoreError::TypeError(ErrorContext::new("Failed to convert 2.0 to target type"))
271    })?;
272
273    let mut grad = Array1::zeros(n);
274    let mut x_perturbed = x.to_vec();
275
276    for i in 0..n {
277        let original = x_perturbed[i];
278
279        match mode {
280            DifferenceMode::Forward => {
281                x_perturbed[i] = original + h;
282                let f_plus = f(&x_perturbed);
283                x_perturbed[i] = original;
284                let f_0 = f(&x_perturbed);
285                grad[i] = (f_plus - f_0) / h;
286            }
287            DifferenceMode::Backward => {
288                x_perturbed[i] = original;
289                let f_0 = f(&x_perturbed);
290                x_perturbed[i] = original - h;
291                let f_minus = f(&x_perturbed);
292                grad[i] = (f_0 - f_minus) / h;
293            }
294            DifferenceMode::Central => {
295                x_perturbed[i] = original + h;
296                let f_plus = f(&x_perturbed);
297                x_perturbed[i] = original - h;
298                let f_minus = f(&x_perturbed);
299                grad[i] = (f_plus - f_minus) / (two * h);
300            }
301        }
302
303        // Restore
304        x_perturbed[i] = original;
305    }
306
307    Ok(grad)
308}
309
310// ---------------------------------------------------------------------------
311// Gradient checking
312// ---------------------------------------------------------------------------
313
314/// Result of a gradient check.
315#[derive(Debug, Clone)]
316pub struct GradientCheckResult<T: Float> {
317    /// Element-wise relative errors between analytical and numerical gradients.
318    pub relative_errors: Array1<T>,
319    /// Maximum relative error.
320    pub max_relative_error: T,
321    /// Mean relative error.
322    pub mean_relative_error: T,
323    /// Whether the check passed (max_relative_error < tolerance).
324    pub passed: bool,
325}
326
327impl<T: Float + Display> std::fmt::Display for GradientCheckResult<T> {
328    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
329        write!(
330            f,
331            "GradientCheck(passed={}, max_rel_err={}, mean_rel_err={})",
332            self.passed, self.max_relative_error, self.mean_relative_error,
333        )
334    }
335}
336
337/// Check an analytical gradient against a numerical gradient.
338///
339/// * `f` -- scalar function R^n -> R.
340/// * `analytical_grad` -- the gradient your code computes at `x`.
341/// * `x` -- the point at which the gradient was computed.
342/// * `h` -- finite difference step size (e.g. 1e-5).
343/// * `tolerance` -- maximum allowed relative error per component.
344///
345/// Returns a `GradientCheckResult` with element-wise details.
346pub fn check_gradient<T, F>(
347    f: &F,
348    analytical_grad: &ArrayView1<T>,
349    x: &[T],
350    h: T,
351    tolerance: T,
352) -> CoreResult<GradientCheckResult<T>>
353where
354    T: Float + FromPrimitive + Debug + Display,
355    F: Fn(&[T]) -> T,
356{
357    if analytical_grad.len() != x.len() {
358        return Err(CoreError::ShapeError(ErrorContext::new(format!(
359            "Analytical gradient length {} does not match input dimension {}",
360            analytical_grad.len(),
361            x.len()
362        ))));
363    }
364
365    let numerical = numerical_gradient(f, x, h, DifferenceMode::Central)?;
366    let rel_errs = relative_errors(&analytical_grad, &numerical.view())?;
367    let max_err = rel_errs
368        .iter()
369        .copied()
370        .fold(T::zero(), |a, b| if b > a { b } else { a });
371
372    let n_f = T::from_usize(rel_errs.len().max(1)).unwrap_or(T::one());
373    let sum_err = rel_errs.iter().copied().fold(T::zero(), |a, b| a + b);
374    let mean_err = sum_err / n_f;
375
376    Ok(GradientCheckResult {
377        relative_errors: rel_errs,
378        max_relative_error: max_err,
379        mean_relative_error: mean_err,
380        passed: max_err < tolerance,
381    })
382}
383
384// ---------------------------------------------------------------------------
385// Tests
386// ---------------------------------------------------------------------------
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391    use ::ndarray::array;
392
393    #[test]
394    fn test_absolute_error() {
395        assert!((absolute_error(3.0_f64, 3.0) - 0.0).abs() < 1e-15);
396        assert!((absolute_error(3.5_f64, 3.0) - 0.5).abs() < 1e-15);
397    }
398
399    #[test]
400    fn test_relative_error_basic() {
401        assert!((relative_error(1.01_f64, 1.0) - 0.01).abs() < 1e-10);
402        assert!((relative_error(0.0_f64, 0.0) - 0.0).abs() < 1e-15);
403        assert!(relative_error(1.0_f64, 0.0).is_infinite());
404    }
405
406    #[test]
407    fn test_relative_errors_array() {
408        let computed = array![1.01, 2.02, 3.03];
409        let reference = array![1.0, 2.0, 3.0];
410        let errs = relative_errors(&computed.view(), &reference.view()).expect("should succeed");
411        assert_eq!(errs.len(), 3);
412        for &e in errs.iter() {
413            assert!(e < 0.02);
414        }
415    }
416
417    #[test]
418    fn test_relative_errors_mismatch() {
419        let a = array![1.0, 2.0];
420        let b = array![1.0];
421        assert!(relative_errors(&a.view(), &b.view()).is_err());
422    }
423
424    #[test]
425    fn test_compensated_sum_accuracy() {
426        // Many small values that lose precision with naive sum
427        let values: Vec<f64> = (0..10_000).map(|_| 0.01).collect();
428        let result = compensated_sum(&values);
429        assert!((result - 100.0).abs() < 1e-10);
430    }
431
432    #[test]
433    fn test_compensated_sum_array_view() {
434        let arr = array![1e20, 1.0, -1e20];
435        let result = compensated_sum_array(&arr.view());
436        assert!((result - 1.0).abs() < 1e-5);
437    }
438
439    #[test]
440    fn test_pairwise_sum_array() {
441        let arr: Array1<f64> = Array1::from_vec((0..500).map(|i| 0.1 + 0.001 * i as f64).collect());
442        let pw = pairwise_sum_array(&arr.view());
443        let naive: f64 = arr.iter().sum();
444        assert!((pw - naive).abs() < 1e-8);
445    }
446
447    #[test]
448    fn test_softmax_array() {
449        let vals = array![1000.0_f64, 1000.0, 1000.0];
450        let sm = softmax_array(&vals.view());
451        for &p in sm.iter() {
452            assert!((p - 1.0 / 3.0).abs() < 1e-10);
453        }
454        let total: f64 = sm.iter().sum();
455        assert!((total - 1.0).abs() < 1e-10);
456    }
457
458    #[test]
459    fn test_softmax_empty() {
460        let vals: Array1<f64> = Array1::from_vec(vec![]);
461        let sm = softmax_array(&vals.view());
462        assert!(sm.is_empty());
463    }
464
465    #[test]
466    fn test_sigmoid_array() {
467        let vals = array![0.0_f64, 100.0, -100.0];
468        let sig = sigmoid_array(&vals.view());
469        assert!((sig[0] - 0.5).abs() < 1e-10);
470        assert!((sig[1] - 1.0).abs() < 1e-10);
471        assert!(sig[2] < 1e-30);
472    }
473
474    #[test]
475    fn test_log_sum_exp_array() {
476        let vals = array![1000.0_f64, 1000.0, 1000.0];
477        let lse = log_sum_exp_array(&vals.view());
478        let expected = 1000.0 + 3.0_f64.ln();
479        assert!((lse - expected).abs() < 1e-10);
480    }
481
482    #[test]
483    fn test_log_sum_exp_array_empty() {
484        let vals: Array1<f64> = Array1::from_vec(vec![]);
485        let lse = log_sum_exp_array(&vals.view());
486        assert!(lse.is_infinite() && lse < 0.0);
487    }
488
489    #[test]
490    fn test_condition_number_1d() {
491        let vals = array![1.0_f64, 10.0, 100.0];
492        let cn = condition_number_1d(&vals.view()).expect("should succeed");
493        assert!((cn - 100.0).abs() < 1e-10);
494    }
495
496    #[test]
497    fn test_condition_number_1d_all_zeros() {
498        let vals = array![0.0_f64, 0.0];
499        assert!(condition_number_1d(&vals.view()).is_err());
500    }
501
502    #[test]
503    fn test_condition_number_1d_empty() {
504        let vals: Array1<f64> = Array1::from_vec(vec![]);
505        assert!(condition_number_1d(&vals.view()).is_err());
506    }
507
508    #[test]
509    fn test_numerical_gradient_forward() {
510        // f(x) = x0^2 + x1^2, grad = [2*x0, 2*x1]
511        let f = |x: &[f64]| x[0] * x[0] + x[1] * x[1];
512        let x = [3.0, 4.0];
513        let grad =
514            numerical_gradient(&f, &x, 1e-7, DifferenceMode::Forward).expect("should succeed");
515        assert!((grad[0] - 6.0).abs() < 1e-4);
516        assert!((grad[1] - 8.0).abs() < 1e-4);
517    }
518
519    #[test]
520    fn test_numerical_gradient_backward() {
521        let f = |x: &[f64]| x[0] * x[0] + x[1] * x[1];
522        let x = [3.0, 4.0];
523        let grad =
524            numerical_gradient(&f, &x, 1e-7, DifferenceMode::Backward).expect("should succeed");
525        assert!((grad[0] - 6.0).abs() < 1e-4);
526        assert!((grad[1] - 8.0).abs() < 1e-4);
527    }
528
529    #[test]
530    fn test_numerical_gradient_central() {
531        let f = |x: &[f64]| x[0] * x[0] + x[1] * x[1];
532        let x = [3.0, 4.0];
533        let grad =
534            numerical_gradient(&f, &x, 1e-5, DifferenceMode::Central).expect("should succeed");
535        // Central difference should be more accurate
536        assert!((grad[0] - 6.0).abs() < 1e-8);
537        assert!((grad[1] - 8.0).abs() < 1e-8);
538    }
539
540    #[test]
541    fn test_numerical_gradient_sin() {
542        // f(x) = sin(x0), grad = [cos(x0)]
543        let f = |x: &[f64]| x[0].sin();
544        let x = [std::f64::consts::PI / 4.0];
545        let grad =
546            numerical_gradient(&f, &x, 1e-7, DifferenceMode::Central).expect("should succeed");
547        let expected = (std::f64::consts::PI / 4.0).cos();
548        assert!((grad[0] - expected).abs() < 1e-8);
549    }
550
551    #[test]
552    fn test_check_gradient_passes() {
553        let f = |x: &[f64]| x[0] * x[0] + 2.0 * x[1] * x[1];
554        let x = [3.0, 4.0];
555        let analytical = array![6.0, 16.0]; // [2*x0, 4*x1]
556        let result =
557            check_gradient(&f, &analytical.view(), &x, 1e-5, 1e-4).expect("should succeed");
558        assert!(result.passed, "gradient check should pass");
559        assert!(result.max_relative_error < 1e-4);
560    }
561
562    #[test]
563    fn test_check_gradient_fails() {
564        let f = |x: &[f64]| x[0] * x[0] + 2.0 * x[1] * x[1];
565        let x = [3.0, 4.0];
566        let bad_analytical = array![100.0, 200.0]; // wrong gradient
567        let result =
568            check_gradient(&f, &bad_analytical.view(), &x, 1e-5, 1e-4).expect("should succeed");
569        assert!(
570            !result.passed,
571            "gradient check should fail with wrong gradient"
572        );
573    }
574
575    #[test]
576    fn test_check_gradient_dimension_mismatch() {
577        let f = |x: &[f64]| x[0];
578        let x = [1.0, 2.0];
579        let analytical = array![1.0]; // wrong dimension
580        assert!(check_gradient(&f, &analytical.view(), &x, 1e-5, 1e-4).is_err());
581    }
582
583    #[test]
584    fn test_max_relative_error() {
585        let a = array![1.1_f64, 2.2, 3.3];
586        let b = array![1.0, 2.0, 3.0];
587        let mre = max_relative_error(&a.view(), &b.view()).expect("should succeed");
588        assert!(mre > 0.09 && mre < 0.11);
589    }
590
591    #[test]
592    fn test_compensated_sum_empty() {
593        let empty: Vec<f64> = vec![];
594        assert!((compensated_sum(&empty) - 0.0).abs() < 1e-15);
595    }
596
597    #[test]
598    fn test_gradient_check_display() {
599        let f = |x: &[f64]| x[0] * x[0];
600        let x = [2.0];
601        let analytical = array![4.0];
602        let result =
603            check_gradient(&f, &analytical.view(), &x, 1e-5, 1e-4).expect("should succeed");
604        let display = format!("{result}");
605        assert!(display.contains("GradientCheck"));
606    }
607}