wickra_core/indicators/
rolling_correlation.rs1use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8#[derive(Debug, Clone)]
46pub struct RollingCorrelation {
47 period: usize,
48 prev: Option<(f64, f64)>,
49 window: VecDeque<(f64, f64)>,
50 sum_x: f64,
51 sum_y: f64,
52 sum_xx: f64,
53 sum_yy: f64,
54 sum_xy: f64,
55}
56
57impl RollingCorrelation {
58 pub fn new(period: usize) -> Result<Self> {
64 if period < 2 {
65 return Err(Error::InvalidPeriod {
66 message: "rolling correlation needs period >= 2",
67 });
68 }
69 Ok(Self {
70 period,
71 prev: None,
72 window: VecDeque::with_capacity(period),
73 sum_x: 0.0,
74 sum_y: 0.0,
75 sum_xx: 0.0,
76 sum_yy: 0.0,
77 sum_xy: 0.0,
78 })
79 }
80
81 pub const fn period(&self) -> usize {
83 self.period
84 }
85}
86
87impl Indicator for RollingCorrelation {
88 type Input = (f64, f64);
89 type Output = f64;
90
91 fn update(&mut self, input: (f64, f64)) -> Option<f64> {
92 let (x, y) = input;
93 if !x.is_finite() || !y.is_finite() {
94 return None;
95 }
96 let Some((px, py)) = self.prev else {
97 self.prev = Some((x, y));
99 return None;
100 };
101 self.prev = Some((x, y));
102 let (rx, ry) = (x - px, y - py);
103 if self.window.len() == self.period {
104 let (ox, oy) = self.window.pop_front().expect("non-empty");
105 self.sum_x -= ox;
106 self.sum_y -= oy;
107 self.sum_xx -= ox * ox;
108 self.sum_yy -= oy * oy;
109 self.sum_xy -= ox * oy;
110 }
111 self.window.push_back((rx, ry));
112 self.sum_x += rx;
113 self.sum_y += ry;
114 self.sum_xx += rx * rx;
115 self.sum_yy += ry * ry;
116 self.sum_xy += rx * ry;
117 if self.window.len() < self.period {
118 return None;
119 }
120 let n = self.period as f64;
121 let mean_x = self.sum_x / n;
122 let mean_y = self.sum_y / n;
123 let var_x = (self.sum_xx / n - mean_x * mean_x).max(0.0);
124 let var_y = (self.sum_yy / n - mean_y * mean_y).max(0.0);
125 let cov = self.sum_xy / n - mean_x * mean_y;
126 let denom = (var_x * var_y).sqrt();
127 if denom == 0.0 {
128 return Some(0.0);
130 }
131 Some((cov / denom).clamp(-1.0, 1.0))
132 }
133
134 fn reset(&mut self) {
135 self.prev = None;
136 self.window.clear();
137 self.sum_x = 0.0;
138 self.sum_y = 0.0;
139 self.sum_xx = 0.0;
140 self.sum_yy = 0.0;
141 self.sum_xy = 0.0;
142 }
143
144 fn warmup_period(&self) -> usize {
145 self.period + 1
146 }
147
148 fn is_ready(&self) -> bool {
149 self.window.len() == self.period
150 }
151
152 fn name(&self) -> &'static str {
153 "RollingCorrelation"
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use crate::traits::BatchExt;
161 use approx::assert_relative_eq;
162
163 #[test]
164 fn rejects_period_below_two() {
165 assert!(RollingCorrelation::new(0).is_err());
166 assert!(RollingCorrelation::new(1).is_err());
167 assert!(RollingCorrelation::new(2).is_ok());
168 }
169
170 #[test]
171 fn accessors_and_metadata() {
172 let rc = RollingCorrelation::new(14).unwrap();
173 assert_eq!(rc.period(), 14);
174 assert_eq!(rc.warmup_period(), 15);
175 assert_eq!(rc.name(), "RollingCorrelation");
176 assert!(!rc.is_ready());
177 }
178
179 #[test]
180 fn warmup_needs_period_plus_one() {
181 let mut rc = RollingCorrelation::new(3).unwrap();
182 assert_eq!(rc.update((1.0, 1.0)), None);
184 assert_eq!(rc.update((2.0, 3.0)), None); assert_eq!(rc.update((3.0, 5.0)), None); assert!(rc.update((4.0, 7.0)).is_some()); assert!(rc.is_ready());
188 }
189
190 #[test]
191 fn comoving_returns_are_plus_one() {
192 let pairs: Vec<(f64, f64)> = (0..20)
194 .map(|i| {
195 let x = (f64::from(i) * 0.5).sin() * 10.0;
196 (x, 2.0 * x + 100.0)
197 })
198 .collect();
199 let last = RollingCorrelation::new(8)
200 .unwrap()
201 .batch(&pairs)
202 .into_iter()
203 .flatten()
204 .last()
205 .unwrap();
206 assert_relative_eq!(last, 1.0, epsilon = 1e-9);
207 }
208
209 #[test]
210 fn opposing_returns_are_minus_one() {
211 let pairs: Vec<(f64, f64)> = (0..20)
212 .map(|i| {
213 let x = (f64::from(i) * 0.5).sin() * 10.0;
214 (x, -1.5 * x + 50.0)
215 })
216 .collect();
217 let last = RollingCorrelation::new(8)
218 .unwrap()
219 .batch(&pairs)
220 .into_iter()
221 .flatten()
222 .last()
223 .unwrap();
224 assert_relative_eq!(last, -1.0, epsilon = 1e-9);
225 }
226
227 #[test]
228 fn flat_return_channel_yields_zero() {
229 let pairs: Vec<(f64, f64)> = (0..20).map(|i| (f64::from(i), 7.0)).collect();
231 let last = RollingCorrelation::new(6)
232 .unwrap()
233 .batch(&pairs)
234 .into_iter()
235 .flatten()
236 .last()
237 .unwrap();
238 assert_relative_eq!(last, 0.0, epsilon = 1e-12);
239 }
240
241 #[test]
242 fn output_in_range() {
243 let pairs: Vec<(f64, f64)> = (0..80)
244 .map(|i| {
245 let t = f64::from(i);
246 (100.0 + t.sin() * 5.0, 50.0 + (t * 0.3).cos() * 3.0)
247 })
248 .collect();
249 let mut rc = RollingCorrelation::new(20).unwrap();
250 for v in rc.batch(&pairs).into_iter().flatten() {
251 assert!((-1.0..=1.0).contains(&v));
252 }
253 }
254
255 #[test]
256 fn reset_clears_state() {
257 let mut rc = RollingCorrelation::new(4).unwrap();
258 rc.batch(&[(1.0, 2.0), (2.0, 4.0), (3.0, 6.0), (4.0, 8.0), (5.0, 10.0)]);
259 assert!(rc.is_ready());
260 rc.reset();
261 assert!(!rc.is_ready());
262 assert_eq!(rc.update((1.0, 1.0)), None);
263 }
264
265 #[test]
266 fn batch_equals_streaming() {
267 let pairs: Vec<(f64, f64)> = (0..60)
268 .map(|i| {
269 let t = f64::from(i);
270 (t.sin(), (t * 0.5).cos())
271 })
272 .collect();
273 let batch = RollingCorrelation::new(14).unwrap().batch(&pairs);
274 let mut rc = RollingCorrelation::new(14).unwrap();
275 let streamed: Vec<_> = pairs.iter().map(|p| rc.update(*p)).collect();
276 assert_eq!(batch, streamed);
277 }
278
279 #[test]
280 fn non_finite_input_returns_none() {
281 let mut rc = RollingCorrelation::new(2).unwrap();
282 assert_eq!(rc.update((f64::NAN, 1.0)), None);
283 assert_eq!(rc.update((1.0, f64::INFINITY)), None);
284 assert_eq!(rc.update((1.0, 1.0)), None);
286 assert_eq!(rc.update((2.0, 3.0)), None);
287 assert!(rc.update((3.0, 5.0)).is_some());
288 }
289}