wickra_core/indicators/
granger_causality.rs1use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8#[derive(Debug, Clone)]
49pub struct GrangerCausality {
50 period: usize,
51 lag: usize,
52 window: VecDeque<(f64, f64)>,
53}
54
55impl GrangerCausality {
56 pub fn new(period: usize, lag: usize) -> Result<Self> {
66 if lag < 1 {
67 return Err(Error::InvalidPeriod {
68 message: "granger causality needs lag >= 1",
69 });
70 }
71 if period < 3 * lag + 2 {
72 return Err(Error::InvalidPeriod {
73 message: "granger causality needs period >= 3*lag + 2",
74 });
75 }
76 Ok(Self {
77 period,
78 lag,
79 window: VecDeque::with_capacity(period),
80 })
81 }
82
83 pub const fn period(&self) -> usize {
85 self.period
86 }
87
88 pub const fn lag(&self) -> usize {
90 self.lag
91 }
92}
93
94impl Indicator for GrangerCausality {
95 type Input = (f64, f64);
96 type Output = f64;
97
98 fn update(&mut self, input: (f64, f64)) -> Option<f64> {
99 if self.window.len() == self.period {
100 self.window.pop_front();
101 }
102 self.window.push_back(input);
103 if self.window.len() < self.period {
104 return None;
105 }
106 let lag = self.lag;
107 let a: Vec<f64> = self.window.iter().map(|&(av, _)| av).collect();
108 let b: Vec<f64> = self.window.iter().map(|&(_, bv)| bv).collect();
109 let num_obs = self.period - lag;
110
111 let mut target = Vec::with_capacity(num_obs);
112 let mut restricted = Vec::with_capacity(num_obs);
113 let mut unrestricted = Vec::with_capacity(num_obs);
114 for k in 0..num_obs {
115 let now = lag + k;
116 target.push(a[now]);
117 let mut row_r = Vec::with_capacity(lag + 1);
118 row_r.push(1.0);
119 for back in 1..=lag {
120 row_r.push(a[now - back]);
121 }
122 let mut row_u = row_r.clone();
123 for back in 1..=lag {
124 row_u.push(b[now - back]);
125 }
126 restricted.push(row_r);
127 unrestricted.push(row_u);
128 }
129
130 let Some(rss_r) = ols_rss(&restricted, &target, lag + 1) else {
131 return Some(0.0);
132 };
133 let Some(rss_u) = ols_rss(&unrestricted, &target, 2 * lag + 1) else {
134 return Some(0.0);
135 };
136 let dof = (num_obs - (2 * lag + 1)) as f64;
137 let numerator = (rss_r - rss_u) / lag as f64;
138 let denominator = rss_u / dof;
139 Some((numerator / denominator).max(0.0))
140 }
141
142 fn reset(&mut self) {
143 self.window.clear();
144 }
145
146 fn warmup_period(&self) -> usize {
147 self.period
148 }
149
150 fn is_ready(&self) -> bool {
151 self.window.len() == self.period
152 }
153
154 fn name(&self) -> &'static str {
155 "GrangerCausality"
156 }
157}
158
159fn ols_rss(rows: &[Vec<f64>], target: &[f64], num_reg: usize) -> Option<f64> {
163 let mut xtx = vec![vec![0.0; num_reg]; num_reg];
164 let mut xty = vec![0.0; num_reg];
165 for (row, &observed) in rows.iter().zip(target) {
166 for (ri, &left) in row.iter().enumerate() {
167 xty[ri] += left * observed;
168 for (ci, &right) in row.iter().enumerate() {
169 xtx[ri][ci] += left * right;
170 }
171 }
172 }
173 let theta = solve(xtx, xty)?;
174 let mut rss = 0.0;
175 for (row, &observed) in rows.iter().zip(target) {
176 let pred: f64 = row
177 .iter()
178 .zip(&theta)
179 .map(|(coeff, value)| coeff * value)
180 .sum();
181 let resid = observed - pred;
182 rss += resid * resid;
183 }
184 Some(rss)
185}
186
187fn solve(mut mat: Vec<Vec<f64>>, mut rhs: Vec<f64>) -> Option<Vec<f64>> {
190 let dim = rhs.len();
191 for col in 0..dim {
192 let pivot = mat[col][col];
193 if pivot.abs() < 1e-12 {
194 return None;
195 }
196 let pivot_row = mat[col].clone();
197 for row in (col + 1)..dim {
198 let factor = mat[row][col] / pivot;
199 for (cell, &above) in mat[row].iter_mut().zip(&pivot_row).skip(col) {
200 *cell -= factor * above;
201 }
202 rhs[row] -= factor * rhs[col];
203 }
204 }
205 let mut sol = vec![0.0; dim];
206 for row in (0..dim).rev() {
207 let known: f64 = mat[row]
208 .iter()
209 .zip(&sol)
210 .skip(row + 1)
211 .map(|(coeff, value)| coeff * value)
212 .sum();
213 sol[row] = (rhs[row] - known) / mat[row][row];
214 }
215 Some(sol)
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221 use crate::traits::BatchExt;
222
223 #[test]
224 fn rejects_bad_parameters() {
225 assert!(GrangerCausality::new(10, 0).is_err()); assert!(GrangerCausality::new(4, 1).is_err()); assert!(GrangerCausality::new(5, 1).is_ok());
228 }
229
230 #[test]
231 fn accessors_and_metadata() {
232 let g = GrangerCausality::new(60, 2).unwrap();
233 assert_eq!(g.period(), 60);
234 assert_eq!(g.lag(), 2);
235 assert_eq!(g.warmup_period(), 60);
236 assert_eq!(g.name(), "GrangerCausality");
237 assert!(!g.is_ready());
238 }
239
240 #[test]
241 fn warmup_returns_none() {
242 let mut g = GrangerCausality::new(5, 1).unwrap();
243 for t in 0..4 {
244 assert_eq!(g.update((f64::from(t), f64::from(t) * 0.5)), None);
245 }
246 assert!(g.update((4.0, 2.0)).is_some());
247 assert!(g.is_ready());
248 }
249
250 #[test]
251 fn b_leading_a_has_positive_statistic() {
252 let mut prev_drive = 0.0;
254 let pairs: Vec<(f64, f64)> = (0..120)
255 .map(|t| {
256 let drive = (f64::from(t) * 0.3).sin() + 0.4 * (f64::from(t) * 0.11).cos();
257 let a = 0.8 * prev_drive + 0.05 * (f64::from(t) * 0.7).sin();
258 prev_drive = drive;
259 (a, drive)
260 })
261 .collect();
262 let last = GrangerCausality::new(60, 1)
263 .unwrap()
264 .batch(&pairs)
265 .into_iter()
266 .flatten()
267 .last()
268 .unwrap();
269 assert!(last > 1.0, "F {last}");
270 }
271
272 #[test]
273 fn constant_b_is_singular_and_returns_zero() {
274 let pairs: Vec<(f64, f64)> = (0..40)
277 .map(|t| (f64::from(t) + (f64::from(t) * 0.6).sin(), 3.0))
278 .collect();
279 let last = GrangerCausality::new(20, 1)
280 .unwrap()
281 .batch(&pairs)
282 .into_iter()
283 .flatten()
284 .last()
285 .unwrap();
286 assert_eq!(last, 0.0);
287 }
288
289 #[test]
290 fn constant_a_restricted_singular_returns_zero() {
291 let pairs: Vec<(f64, f64)> = (0..40).map(|t| (5.0, (f64::from(t) * 0.4).sin())).collect();
294 let last = GrangerCausality::new(20, 1)
295 .unwrap()
296 .batch(&pairs)
297 .into_iter()
298 .flatten()
299 .last()
300 .unwrap();
301 assert_eq!(last, 0.0);
302 }
303
304 #[test]
305 fn reset_clears_state() {
306 let mut g = GrangerCausality::new(8, 1).unwrap();
307 for t in 0..12 {
308 g.update((
309 f64::from(t) + (f64::from(t) * 0.7).sin(),
310 (f64::from(t) * 0.3).cos(),
311 ));
312 }
313 assert!(g.is_ready());
314 g.reset();
315 assert!(!g.is_ready());
316 assert_eq!(g.update((1.0, 1.0)), None);
317 }
318
319 #[test]
320 fn batch_equals_streaming() {
321 let pairs: Vec<(f64, f64)> = (0..80)
322 .map(|t| {
323 let b = (f64::from(t) * 0.4).sin();
324 (
325 0.6 * (f64::from(t.max(1) - 1) * 0.4).sin() + 0.1 * f64::from(t % 3),
326 b,
327 )
328 })
329 .collect();
330 let batch = GrangerCausality::new(30, 2).unwrap().batch(&pairs);
331 let mut g = GrangerCausality::new(30, 2).unwrap();
332 let streamed: Vec<_> = pairs.iter().map(|p| g.update(*p)).collect();
333 assert_eq!(batch, streamed);
334 }
335}