Skip to main content

wickra_core/indicators/
rolling_correlation.rs

1//! Rolling Pearson correlation of the period-over-period *returns* of two series.
2
3use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8/// Rolling correlation of the **returns** of two synchronised series.
9///
10/// Where [`crate::PearsonCorrelation`] correlates the raw *levels* `(x, y)`,
11/// this indicator first differences each channel into a one-step return and
12/// correlates those returns over the trailing window:
13///
14/// ```text
15/// rxₜ = xₜ − xₜ₋₁          ryₜ = yₜ − yₜ₋₁
16/// corr = cov(rx, ry) / √(var(rx) · var(ry))
17/// ```
18///
19/// Return correlation is the quantity that matters for hedging and portfolio
20/// risk: two assets can trend together (high level correlation) while their
21/// day-to-day moves are nearly independent (low return correlation). The output
22/// is in `[−1, +1]`; a flat return channel makes the ratio undefined and the
23/// indicator reports `0` rather than `NaN`. The value is clamped to `[−1, +1]`
24/// to absorb tiny floating-point overshoots near the boundaries.
25///
26/// Each `update` is O(1): the five running sums (`Σrx`, `Σry`, `Σrx²`, `Σry²`,
27/// `Σrxry`) are maintained as the window of returns slides. The first level in
28/// each channel produces no return, so a `period`-pair correlation needs
29/// `period + 1` updates of warmup.
30///
31/// # Example
32///
33/// ```
34/// use wickra_core::{Indicator, RollingCorrelation};
35///
36/// let mut rc = RollingCorrelation::new(10).unwrap();
37/// let mut last = None;
38/// for i in 0..40 {
39///     // A varying path where y always moves with x ⇒ return correlation +1.
40///     let x = (f64::from(i) * 0.5).sin() * 10.0;
41///     last = rc.update((x, 2.0 * x));
42/// }
43/// assert!((last.unwrap() - 1.0).abs() < 1e-9);
44/// ```
45#[derive(Debug, Clone)]
46pub struct RollingCorrelation {
47    period: usize,
48    prev: Option<(f64, f64)>,
49    window: VecDeque<(f64, f64)>,
50    sum_x: f64,
51    sum_y: f64,
52    sum_xx: f64,
53    sum_yy: f64,
54    sum_xy: f64,
55}
56
57impl RollingCorrelation {
58    /// Construct a new rolling return-correlation.
59    ///
60    /// # Errors
61    /// Returns [`Error::InvalidPeriod`] if `period < 2` — correlation is
62    /// undefined for fewer than two return pairs.
63    pub fn new(period: usize) -> Result<Self> {
64        if period < 2 {
65            return Err(Error::InvalidPeriod {
66                message: "rolling correlation needs period >= 2",
67            });
68        }
69        Ok(Self {
70            period,
71            prev: None,
72            window: VecDeque::with_capacity(period),
73            sum_x: 0.0,
74            sum_y: 0.0,
75            sum_xx: 0.0,
76            sum_yy: 0.0,
77            sum_xy: 0.0,
78        })
79    }
80
81    /// Configured window of returns.
82    pub const fn period(&self) -> usize {
83        self.period
84    }
85}
86
87impl Indicator for RollingCorrelation {
88    type Input = (f64, f64);
89    type Output = f64;
90
91    fn update(&mut self, input: (f64, f64)) -> Option<f64> {
92        let (x, y) = input;
93        let Some((px, py)) = self.prev else {
94            // First level in each channel: store it, no return yet.
95            self.prev = Some((x, y));
96            return None;
97        };
98        self.prev = Some((x, y));
99        let (rx, ry) = (x - px, y - py);
100        if self.window.len() == self.period {
101            let (ox, oy) = self.window.pop_front().expect("non-empty");
102            self.sum_x -= ox;
103            self.sum_y -= oy;
104            self.sum_xx -= ox * ox;
105            self.sum_yy -= oy * oy;
106            self.sum_xy -= ox * oy;
107        }
108        self.window.push_back((rx, ry));
109        self.sum_x += rx;
110        self.sum_y += ry;
111        self.sum_xx += rx * rx;
112        self.sum_yy += ry * ry;
113        self.sum_xy += rx * ry;
114        if self.window.len() < self.period {
115            return None;
116        }
117        let n = self.period as f64;
118        let mean_x = self.sum_x / n;
119        let mean_y = self.sum_y / n;
120        let var_x = (self.sum_xx / n - mean_x * mean_x).max(0.0);
121        let var_y = (self.sum_yy / n - mean_y * mean_y).max(0.0);
122        let cov = self.sum_xy / n - mean_x * mean_y;
123        let denom = (var_x * var_y).sqrt();
124        if denom == 0.0 {
125            // At least one return channel is flat: correlation is undefined.
126            return Some(0.0);
127        }
128        Some((cov / denom).clamp(-1.0, 1.0))
129    }
130
131    fn reset(&mut self) {
132        self.prev = None;
133        self.window.clear();
134        self.sum_x = 0.0;
135        self.sum_y = 0.0;
136        self.sum_xx = 0.0;
137        self.sum_yy = 0.0;
138        self.sum_xy = 0.0;
139    }
140
141    fn warmup_period(&self) -> usize {
142        self.period + 1
143    }
144
145    fn is_ready(&self) -> bool {
146        self.window.len() == self.period
147    }
148
149    fn name(&self) -> &'static str {
150        "RollingCorrelation"
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157    use crate::traits::BatchExt;
158    use approx::assert_relative_eq;
159
160    #[test]
161    fn rejects_period_below_two() {
162        assert!(RollingCorrelation::new(0).is_err());
163        assert!(RollingCorrelation::new(1).is_err());
164        assert!(RollingCorrelation::new(2).is_ok());
165    }
166
167    #[test]
168    fn accessors_and_metadata() {
169        let rc = RollingCorrelation::new(14).unwrap();
170        assert_eq!(rc.period(), 14);
171        assert_eq!(rc.warmup_period(), 15);
172        assert_eq!(rc.name(), "RollingCorrelation");
173        assert!(!rc.is_ready());
174    }
175
176    #[test]
177    fn warmup_needs_period_plus_one() {
178        let mut rc = RollingCorrelation::new(3).unwrap();
179        // First update only seeds the previous level ⇒ None.
180        assert_eq!(rc.update((1.0, 1.0)), None);
181        assert_eq!(rc.update((2.0, 3.0)), None); // 1 return
182        assert_eq!(rc.update((3.0, 5.0)), None); // 2 returns
183        assert!(rc.update((4.0, 7.0)).is_some()); // 3 returns ⇒ ready
184        assert!(rc.is_ready());
185    }
186
187    #[test]
188    fn comoving_returns_are_plus_one() {
189        // y always moves by 2x x's move ⇒ perfectly correlated returns.
190        let pairs: Vec<(f64, f64)> = (0..20)
191            .map(|i| {
192                let x = (f64::from(i) * 0.5).sin() * 10.0;
193                (x, 2.0 * x + 100.0)
194            })
195            .collect();
196        let last = RollingCorrelation::new(8)
197            .unwrap()
198            .batch(&pairs)
199            .into_iter()
200            .flatten()
201            .last()
202            .unwrap();
203        assert_relative_eq!(last, 1.0, epsilon = 1e-9);
204    }
205
206    #[test]
207    fn opposing_returns_are_minus_one() {
208        let pairs: Vec<(f64, f64)> = (0..20)
209            .map(|i| {
210                let x = (f64::from(i) * 0.5).sin() * 10.0;
211                (x, -1.5 * x + 50.0)
212            })
213            .collect();
214        let last = RollingCorrelation::new(8)
215            .unwrap()
216            .batch(&pairs)
217            .into_iter()
218            .flatten()
219            .last()
220            .unwrap();
221        assert_relative_eq!(last, -1.0, epsilon = 1e-9);
222    }
223
224    #[test]
225    fn flat_return_channel_yields_zero() {
226        // y is constant ⇒ its returns are all zero ⇒ undefined ⇒ 0.
227        let pairs: Vec<(f64, f64)> = (0..20).map(|i| (f64::from(i), 7.0)).collect();
228        let last = RollingCorrelation::new(6)
229            .unwrap()
230            .batch(&pairs)
231            .into_iter()
232            .flatten()
233            .last()
234            .unwrap();
235        assert_relative_eq!(last, 0.0, epsilon = 1e-12);
236    }
237
238    #[test]
239    fn output_in_range() {
240        let pairs: Vec<(f64, f64)> = (0..80)
241            .map(|i| {
242                let t = f64::from(i);
243                (100.0 + t.sin() * 5.0, 50.0 + (t * 0.3).cos() * 3.0)
244            })
245            .collect();
246        let mut rc = RollingCorrelation::new(20).unwrap();
247        for v in rc.batch(&pairs).into_iter().flatten() {
248            assert!((-1.0..=1.0).contains(&v));
249        }
250    }
251
252    #[test]
253    fn reset_clears_state() {
254        let mut rc = RollingCorrelation::new(4).unwrap();
255        rc.batch(&[(1.0, 2.0), (2.0, 4.0), (3.0, 6.0), (4.0, 8.0), (5.0, 10.0)]);
256        assert!(rc.is_ready());
257        rc.reset();
258        assert!(!rc.is_ready());
259        assert_eq!(rc.update((1.0, 1.0)), None);
260    }
261
262    #[test]
263    fn batch_equals_streaming() {
264        let pairs: Vec<(f64, f64)> = (0..60)
265            .map(|i| {
266                let t = f64::from(i);
267                (t.sin(), (t * 0.5).cos())
268            })
269            .collect();
270        let batch = RollingCorrelation::new(14).unwrap().batch(&pairs);
271        let mut rc = RollingCorrelation::new(14).unwrap();
272        let streamed: Vec<_> = pairs.iter().map(|p| rc.update(*p)).collect();
273        assert_eq!(batch, streamed);
274    }
275}