wickra_core/indicators/
rolling_correlation.rs1use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8#[derive(Debug, Clone)]
46pub struct RollingCorrelation {
47 period: usize,
48 prev: Option<(f64, f64)>,
49 window: VecDeque<(f64, f64)>,
50 sum_x: f64,
51 sum_y: f64,
52 sum_xx: f64,
53 sum_yy: f64,
54 sum_xy: f64,
55}
56
57impl RollingCorrelation {
58 pub fn new(period: usize) -> Result<Self> {
64 if period < 2 {
65 return Err(Error::InvalidPeriod {
66 message: "rolling correlation needs period >= 2",
67 });
68 }
69 Ok(Self {
70 period,
71 prev: None,
72 window: VecDeque::with_capacity(period),
73 sum_x: 0.0,
74 sum_y: 0.0,
75 sum_xx: 0.0,
76 sum_yy: 0.0,
77 sum_xy: 0.0,
78 })
79 }
80
81 pub const fn period(&self) -> usize {
83 self.period
84 }
85}
86
87impl Indicator for RollingCorrelation {
88 type Input = (f64, f64);
89 type Output = f64;
90
91 fn update(&mut self, input: (f64, f64)) -> Option<f64> {
92 let (x, y) = input;
93 let Some((px, py)) = self.prev else {
94 self.prev = Some((x, y));
96 return None;
97 };
98 self.prev = Some((x, y));
99 let (rx, ry) = (x - px, y - py);
100 if self.window.len() == self.period {
101 let (ox, oy) = self.window.pop_front().expect("non-empty");
102 self.sum_x -= ox;
103 self.sum_y -= oy;
104 self.sum_xx -= ox * ox;
105 self.sum_yy -= oy * oy;
106 self.sum_xy -= ox * oy;
107 }
108 self.window.push_back((rx, ry));
109 self.sum_x += rx;
110 self.sum_y += ry;
111 self.sum_xx += rx * rx;
112 self.sum_yy += ry * ry;
113 self.sum_xy += rx * ry;
114 if self.window.len() < self.period {
115 return None;
116 }
117 let n = self.period as f64;
118 let mean_x = self.sum_x / n;
119 let mean_y = self.sum_y / n;
120 let var_x = (self.sum_xx / n - mean_x * mean_x).max(0.0);
121 let var_y = (self.sum_yy / n - mean_y * mean_y).max(0.0);
122 let cov = self.sum_xy / n - mean_x * mean_y;
123 let denom = (var_x * var_y).sqrt();
124 if denom == 0.0 {
125 return Some(0.0);
127 }
128 Some((cov / denom).clamp(-1.0, 1.0))
129 }
130
131 fn reset(&mut self) {
132 self.prev = None;
133 self.window.clear();
134 self.sum_x = 0.0;
135 self.sum_y = 0.0;
136 self.sum_xx = 0.0;
137 self.sum_yy = 0.0;
138 self.sum_xy = 0.0;
139 }
140
141 fn warmup_period(&self) -> usize {
142 self.period + 1
143 }
144
145 fn is_ready(&self) -> bool {
146 self.window.len() == self.period
147 }
148
149 fn name(&self) -> &'static str {
150 "RollingCorrelation"
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use crate::traits::BatchExt;
158 use approx::assert_relative_eq;
159
160 #[test]
161 fn rejects_period_below_two() {
162 assert!(RollingCorrelation::new(0).is_err());
163 assert!(RollingCorrelation::new(1).is_err());
164 assert!(RollingCorrelation::new(2).is_ok());
165 }
166
167 #[test]
168 fn accessors_and_metadata() {
169 let rc = RollingCorrelation::new(14).unwrap();
170 assert_eq!(rc.period(), 14);
171 assert_eq!(rc.warmup_period(), 15);
172 assert_eq!(rc.name(), "RollingCorrelation");
173 assert!(!rc.is_ready());
174 }
175
176 #[test]
177 fn warmup_needs_period_plus_one() {
178 let mut rc = RollingCorrelation::new(3).unwrap();
179 assert_eq!(rc.update((1.0, 1.0)), None);
181 assert_eq!(rc.update((2.0, 3.0)), None); assert_eq!(rc.update((3.0, 5.0)), None); assert!(rc.update((4.0, 7.0)).is_some()); assert!(rc.is_ready());
185 }
186
187 #[test]
188 fn comoving_returns_are_plus_one() {
189 let pairs: Vec<(f64, f64)> = (0..20)
191 .map(|i| {
192 let x = (f64::from(i) * 0.5).sin() * 10.0;
193 (x, 2.0 * x + 100.0)
194 })
195 .collect();
196 let last = RollingCorrelation::new(8)
197 .unwrap()
198 .batch(&pairs)
199 .into_iter()
200 .flatten()
201 .last()
202 .unwrap();
203 assert_relative_eq!(last, 1.0, epsilon = 1e-9);
204 }
205
206 #[test]
207 fn opposing_returns_are_minus_one() {
208 let pairs: Vec<(f64, f64)> = (0..20)
209 .map(|i| {
210 let x = (f64::from(i) * 0.5).sin() * 10.0;
211 (x, -1.5 * x + 50.0)
212 })
213 .collect();
214 let last = RollingCorrelation::new(8)
215 .unwrap()
216 .batch(&pairs)
217 .into_iter()
218 .flatten()
219 .last()
220 .unwrap();
221 assert_relative_eq!(last, -1.0, epsilon = 1e-9);
222 }
223
224 #[test]
225 fn flat_return_channel_yields_zero() {
226 let pairs: Vec<(f64, f64)> = (0..20).map(|i| (f64::from(i), 7.0)).collect();
228 let last = RollingCorrelation::new(6)
229 .unwrap()
230 .batch(&pairs)
231 .into_iter()
232 .flatten()
233 .last()
234 .unwrap();
235 assert_relative_eq!(last, 0.0, epsilon = 1e-12);
236 }
237
238 #[test]
239 fn output_in_range() {
240 let pairs: Vec<(f64, f64)> = (0..80)
241 .map(|i| {
242 let t = f64::from(i);
243 (100.0 + t.sin() * 5.0, 50.0 + (t * 0.3).cos() * 3.0)
244 })
245 .collect();
246 let mut rc = RollingCorrelation::new(20).unwrap();
247 for v in rc.batch(&pairs).into_iter().flatten() {
248 assert!((-1.0..=1.0).contains(&v));
249 }
250 }
251
252 #[test]
253 fn reset_clears_state() {
254 let mut rc = RollingCorrelation::new(4).unwrap();
255 rc.batch(&[(1.0, 2.0), (2.0, 4.0), (3.0, 6.0), (4.0, 8.0), (5.0, 10.0)]);
256 assert!(rc.is_ready());
257 rc.reset();
258 assert!(!rc.is_ready());
259 assert_eq!(rc.update((1.0, 1.0)), None);
260 }
261
262 #[test]
263 fn batch_equals_streaming() {
264 let pairs: Vec<(f64, f64)> = (0..60)
265 .map(|i| {
266 let t = f64::from(i);
267 (t.sin(), (t * 0.5).cos())
268 })
269 .collect();
270 let batch = RollingCorrelation::new(14).unwrap().batch(&pairs);
271 let mut rc = RollingCorrelation::new(14).unwrap();
272 let streamed: Vec<_> = pairs.iter().map(|p| rc.update(*p)).collect();
273 assert_eq!(batch, streamed);
274 }
275}