Skip to main content

wickra_core/indicators/
rolling_correlation.rs

1//! Rolling Pearson correlation of the period-over-period *returns* of two series.
2
3use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8/// Rolling correlation of the **returns** of two synchronised series.
9///
10/// Where [`crate::PearsonCorrelation`] correlates the raw *levels* `(x, y)`,
11/// this indicator first differences each channel into a one-step return and
12/// correlates those returns over the trailing window:
13///
14/// ```text
15/// rxₜ = xₜ − xₜ₋₁          ryₜ = yₜ − yₜ₋₁
16/// corr = cov(rx, ry) / √(var(rx) · var(ry))
17/// ```
18///
19/// Return correlation is the quantity that matters for hedging and portfolio
20/// risk: two assets can trend together (high level correlation) while their
21/// day-to-day moves are nearly independent (low return correlation). The output
22/// is in `[−1, +1]`; a flat return channel makes the ratio undefined and the
23/// indicator reports `0` rather than `NaN`. The value is clamped to `[−1, +1]`
24/// to absorb tiny floating-point overshoots near the boundaries.
25///
26/// Each `update` is O(1): the five running sums (`Σrx`, `Σry`, `Σrx²`, `Σry²`,
27/// `Σrxry`) are maintained as the window of returns slides. The first level in
28/// each channel produces no return, so a `period`-pair correlation needs
29/// `period + 1` updates of warmup.
30///
31/// # Example
32///
33/// ```
34/// use wickra_core::{Indicator, RollingCorrelation};
35///
36/// let mut rc = RollingCorrelation::new(10).unwrap();
37/// let mut last = None;
38/// for i in 0..40 {
39///     // A varying path where y always moves with x ⇒ return correlation +1.
40///     let x = (f64::from(i) * 0.5).sin() * 10.0;
41///     last = rc.update((x, 2.0 * x));
42/// }
43/// assert!((last.unwrap() - 1.0).abs() < 1e-9);
44/// ```
45#[derive(Debug, Clone)]
46pub struct RollingCorrelation {
47    period: usize,
48    prev: Option<(f64, f64)>,
49    window: VecDeque<(f64, f64)>,
50    sum_x: f64,
51    sum_y: f64,
52    sum_xx: f64,
53    sum_yy: f64,
54    sum_xy: f64,
55}
56
57impl RollingCorrelation {
58    /// Construct a new rolling return-correlation.
59    ///
60    /// # Errors
61    /// Returns [`Error::InvalidPeriod`] if `period < 2` — correlation is
62    /// undefined for fewer than two return pairs.
63    pub fn new(period: usize) -> Result<Self> {
64        if period < 2 {
65            return Err(Error::InvalidPeriod {
66                message: "rolling correlation needs period >= 2",
67            });
68        }
69        Ok(Self {
70            period,
71            prev: None,
72            window: VecDeque::with_capacity(period),
73            sum_x: 0.0,
74            sum_y: 0.0,
75            sum_xx: 0.0,
76            sum_yy: 0.0,
77            sum_xy: 0.0,
78        })
79    }
80
81    /// Configured window of returns.
82    pub const fn period(&self) -> usize {
83        self.period
84    }
85}
86
87impl Indicator for RollingCorrelation {
88    type Input = (f64, f64);
89    type Output = f64;
90
91    fn update(&mut self, input: (f64, f64)) -> Option<f64> {
92        let (x, y) = input;
93        if !x.is_finite() || !y.is_finite() {
94            return None;
95        }
96        let Some((px, py)) = self.prev else {
97            // First level in each channel: store it, no return yet.
98            self.prev = Some((x, y));
99            return None;
100        };
101        self.prev = Some((x, y));
102        let (rx, ry) = (x - px, y - py);
103        if self.window.len() == self.period {
104            let (ox, oy) = self.window.pop_front().expect("non-empty");
105            self.sum_x -= ox;
106            self.sum_y -= oy;
107            self.sum_xx -= ox * ox;
108            self.sum_yy -= oy * oy;
109            self.sum_xy -= ox * oy;
110        }
111        self.window.push_back((rx, ry));
112        self.sum_x += rx;
113        self.sum_y += ry;
114        self.sum_xx += rx * rx;
115        self.sum_yy += ry * ry;
116        self.sum_xy += rx * ry;
117        if self.window.len() < self.period {
118            return None;
119        }
120        let n = self.period as f64;
121        let mean_x = self.sum_x / n;
122        let mean_y = self.sum_y / n;
123        let var_x = (self.sum_xx / n - mean_x * mean_x).max(0.0);
124        let var_y = (self.sum_yy / n - mean_y * mean_y).max(0.0);
125        let cov = self.sum_xy / n - mean_x * mean_y;
126        let denom = (var_x * var_y).sqrt();
127        if denom == 0.0 {
128            // At least one return channel is flat: correlation is undefined.
129            return Some(0.0);
130        }
131        Some((cov / denom).clamp(-1.0, 1.0))
132    }
133
134    fn reset(&mut self) {
135        self.prev = None;
136        self.window.clear();
137        self.sum_x = 0.0;
138        self.sum_y = 0.0;
139        self.sum_xx = 0.0;
140        self.sum_yy = 0.0;
141        self.sum_xy = 0.0;
142    }
143
144    fn warmup_period(&self) -> usize {
145        self.period + 1
146    }
147
148    fn is_ready(&self) -> bool {
149        self.window.len() == self.period
150    }
151
152    fn name(&self) -> &'static str {
153        "RollingCorrelation"
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use crate::traits::BatchExt;
161    use approx::assert_relative_eq;
162
163    #[test]
164    fn rejects_period_below_two() {
165        assert!(RollingCorrelation::new(0).is_err());
166        assert!(RollingCorrelation::new(1).is_err());
167        assert!(RollingCorrelation::new(2).is_ok());
168    }
169
170    #[test]
171    fn accessors_and_metadata() {
172        let rc = RollingCorrelation::new(14).unwrap();
173        assert_eq!(rc.period(), 14);
174        assert_eq!(rc.warmup_period(), 15);
175        assert_eq!(rc.name(), "RollingCorrelation");
176        assert!(!rc.is_ready());
177    }
178
179    #[test]
180    fn warmup_needs_period_plus_one() {
181        let mut rc = RollingCorrelation::new(3).unwrap();
182        // First update only seeds the previous level ⇒ None.
183        assert_eq!(rc.update((1.0, 1.0)), None);
184        assert_eq!(rc.update((2.0, 3.0)), None); // 1 return
185        assert_eq!(rc.update((3.0, 5.0)), None); // 2 returns
186        assert!(rc.update((4.0, 7.0)).is_some()); // 3 returns ⇒ ready
187        assert!(rc.is_ready());
188    }
189
190    #[test]
191    fn comoving_returns_are_plus_one() {
192        // y always moves by 2x x's move ⇒ perfectly correlated returns.
193        let pairs: Vec<(f64, f64)> = (0..20)
194            .map(|i| {
195                let x = (f64::from(i) * 0.5).sin() * 10.0;
196                (x, 2.0 * x + 100.0)
197            })
198            .collect();
199        let last = RollingCorrelation::new(8)
200            .unwrap()
201            .batch(&pairs)
202            .into_iter()
203            .flatten()
204            .last()
205            .unwrap();
206        assert_relative_eq!(last, 1.0, epsilon = 1e-9);
207    }
208
209    #[test]
210    fn opposing_returns_are_minus_one() {
211        let pairs: Vec<(f64, f64)> = (0..20)
212            .map(|i| {
213                let x = (f64::from(i) * 0.5).sin() * 10.0;
214                (x, -1.5 * x + 50.0)
215            })
216            .collect();
217        let last = RollingCorrelation::new(8)
218            .unwrap()
219            .batch(&pairs)
220            .into_iter()
221            .flatten()
222            .last()
223            .unwrap();
224        assert_relative_eq!(last, -1.0, epsilon = 1e-9);
225    }
226
227    #[test]
228    fn flat_return_channel_yields_zero() {
229        // y is constant ⇒ its returns are all zero ⇒ undefined ⇒ 0.
230        let pairs: Vec<(f64, f64)> = (0..20).map(|i| (f64::from(i), 7.0)).collect();
231        let last = RollingCorrelation::new(6)
232            .unwrap()
233            .batch(&pairs)
234            .into_iter()
235            .flatten()
236            .last()
237            .unwrap();
238        assert_relative_eq!(last, 0.0, epsilon = 1e-12);
239    }
240
241    #[test]
242    fn output_in_range() {
243        let pairs: Vec<(f64, f64)> = (0..80)
244            .map(|i| {
245                let t = f64::from(i);
246                (100.0 + t.sin() * 5.0, 50.0 + (t * 0.3).cos() * 3.0)
247            })
248            .collect();
249        let mut rc = RollingCorrelation::new(20).unwrap();
250        for v in rc.batch(&pairs).into_iter().flatten() {
251            assert!((-1.0..=1.0).contains(&v));
252        }
253    }
254
255    #[test]
256    fn reset_clears_state() {
257        let mut rc = RollingCorrelation::new(4).unwrap();
258        rc.batch(&[(1.0, 2.0), (2.0, 4.0), (3.0, 6.0), (4.0, 8.0), (5.0, 10.0)]);
259        assert!(rc.is_ready());
260        rc.reset();
261        assert!(!rc.is_ready());
262        assert_eq!(rc.update((1.0, 1.0)), None);
263    }
264
265    #[test]
266    fn batch_equals_streaming() {
267        let pairs: Vec<(f64, f64)> = (0..60)
268            .map(|i| {
269                let t = f64::from(i);
270                (t.sin(), (t * 0.5).cos())
271            })
272            .collect();
273        let batch = RollingCorrelation::new(14).unwrap().batch(&pairs);
274        let mut rc = RollingCorrelation::new(14).unwrap();
275        let streamed: Vec<_> = pairs.iter().map(|p| rc.update(*p)).collect();
276        assert_eq!(batch, streamed);
277    }
278
279    #[test]
280    fn non_finite_input_returns_none() {
281        let mut rc = RollingCorrelation::new(2).unwrap();
282        assert_eq!(rc.update((f64::NAN, 1.0)), None);
283        assert_eq!(rc.update((1.0, f64::INFINITY)), None);
284        // First finite tick seeds prev; two more returns fill the window.
285        assert_eq!(rc.update((1.0, 1.0)), None);
286        assert_eq!(rc.update((2.0, 3.0)), None);
287        assert!(rc.update((3.0, 5.0)).is_some());
288    }
289}