Skip to main content

use_correlation/
lib.rs

1//! Correlation helpers for `f64` slices.
2//!
3//! The crate provides a population covariance helper and a Pearson correlation
4//! helper for paired slices of equal length.
5//!
6//! # Examples
7//!
8//! ```rust
9//! use use_correlation::{covariance, pearson_correlation};
10//!
11//! let x = [1.0, 2.0, 3.0, 4.0, 5.0];
12//! let y = [2.0, 4.0, 6.0, 8.0, 10.0];
13//!
14//! assert_eq!(covariance(&x, &y).unwrap(), 4.0);
15//! assert!((pearson_correlation(&x, &y).unwrap() - 1.0).abs() < 1.0e-12);
16//! ```
17
18#[derive(Debug, Clone, PartialEq, Eq)]
19pub enum CorrelationError {
20    EmptyInput,
21    MismatchedLengths,
22    ZeroVariance,
23}
24
25pub fn covariance(x_values: &[f64], y_values: &[f64]) -> Result<f64, CorrelationError> {
26    validate_pair(x_values, y_values)?;
27
28    let x_mean = mean(x_values);
29    let y_mean = mean(y_values);
30    let covariance_sum: f64 = x_values
31        .iter()
32        .zip(y_values.iter())
33        .map(|(x_value, y_value)| (x_value - x_mean) * (y_value - y_mean))
34        .sum();
35
36    Ok(covariance_sum / x_values.len() as f64)
37}
38
39pub fn pearson_correlation(x_values: &[f64], y_values: &[f64]) -> Result<f64, CorrelationError> {
40    validate_pair(x_values, y_values)?;
41
42    let x_mean = mean(x_values);
43    let y_mean = mean(y_values);
44
45    let numerator: f64 = x_values
46        .iter()
47        .zip(y_values.iter())
48        .map(|(x_value, y_value)| (x_value - x_mean) * (y_value - y_mean))
49        .sum();
50    let x_squared_deviation_sum: f64 = x_values.iter().map(|value| (value - x_mean).powi(2)).sum();
51    let y_squared_deviation_sum: f64 = y_values.iter().map(|value| (value - y_mean).powi(2)).sum();
52    let denominator = x_squared_deviation_sum.sqrt() * y_squared_deviation_sum.sqrt();
53
54    if denominator == 0.0 {
55        return Err(CorrelationError::ZeroVariance);
56    }
57
58    Ok(numerator / denominator)
59}
60
61fn validate_pair(x_values: &[f64], y_values: &[f64]) -> Result<(), CorrelationError> {
62    if x_values.is_empty() || y_values.is_empty() {
63        return Err(CorrelationError::EmptyInput);
64    }
65
66    if x_values.len() != y_values.len() {
67        return Err(CorrelationError::MismatchedLengths);
68    }
69
70    Ok(())
71}
72
73fn mean(values: &[f64]) -> f64 {
74    values.iter().sum::<f64>() / values.len() as f64
75}
76
77#[cfg(test)]
78mod tests {
79    use super::{covariance, pearson_correlation, CorrelationError};
80
81    fn approx_eq(left: f64, right: f64) {
82        assert!((left - right).abs() < 1.0e-10, "left={left}, right={right}");
83    }
84
85    #[test]
86    fn computes_covariance_and_positive_correlation() {
87        let x = [1.0, 2.0, 3.0, 4.0, 5.0];
88        let y = [2.0, 4.0, 6.0, 8.0, 10.0];
89
90        approx_eq(covariance(&x, &y).unwrap(), 4.0);
91        approx_eq(pearson_correlation(&x, &y).unwrap(), 1.0);
92    }
93
94    #[test]
95    fn computes_negative_correlation() {
96        let x = [1.0, 2.0, 3.0];
97        let y = [3.0, 2.0, 1.0];
98
99        approx_eq(pearson_correlation(&x, &y).unwrap(), -1.0);
100    }
101
102    #[test]
103    fn handles_single_value_covariance() {
104        approx_eq(covariance(&[5.0], &[8.0]).unwrap(), 0.0);
105        assert_eq!(
106            pearson_correlation(&[5.0], &[8.0]),
107            Err(CorrelationError::ZeroVariance)
108        );
109    }
110
111    #[test]
112    fn rejects_invalid_inputs() {
113        assert_eq!(covariance(&[], &[]), Err(CorrelationError::EmptyInput));
114        assert_eq!(
115            covariance(&[1.0, 2.0], &[1.0]),
116            Err(CorrelationError::MismatchedLengths)
117        );
118        assert_eq!(
119            pearson_correlation(&[2.0, 2.0], &[1.0, 2.0]),
120            Err(CorrelationError::ZeroVariance)
121        );
122    }
123}