scouter_profile/profile/
stats.rs1use ndarray::prelude::*;
2use ndarray_stats::CorrelationExt;
3use num_traits::Float;
4use num_traits::FromPrimitive;
5use std::collections::HashMap;
6pub fn compute_feature_correlations<F>(
17 data: &ArrayView2<F>,
18 features: &[String],
19) -> HashMap<String, HashMap<String, f32>>
20where
21 F: Float + FromPrimitive + 'static,
22{
23 let mut feature_correlations: HashMap<String, HashMap<String, f32>> = HashMap::new();
24 let correlations = data.t().pearson_correlation().unwrap();
25
26 features.iter().enumerate().for_each(|(i, feature)| {
27 let mut feature_correlation: HashMap<String, f32> = HashMap::new();
28 features.iter().enumerate().for_each(|(j, other_feature)| {
29 if i != j {
30 let value = correlations[[i, j]].to_f32().unwrap();
31 feature_correlation.insert(other_feature.clone(), value);
33 }
34 });
35 feature_correlations.insert(feature.clone(), feature_correlation);
36 });
37
38 feature_correlations
39}
40
41#[cfg(test)]
42mod tests {
43
44 use super::*;
45 use ndarray::stack;
46 use ndarray_rand::rand::thread_rng;
47 use ndarray_rand::rand_distr::{Distribution, Normal};
48
49 fn generate_correlated_arrays(size: usize, correlation: f64) -> (Array1<f64>, Array1<f64>) {
50 let mut rng = thread_rng();
51 let normal = Normal::new(0.0, 1.0).unwrap();
52
53 let x: Array1<f64> = Array1::from_iter((0..size).map(|_| normal.sample(&mut rng)));
54
55 let y: Array1<f64> = Array1::from_iter((0..size).map(|i| {
56 correlation * x[i] + (1.0 - correlation.powi(2)).sqrt() * normal.sample(&mut rng)
57 }));
58
59 (x, y)
60 }
61
62 #[test]
63 fn test_correlation_2d_stats() {
64 let (x1, y1) = generate_correlated_arrays(20000, 0.75);
66
67 let (x2, y2) = generate_correlated_arrays(20000, 0.33);
69
70 let (x3, y3) = generate_correlated_arrays(20000, -0.80);
71
72 let data = stack![Axis(1), x1, y1, x2, y2, x3, y3];
74 let features = vec![
75 "x1".to_string(),
76 "y1".to_string(),
77 "x2".to_string(),
78 "y2".to_string(),
79 "x3".to_string(),
80 "y3".to_string(),
81 ];
82
83 let correlations = compute_feature_correlations(&data.view(), &features);
84
85 assert!((correlations["x1"]["y1"] - 0.75).abs() < 0.1);
86 assert!((correlations["x2"]["y2"] - 0.33).abs() < 0.1);
87 assert!((correlations["x3"]["y3"] + 0.80).abs() < 0.1);
88 }
89}