wickra_core/indicators/
rolling_covariance.rs1use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8#[derive(Debug, Clone)]
46pub struct RollingCovariance {
47 period: usize,
48 prev: Option<(f64, f64)>,
49 window: VecDeque<(f64, f64)>,
50 sum_x: f64,
51 sum_y: f64,
52 sum_xy: f64,
53}
54
55impl RollingCovariance {
56 pub fn new(period: usize) -> Result<Self> {
62 if period < 2 {
63 return Err(Error::InvalidPeriod {
64 message: "rolling covariance needs period >= 2",
65 });
66 }
67 Ok(Self {
68 period,
69 prev: None,
70 window: VecDeque::with_capacity(period),
71 sum_x: 0.0,
72 sum_y: 0.0,
73 sum_xy: 0.0,
74 })
75 }
76
77 pub const fn period(&self) -> usize {
79 self.period
80 }
81}
82
83impl Indicator for RollingCovariance {
84 type Input = (f64, f64);
85 type Output = f64;
86
87 fn update(&mut self, input: (f64, f64)) -> Option<f64> {
88 let (x, y) = input;
89 if !x.is_finite() || !y.is_finite() {
90 return None;
91 }
92 let Some((px, py)) = self.prev else {
93 self.prev = Some((x, y));
94 return None;
95 };
96 self.prev = Some((x, y));
97 let (rx, ry) = (x - px, y - py);
98 if self.window.len() == self.period {
99 let (ox, oy) = self.window.pop_front().expect("non-empty");
100 self.sum_x -= ox;
101 self.sum_y -= oy;
102 self.sum_xy -= ox * oy;
103 }
104 self.window.push_back((rx, ry));
105 self.sum_x += rx;
106 self.sum_y += ry;
107 self.sum_xy += rx * ry;
108 if self.window.len() < self.period {
109 return None;
110 }
111 let n = self.period as f64;
112 let mean_x = self.sum_x / n;
113 let mean_y = self.sum_y / n;
114 Some(self.sum_xy / n - mean_x * mean_y)
115 }
116
117 fn reset(&mut self) {
118 self.prev = None;
119 self.window.clear();
120 self.sum_x = 0.0;
121 self.sum_y = 0.0;
122 self.sum_xy = 0.0;
123 }
124
125 fn warmup_period(&self) -> usize {
126 self.period + 1
127 }
128
129 fn is_ready(&self) -> bool {
130 self.window.len() == self.period
131 }
132
133 fn name(&self) -> &'static str {
134 "RollingCovariance"
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141 use crate::traits::BatchExt;
142 use approx::assert_relative_eq;
143
144 #[test]
145 fn rejects_period_below_two() {
146 assert!(RollingCovariance::new(0).is_err());
147 assert!(RollingCovariance::new(1).is_err());
148 assert!(RollingCovariance::new(2).is_ok());
149 }
150
151 #[test]
152 fn accessors_and_metadata() {
153 let rc = RollingCovariance::new(14).unwrap();
154 assert_eq!(rc.period(), 14);
155 assert_eq!(rc.warmup_period(), 15);
156 assert_eq!(rc.name(), "RollingCovariance");
157 assert!(!rc.is_ready());
158 }
159
160 #[test]
161 fn warmup_needs_period_plus_one() {
162 let mut rc = RollingCovariance::new(3).unwrap();
163 assert_eq!(rc.update((1.0, 1.0)), None);
164 assert_eq!(rc.update((2.0, 3.0)), None);
165 assert_eq!(rc.update((3.0, 5.0)), None);
166 assert!(rc.update((4.0, 7.0)).is_some());
167 assert!(rc.is_ready());
168 }
169
170 #[test]
171 fn hand_computed_value() {
172 let pairs = [
176 (0.0, 0.0),
177 (1.0, 2.0),
178 (3.0, 6.0),
179 (6.0, 12.0),
180 (10.0, 20.0),
181 ];
182 let last = RollingCovariance::new(3)
183 .unwrap()
184 .batch(&pairs)
185 .into_iter()
186 .flatten()
187 .last()
188 .unwrap();
189 assert_relative_eq!(last, 4.0 / 3.0, epsilon = 1e-9);
190 }
191
192 #[test]
193 fn opposing_returns_give_negative_covariance() {
194 let pairs: Vec<(f64, f64)> = (0..30)
195 .map(|i| {
196 let x = (f64::from(i) * 0.4).sin() * 10.0;
197 (x, -x)
198 })
199 .collect();
200 let last = RollingCovariance::new(10)
201 .unwrap()
202 .batch(&pairs)
203 .into_iter()
204 .flatten()
205 .last()
206 .unwrap();
207 assert!(last < 0.0, "cov {last}");
208 }
209
210 #[test]
211 fn flat_channel_gives_zero() {
212 let pairs: Vec<(f64, f64)> = (0..20).map(|i| (f64::from(i), 7.0)).collect();
213 let last = RollingCovariance::new(6)
214 .unwrap()
215 .batch(&pairs)
216 .into_iter()
217 .flatten()
218 .last()
219 .unwrap();
220 assert_relative_eq!(last, 0.0, epsilon = 1e-12);
221 }
222
223 #[test]
224 fn reset_clears_state() {
225 let mut rc = RollingCovariance::new(4).unwrap();
226 rc.batch(&[(1.0, 2.0), (2.0, 4.0), (3.0, 1.0), (4.0, 9.0), (5.0, 2.0)]);
227 assert!(rc.is_ready());
228 rc.reset();
229 assert!(!rc.is_ready());
230 assert_eq!(rc.update((1.0, 1.0)), None);
231 }
232
233 #[test]
234 fn batch_equals_streaming() {
235 let pairs: Vec<(f64, f64)> = (0..60)
236 .map(|i| {
237 let t = f64::from(i);
238 (t.sin() * 4.0, (t * 0.5).cos() * 2.0)
239 })
240 .collect();
241 let batch = RollingCovariance::new(12).unwrap().batch(&pairs);
242 let mut rc = RollingCovariance::new(12).unwrap();
243 let streamed: Vec<_> = pairs.iter().map(|p| rc.update(*p)).collect();
244 assert_eq!(batch, streamed);
245 }
246
247 #[test]
248 fn non_finite_input_returns_none() {
249 let mut rc = RollingCovariance::new(2).unwrap();
250 assert_eq!(rc.update((f64::NAN, 1.0)), None);
251 assert_eq!(rc.update((1.0, f64::INFINITY)), None);
252 assert_eq!(rc.update((1.0, 1.0)), None);
254 assert_eq!(rc.update((2.0, 3.0)), None);
255 assert!(rc.update((3.0, 5.0)).is_some());
256 }
257}