wickra_core/indicators/
granger_causality.rs1use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8#[derive(Debug, Clone)]
49pub struct GrangerCausality {
50 period: usize,
51 lag: usize,
52 window: VecDeque<(f64, f64)>,
53}
54
55impl GrangerCausality {
56 pub fn new(period: usize, lag: usize) -> Result<Self> {
66 if lag < 1 {
67 return Err(Error::InvalidPeriod {
68 message: "granger causality needs lag >= 1",
69 });
70 }
71 if period < 3 * lag + 2 {
72 return Err(Error::InvalidPeriod {
73 message: "granger causality needs period >= 3*lag + 2",
74 });
75 }
76 Ok(Self {
77 period,
78 lag,
79 window: VecDeque::with_capacity(period),
80 })
81 }
82
83 pub const fn period(&self) -> usize {
85 self.period
86 }
87
88 pub const fn lag(&self) -> usize {
90 self.lag
91 }
92}
93
94impl Indicator for GrangerCausality {
95 type Input = (f64, f64);
96 type Output = f64;
97
98 fn update(&mut self, input: (f64, f64)) -> Option<f64> {
99 if !input.0.is_finite() || !input.1.is_finite() {
100 return None;
101 }
102 if self.window.len() == self.period {
103 self.window.pop_front();
104 }
105 self.window.push_back(input);
106 if self.window.len() < self.period {
107 return None;
108 }
109 let lag = self.lag;
110 let a: Vec<f64> = self.window.iter().map(|&(av, _)| av).collect();
111 let b: Vec<f64> = self.window.iter().map(|&(_, bv)| bv).collect();
112 let num_obs = self.period - lag;
113
114 let mut target = Vec::with_capacity(num_obs);
115 let mut restricted = Vec::with_capacity(num_obs);
116 let mut unrestricted = Vec::with_capacity(num_obs);
117 for k in 0..num_obs {
118 let now = lag + k;
119 target.push(a[now]);
120 let mut row_r = Vec::with_capacity(lag + 1);
121 row_r.push(1.0);
122 for back in 1..=lag {
123 row_r.push(a[now - back]);
124 }
125 let mut row_u = row_r.clone();
126 for back in 1..=lag {
127 row_u.push(b[now - back]);
128 }
129 restricted.push(row_r);
130 unrestricted.push(row_u);
131 }
132
133 let Some(rss_r) = ols_rss(&restricted, &target, lag + 1) else {
134 return Some(0.0);
135 };
136 let Some(rss_u) = ols_rss(&unrestricted, &target, 2 * lag + 1) else {
137 return Some(0.0);
138 };
139 let dof = (num_obs - (2 * lag + 1)) as f64;
140 let numerator = (rss_r - rss_u) / lag as f64;
141 let denominator = rss_u / dof;
142 Some((numerator / denominator).max(0.0))
143 }
144
145 fn reset(&mut self) {
146 self.window.clear();
147 }
148
149 fn warmup_period(&self) -> usize {
150 self.period
151 }
152
153 fn is_ready(&self) -> bool {
154 self.window.len() == self.period
155 }
156
157 fn name(&self) -> &'static str {
158 "GrangerCausality"
159 }
160}
161
162fn ols_rss(rows: &[Vec<f64>], target: &[f64], num_reg: usize) -> Option<f64> {
166 let mut xtx = vec![vec![0.0; num_reg]; num_reg];
167 let mut xty = vec![0.0; num_reg];
168 for (row, &observed) in rows.iter().zip(target) {
169 for (ri, &left) in row.iter().enumerate() {
170 xty[ri] += left * observed;
171 for (ci, &right) in row.iter().enumerate() {
172 xtx[ri][ci] += left * right;
173 }
174 }
175 }
176 let theta = solve(xtx, xty)?;
177 let mut rss = 0.0;
178 for (row, &observed) in rows.iter().zip(target) {
179 let pred: f64 = row
180 .iter()
181 .zip(&theta)
182 .map(|(coeff, value)| coeff * value)
183 .sum();
184 let resid = observed - pred;
185 rss += resid * resid;
186 }
187 Some(rss)
188}
189
190fn solve(mut mat: Vec<Vec<f64>>, mut rhs: Vec<f64>) -> Option<Vec<f64>> {
193 let dim = rhs.len();
194 for col in 0..dim {
195 let pivot = mat[col][col];
196 if pivot.abs() < 1e-12 {
197 return None;
198 }
199 let pivot_row = mat[col].clone();
200 for row in (col + 1)..dim {
201 let factor = mat[row][col] / pivot;
202 for (cell, &above) in mat[row].iter_mut().zip(&pivot_row).skip(col) {
203 *cell -= factor * above;
204 }
205 rhs[row] -= factor * rhs[col];
206 }
207 }
208 let mut sol = vec![0.0; dim];
209 for row in (0..dim).rev() {
210 let known: f64 = mat[row]
211 .iter()
212 .zip(&sol)
213 .skip(row + 1)
214 .map(|(coeff, value)| coeff * value)
215 .sum();
216 sol[row] = (rhs[row] - known) / mat[row][row];
217 }
218 Some(sol)
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224 use crate::traits::BatchExt;
225
226 #[test]
227 fn rejects_bad_parameters() {
228 assert!(GrangerCausality::new(10, 0).is_err()); assert!(GrangerCausality::new(4, 1).is_err()); assert!(GrangerCausality::new(5, 1).is_ok());
231 }
232
233 #[test]
234 fn accessors_and_metadata() {
235 let g = GrangerCausality::new(60, 2).unwrap();
236 assert_eq!(g.period(), 60);
237 assert_eq!(g.lag(), 2);
238 assert_eq!(g.warmup_period(), 60);
239 assert_eq!(g.name(), "GrangerCausality");
240 assert!(!g.is_ready());
241 }
242
243 #[test]
244 fn warmup_returns_none() {
245 let mut g = GrangerCausality::new(5, 1).unwrap();
246 for t in 0..4 {
247 assert_eq!(g.update((f64::from(t), f64::from(t) * 0.5)), None);
248 }
249 assert!(g.update((4.0, 2.0)).is_some());
250 assert!(g.is_ready());
251 }
252
253 #[test]
254 fn b_leading_a_has_positive_statistic() {
255 let mut prev_drive = 0.0;
257 let pairs: Vec<(f64, f64)> = (0..120)
258 .map(|t| {
259 let drive = (f64::from(t) * 0.3).sin() + 0.4 * (f64::from(t) * 0.11).cos();
260 let a = 0.8 * prev_drive + 0.05 * (f64::from(t) * 0.7).sin();
261 prev_drive = drive;
262 (a, drive)
263 })
264 .collect();
265 let last = GrangerCausality::new(60, 1)
266 .unwrap()
267 .batch(&pairs)
268 .into_iter()
269 .flatten()
270 .last()
271 .unwrap();
272 assert!(last > 1.0, "F {last}");
273 }
274
275 #[test]
276 fn constant_b_is_singular_and_returns_zero() {
277 let pairs: Vec<(f64, f64)> = (0..40)
280 .map(|t| (f64::from(t) + (f64::from(t) * 0.6).sin(), 3.0))
281 .collect();
282 let last = GrangerCausality::new(20, 1)
283 .unwrap()
284 .batch(&pairs)
285 .into_iter()
286 .flatten()
287 .last()
288 .unwrap();
289 assert_eq!(last, 0.0);
290 }
291
292 #[test]
293 fn constant_a_restricted_singular_returns_zero() {
294 let pairs: Vec<(f64, f64)> = (0..40).map(|t| (5.0, (f64::from(t) * 0.4).sin())).collect();
297 let last = GrangerCausality::new(20, 1)
298 .unwrap()
299 .batch(&pairs)
300 .into_iter()
301 .flatten()
302 .last()
303 .unwrap();
304 assert_eq!(last, 0.0);
305 }
306
307 #[test]
308 fn reset_clears_state() {
309 let mut g = GrangerCausality::new(8, 1).unwrap();
310 for t in 0..12 {
311 g.update((
312 f64::from(t) + (f64::from(t) * 0.7).sin(),
313 (f64::from(t) * 0.3).cos(),
314 ));
315 }
316 assert!(g.is_ready());
317 g.reset();
318 assert!(!g.is_ready());
319 assert_eq!(g.update((1.0, 1.0)), None);
320 }
321
322 #[test]
323 fn batch_equals_streaming() {
324 let pairs: Vec<(f64, f64)> = (0..80)
325 .map(|t| {
326 let b = (f64::from(t) * 0.4).sin();
327 (
328 0.6 * (f64::from(t.max(1) - 1) * 0.4).sin() + 0.1 * f64::from(t % 3),
329 b,
330 )
331 })
332 .collect();
333 let batch = GrangerCausality::new(30, 2).unwrap().batch(&pairs);
334 let mut g = GrangerCausality::new(30, 2).unwrap();
335 let streamed: Vec<_> = pairs.iter().map(|p| g.update(*p)).collect();
336 assert_eq!(batch, streamed);
337 }
338
339 #[test]
340 fn non_finite_input_returns_none() {
341 let mut g = GrangerCausality::new(5, 1).unwrap();
342 assert_eq!(g.update((f64::NAN, 1.0)), None);
343 assert_eq!(g.update((1.0, f64::INFINITY)), None);
344 for t in 0..4 {
346 assert_eq!(g.update((f64::from(t), f64::from(t) * 0.5)), None);
347 }
348 assert!(g.update((4.0, 2.0)).is_some());
349 }
350}