wickra_core/indicators/
rolling_covariance.rs1use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8#[derive(Debug, Clone)]
46pub struct RollingCovariance {
47 period: usize,
48 prev: Option<(f64, f64)>,
49 window: VecDeque<(f64, f64)>,
50 sum_x: f64,
51 sum_y: f64,
52 sum_xy: f64,
53}
54
55impl RollingCovariance {
56 pub fn new(period: usize) -> Result<Self> {
62 if period < 2 {
63 return Err(Error::InvalidPeriod {
64 message: "rolling covariance needs period >= 2",
65 });
66 }
67 Ok(Self {
68 period,
69 prev: None,
70 window: VecDeque::with_capacity(period),
71 sum_x: 0.0,
72 sum_y: 0.0,
73 sum_xy: 0.0,
74 })
75 }
76
77 pub const fn period(&self) -> usize {
79 self.period
80 }
81}
82
83impl Indicator for RollingCovariance {
84 type Input = (f64, f64);
85 type Output = f64;
86
87 fn update(&mut self, input: (f64, f64)) -> Option<f64> {
88 let (x, y) = input;
89 let Some((px, py)) = self.prev else {
90 self.prev = Some((x, y));
91 return None;
92 };
93 self.prev = Some((x, y));
94 let (rx, ry) = (x - px, y - py);
95 if self.window.len() == self.period {
96 let (ox, oy) = self.window.pop_front().expect("non-empty");
97 self.sum_x -= ox;
98 self.sum_y -= oy;
99 self.sum_xy -= ox * oy;
100 }
101 self.window.push_back((rx, ry));
102 self.sum_x += rx;
103 self.sum_y += ry;
104 self.sum_xy += rx * ry;
105 if self.window.len() < self.period {
106 return None;
107 }
108 let n = self.period as f64;
109 let mean_x = self.sum_x / n;
110 let mean_y = self.sum_y / n;
111 Some(self.sum_xy / n - mean_x * mean_y)
112 }
113
114 fn reset(&mut self) {
115 self.prev = None;
116 self.window.clear();
117 self.sum_x = 0.0;
118 self.sum_y = 0.0;
119 self.sum_xy = 0.0;
120 }
121
122 fn warmup_period(&self) -> usize {
123 self.period + 1
124 }
125
126 fn is_ready(&self) -> bool {
127 self.window.len() == self.period
128 }
129
130 fn name(&self) -> &'static str {
131 "RollingCovariance"
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138 use crate::traits::BatchExt;
139 use approx::assert_relative_eq;
140
141 #[test]
142 fn rejects_period_below_two() {
143 assert!(RollingCovariance::new(0).is_err());
144 assert!(RollingCovariance::new(1).is_err());
145 assert!(RollingCovariance::new(2).is_ok());
146 }
147
148 #[test]
149 fn accessors_and_metadata() {
150 let rc = RollingCovariance::new(14).unwrap();
151 assert_eq!(rc.period(), 14);
152 assert_eq!(rc.warmup_period(), 15);
153 assert_eq!(rc.name(), "RollingCovariance");
154 assert!(!rc.is_ready());
155 }
156
157 #[test]
158 fn warmup_needs_period_plus_one() {
159 let mut rc = RollingCovariance::new(3).unwrap();
160 assert_eq!(rc.update((1.0, 1.0)), None);
161 assert_eq!(rc.update((2.0, 3.0)), None);
162 assert_eq!(rc.update((3.0, 5.0)), None);
163 assert!(rc.update((4.0, 7.0)).is_some());
164 assert!(rc.is_ready());
165 }
166
167 #[test]
168 fn hand_computed_value() {
169 let pairs = [
173 (0.0, 0.0),
174 (1.0, 2.0),
175 (3.0, 6.0),
176 (6.0, 12.0),
177 (10.0, 20.0),
178 ];
179 let last = RollingCovariance::new(3)
180 .unwrap()
181 .batch(&pairs)
182 .into_iter()
183 .flatten()
184 .last()
185 .unwrap();
186 assert_relative_eq!(last, 4.0 / 3.0, epsilon = 1e-9);
187 }
188
189 #[test]
190 fn opposing_returns_give_negative_covariance() {
191 let pairs: Vec<(f64, f64)> = (0..30)
192 .map(|i| {
193 let x = (f64::from(i) * 0.4).sin() * 10.0;
194 (x, -x)
195 })
196 .collect();
197 let last = RollingCovariance::new(10)
198 .unwrap()
199 .batch(&pairs)
200 .into_iter()
201 .flatten()
202 .last()
203 .unwrap();
204 assert!(last < 0.0, "cov {last}");
205 }
206
207 #[test]
208 fn flat_channel_gives_zero() {
209 let pairs: Vec<(f64, f64)> = (0..20).map(|i| (f64::from(i), 7.0)).collect();
210 let last = RollingCovariance::new(6)
211 .unwrap()
212 .batch(&pairs)
213 .into_iter()
214 .flatten()
215 .last()
216 .unwrap();
217 assert_relative_eq!(last, 0.0, epsilon = 1e-12);
218 }
219
220 #[test]
221 fn reset_clears_state() {
222 let mut rc = RollingCovariance::new(4).unwrap();
223 rc.batch(&[(1.0, 2.0), (2.0, 4.0), (3.0, 1.0), (4.0, 9.0), (5.0, 2.0)]);
224 assert!(rc.is_ready());
225 rc.reset();
226 assert!(!rc.is_ready());
227 assert_eq!(rc.update((1.0, 1.0)), None);
228 }
229
230 #[test]
231 fn batch_equals_streaming() {
232 let pairs: Vec<(f64, f64)> = (0..60)
233 .map(|i| {
234 let t = f64::from(i);
235 (t.sin() * 4.0, (t * 0.5).cos() * 2.0)
236 })
237 .collect();
238 let batch = RollingCovariance::new(12).unwrap().batch(&pairs);
239 let mut rc = RollingCovariance::new(12).unwrap();
240 let streamed: Vec<_> = pairs.iter().map(|p| rc.update(*p)).collect();
241 assert_eq!(batch, streamed);
242 }
243}