1#[derive(Debug, Clone, PartialEq, Eq)]
19pub enum CorrelationError {
20 EmptyInput,
21 MismatchedLengths,
22 ZeroVariance,
23}
24
25pub fn covariance(x_values: &[f64], y_values: &[f64]) -> Result<f64, CorrelationError> {
26 validate_pair(x_values, y_values)?;
27
28 let x_mean = mean(x_values);
29 let y_mean = mean(y_values);
30 let covariance_sum: f64 = x_values
31 .iter()
32 .zip(y_values.iter())
33 .map(|(x_value, y_value)| (x_value - x_mean) * (y_value - y_mean))
34 .sum();
35
36 Ok(covariance_sum / x_values.len() as f64)
37}
38
39pub fn pearson_correlation(x_values: &[f64], y_values: &[f64]) -> Result<f64, CorrelationError> {
40 validate_pair(x_values, y_values)?;
41
42 let x_mean = mean(x_values);
43 let y_mean = mean(y_values);
44
45 let numerator: f64 = x_values
46 .iter()
47 .zip(y_values.iter())
48 .map(|(x_value, y_value)| (x_value - x_mean) * (y_value - y_mean))
49 .sum();
50 let x_squared_deviation_sum: f64 = x_values.iter().map(|value| (value - x_mean).powi(2)).sum();
51 let y_squared_deviation_sum: f64 = y_values.iter().map(|value| (value - y_mean).powi(2)).sum();
52 let denominator = x_squared_deviation_sum.sqrt() * y_squared_deviation_sum.sqrt();
53
54 if denominator == 0.0 {
55 return Err(CorrelationError::ZeroVariance);
56 }
57
58 Ok(numerator / denominator)
59}
60
61fn validate_pair(x_values: &[f64], y_values: &[f64]) -> Result<(), CorrelationError> {
62 if x_values.is_empty() || y_values.is_empty() {
63 return Err(CorrelationError::EmptyInput);
64 }
65
66 if x_values.len() != y_values.len() {
67 return Err(CorrelationError::MismatchedLengths);
68 }
69
70 Ok(())
71}
72
73fn mean(values: &[f64]) -> f64 {
74 values.iter().sum::<f64>() / values.len() as f64
75}
76
77#[cfg(test)]
78mod tests {
79 use super::{covariance, pearson_correlation, CorrelationError};
80
81 fn approx_eq(left: f64, right: f64) {
82 assert!((left - right).abs() < 1.0e-10, "left={left}, right={right}");
83 }
84
85 #[test]
86 fn computes_covariance_and_positive_correlation() {
87 let x = [1.0, 2.0, 3.0, 4.0, 5.0];
88 let y = [2.0, 4.0, 6.0, 8.0, 10.0];
89
90 approx_eq(covariance(&x, &y).unwrap(), 4.0);
91 approx_eq(pearson_correlation(&x, &y).unwrap(), 1.0);
92 }
93
94 #[test]
95 fn computes_negative_correlation() {
96 let x = [1.0, 2.0, 3.0];
97 let y = [3.0, 2.0, 1.0];
98
99 approx_eq(pearson_correlation(&x, &y).unwrap(), -1.0);
100 }
101
102 #[test]
103 fn handles_single_value_covariance() {
104 approx_eq(covariance(&[5.0], &[8.0]).unwrap(), 0.0);
105 assert_eq!(
106 pearson_correlation(&[5.0], &[8.0]),
107 Err(CorrelationError::ZeroVariance)
108 );
109 }
110
111 #[test]
112 fn rejects_invalid_inputs() {
113 assert_eq!(covariance(&[], &[]), Err(CorrelationError::EmptyInput));
114 assert_eq!(
115 covariance(&[1.0, 2.0], &[1.0]),
116 Err(CorrelationError::MismatchedLengths)
117 );
118 assert_eq!(
119 pearson_correlation(&[2.0, 2.0], &[1.0, 2.0]),
120 Err(CorrelationError::ZeroVariance)
121 );
122 }
123}