sklears_utils/
stats.rs

1//! Statistical utility functions
2
3use scirs2_core::ndarray::{Array1, Array2, Axis, NdFloat};
4use scirs2_core::numeric::FromPrimitive;
5
6/// Compute mean along an axis
7pub fn mean_axis<'a, D>(array: &'a Array2<D>, axis: Axis) -> Array1<D>
8where
9    D: NdFloat + FromPrimitive + 'a,
10{
11    array.mean_axis(axis).unwrap()
12}
13
14/// Compute variance along an axis
15pub fn var_axis<'a, D>(array: &'a Array2<D>, axis: Axis, ddof: usize) -> Array1<D>
16where
17    D: NdFloat + FromPrimitive + 'a,
18{
19    let mean = array.mean_axis(axis).unwrap();
20    let n = array.len_of(axis);
21
22    if axis == Axis(0) {
23        // Variance along rows (for each column)
24        let mut var = Array1::zeros(array.ncols());
25        for j in 0..array.ncols() {
26            let col = array.column(j);
27            let m = mean[j];
28            let sum_sq: D = col.mapv(|x| (x - m).powi(2)).sum();
29            var[j] = sum_sq / D::from(n - ddof).unwrap();
30        }
31        var
32    } else {
33        // Variance along columns (for each row)
34        let mut var = Array1::zeros(array.nrows());
35        for i in 0..array.nrows() {
36            let row = array.row(i);
37            let m = mean[i];
38            let sum_sq: D = row.mapv(|x| (x - m).powi(2)).sum();
39            var[i] = sum_sq / D::from(n - ddof).unwrap();
40        }
41        var
42    }
43}
44
45/// Compute standard deviation along an axis
46pub fn std_axis<'a, D>(array: &'a Array2<D>, axis: Axis, ddof: usize) -> Array1<D>
47where
48    D: NdFloat + FromPrimitive + 'a,
49{
50    var_axis(array, axis, ddof).mapv(|v| v.sqrt())
51}
52
53/// Compute covariance matrix
54pub fn covariance<D>(x: &Array2<D>, ddof: usize) -> Array2<D>
55where
56    D: NdFloat + FromPrimitive,
57{
58    let n_samples = x.nrows();
59
60    // Center the data
61    let mean = x.mean_axis(Axis(0)).unwrap();
62    let centered = x - &mean;
63
64    // Compute covariance
65    let cov = centered.t().dot(&centered) / D::from(n_samples - ddof).unwrap();
66    cov
67}
68
69#[allow(non_snake_case)]
70#[cfg(test)]
71mod tests {
72    use super::*;
73    use approx::assert_abs_diff_eq;
74    use scirs2_core::ndarray::array;
75
76    #[test]
77    fn test_mean_axis() {
78        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
79
80        let mean_rows = mean_axis(&x, Axis(0));
81        assert_abs_diff_eq!(mean_rows[0], 4.0, epsilon = 1e-10);
82        assert_abs_diff_eq!(mean_rows[1], 5.0, epsilon = 1e-10);
83        assert_abs_diff_eq!(mean_rows[2], 6.0, epsilon = 1e-10);
84
85        let mean_cols = mean_axis(&x, Axis(1));
86        assert_abs_diff_eq!(mean_cols[0], 2.0, epsilon = 1e-10);
87        assert_abs_diff_eq!(mean_cols[1], 5.0, epsilon = 1e-10);
88        assert_abs_diff_eq!(mean_cols[2], 8.0, epsilon = 1e-10);
89    }
90
91    #[test]
92    fn test_var_axis() {
93        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
94
95        let var_rows = var_axis(&x, Axis(0), 0);
96        assert_abs_diff_eq!(var_rows[0], 6.0, epsilon = 1e-10);
97        assert_abs_diff_eq!(var_rows[1], 6.0, epsilon = 1e-10);
98        assert_abs_diff_eq!(var_rows[2], 6.0, epsilon = 1e-10);
99    }
100}