scouter_profile/profile/
stats.rs

1use ndarray::prelude::*;
2use ndarray_stats::CorrelationExt;
3use num_traits::Float;
4use num_traits::FromPrimitive;
5use std::collections::HashMap;
6// compute_feature_correlations computes the correlation between features in a 2D array
7//
8// # Arguments
9//
10// * `data` - A 2D array of data
11// * `features` - A vector of feature names
12//
13// # Returns
14//
15// A HashMap of feature names to a HashMap of other feature names to the correlation value
16pub fn compute_feature_correlations<F>(
17    data: &ArrayView2<F>,
18    features: &[String],
19) -> HashMap<String, HashMap<String, f32>>
20where
21    F: Float + FromPrimitive + 'static,
22{
23    let mut feature_correlations: HashMap<String, HashMap<String, f32>> = HashMap::new();
24    let correlations = data.t().pearson_correlation().unwrap();
25
26    features.iter().enumerate().for_each(|(i, feature)| {
27        let mut feature_correlation: HashMap<String, f32> = HashMap::new();
28        features.iter().enumerate().for_each(|(j, other_feature)| {
29            if i != j {
30                let value = correlations[[i, j]].to_f32().unwrap();
31                // extract the correlation value
32                feature_correlation.insert(other_feature.clone(), value);
33            }
34        });
35        feature_correlations.insert(feature.clone(), feature_correlation);
36    });
37
38    feature_correlations
39}
40
41#[cfg(test)]
42mod tests {
43
44    use super::*;
45    use ndarray::stack;
46    use ndarray_rand::rand::thread_rng;
47    use ndarray_rand::rand_distr::{Distribution, Normal};
48
49    fn generate_correlated_arrays(size: usize, correlation: f64) -> (Array1<f64>, Array1<f64>) {
50        let mut rng = thread_rng();
51        let normal = Normal::new(0.0, 1.0).unwrap();
52
53        let x: Array1<f64> = Array1::from_iter((0..size).map(|_| normal.sample(&mut rng)));
54
55        let y: Array1<f64> = Array1::from_iter((0..size).map(|i| {
56            correlation * x[i] + (1.0 - correlation.powi(2)).sqrt() * normal.sample(&mut rng)
57        }));
58
59        (x, y)
60    }
61
62    #[test]
63    fn test_correlation_2d_stats() {
64        // generate first set
65        let (x1, y1) = generate_correlated_arrays(20000, 0.75);
66
67        // generate second set
68        let (x2, y2) = generate_correlated_arrays(20000, 0.33);
69
70        let (x3, y3) = generate_correlated_arrays(20000, -0.80);
71
72        // combine into 4 columns
73        let data = stack![Axis(1), x1, y1, x2, y2, x3, y3];
74        let features = vec![
75            "x1".to_string(),
76            "y1".to_string(),
77            "x2".to_string(),
78            "y2".to_string(),
79            "x3".to_string(),
80            "y3".to_string(),
81        ];
82
83        let correlations = compute_feature_correlations(&data.view(), &features);
84
85        assert!((correlations["x1"]["y1"] - 0.75).abs() < 0.1);
86        assert!((correlations["x2"]["y2"] - 0.33).abs() < 0.1);
87        assert!((correlations["x3"]["y3"] + 0.80).abs() < 0.1);
88    }
89}