Skip to main content

wickra_core/indicators/
granger_causality.rs

1//! Granger causality F-statistic: does series `b` help predict series `a`?
2
3use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8/// Granger causality of `b` on `a` over a rolling window, as an F-statistic.
9///
10/// Each `update` takes one `(a, b)` pair. Over the trailing window of `period`
11/// observations the indicator fits two autoregressions of `a` and compares them
12/// with an F-test:
13///
14/// ```text
15/// restricted:    aₜ = c + Σ φᵢ·aₜ₋ᵢ                       (a's own lags only)
16/// unrestricted:  aₜ = c + Σ φᵢ·aₜ₋ᵢ + Σ ψᵢ·bₜ₋ᵢ          (+ b's lags)
17/// F = ((RSSᵣ − RSSᵤ) / lag) / (RSSᵤ / (n − 2·lag − 1))
18/// ```
19///
20/// If adding `b`'s lags significantly reduces the residual sum of squares, `b`
21/// **Granger-causes** `a`: past values of `b` carry information about the future
22/// of `a` beyond what `a`'s own past holds. A **larger** F means stronger
23/// predictive causality (lead–lag structure a stat-arb model can trade); a
24/// value near `0` means `b` adds nothing. Note Granger causality is purely
25/// predictive — it is not structural cause and effect.
26///
27/// The statistic is `0` when a regression is degenerate — a collinear or flat
28/// window makes the normal equations singular. The output is always `≥ 0`.
29///
30/// Each `update` is `O(period · lag² + lag³)`, bounded by the fixed parameters.
31///
32/// # Example
33///
34/// ```
35/// use wickra_core::{GrangerCausality, Indicator};
36///
37/// let mut g = GrangerCausality::new(60, 1).unwrap();
38/// let mut last = None;
39/// for t in 0..120 {
40///     let drive = (f64::from(t) * 0.3).sin();
41///     // a echoes b's previous value plus noise ⇒ b Granger-causes a.
42///     let b = drive;
43///     let a = 0.5 * (f64::from(t.max(1) - 1) * 0.3).sin() + 0.1 * (f64::from(t) * 0.9).cos();
44///     last = g.update((a, b));
45/// }
46/// assert!(last.unwrap() >= 0.0);
47/// ```
48#[derive(Debug, Clone)]
49pub struct GrangerCausality {
50    period: usize,
51    lag: usize,
52    window: VecDeque<(f64, f64)>,
53}
54
55impl GrangerCausality {
56    /// Construct a new Granger causality test.
57    ///
58    /// `period` is the look-back window; `lag` is the autoregressive order
59    /// (number of own/cross lags in each model).
60    ///
61    /// # Errors
62    /// Returns [`Error::InvalidPeriod`] if `lag < 1` or if `period < 3·lag + 2`
63    /// (the smallest window that leaves the unrestricted regression at least one
64    /// residual degree of freedom).
65    pub fn new(period: usize, lag: usize) -> Result<Self> {
66        if lag < 1 {
67            return Err(Error::InvalidPeriod {
68                message: "granger causality needs lag >= 1",
69            });
70        }
71        if period < 3 * lag + 2 {
72            return Err(Error::InvalidPeriod {
73                message: "granger causality needs period >= 3*lag + 2",
74            });
75        }
76        Ok(Self {
77            period,
78            lag,
79            window: VecDeque::with_capacity(period),
80        })
81    }
82
83    /// Configured look-back window.
84    pub const fn period(&self) -> usize {
85        self.period
86    }
87
88    /// Configured autoregressive order.
89    pub const fn lag(&self) -> usize {
90        self.lag
91    }
92}
93
94impl Indicator for GrangerCausality {
95    type Input = (f64, f64);
96    type Output = f64;
97
98    fn update(&mut self, input: (f64, f64)) -> Option<f64> {
99        if !input.0.is_finite() || !input.1.is_finite() {
100            return None;
101        }
102        if self.window.len() == self.period {
103            self.window.pop_front();
104        }
105        self.window.push_back(input);
106        if self.window.len() < self.period {
107            return None;
108        }
109        let lag = self.lag;
110        let a: Vec<f64> = self.window.iter().map(|&(av, _)| av).collect();
111        let b: Vec<f64> = self.window.iter().map(|&(_, bv)| bv).collect();
112        let num_obs = self.period - lag;
113
114        let mut target = Vec::with_capacity(num_obs);
115        let mut restricted = Vec::with_capacity(num_obs);
116        let mut unrestricted = Vec::with_capacity(num_obs);
117        for k in 0..num_obs {
118            let now = lag + k;
119            target.push(a[now]);
120            let mut row_r = Vec::with_capacity(lag + 1);
121            row_r.push(1.0);
122            for back in 1..=lag {
123                row_r.push(a[now - back]);
124            }
125            let mut row_u = row_r.clone();
126            for back in 1..=lag {
127                row_u.push(b[now - back]);
128            }
129            restricted.push(row_r);
130            unrestricted.push(row_u);
131        }
132
133        let Some(rss_r) = ols_rss(&restricted, &target, lag + 1) else {
134            return Some(0.0);
135        };
136        let Some(rss_u) = ols_rss(&unrestricted, &target, 2 * lag + 1) else {
137            return Some(0.0);
138        };
139        let dof = (num_obs - (2 * lag + 1)) as f64;
140        let numerator = (rss_r - rss_u) / lag as f64;
141        let denominator = rss_u / dof;
142        Some((numerator / denominator).max(0.0))
143    }
144
145    fn reset(&mut self) {
146        self.window.clear();
147    }
148
149    fn warmup_period(&self) -> usize {
150        self.period
151    }
152
153    fn is_ready(&self) -> bool {
154        self.window.len() == self.period
155    }
156
157    fn name(&self) -> &'static str {
158        "GrangerCausality"
159    }
160}
161
162/// Residual sum of squares of the OLS fit of `target` on the design `rows`
163/// (each a length-`num_reg` regressor vector). Returns `None` if the normal
164/// equations are singular.
165fn ols_rss(rows: &[Vec<f64>], target: &[f64], num_reg: usize) -> Option<f64> {
166    let mut xtx = vec![vec![0.0; num_reg]; num_reg];
167    let mut xty = vec![0.0; num_reg];
168    for (row, &observed) in rows.iter().zip(target) {
169        for (ri, &left) in row.iter().enumerate() {
170            xty[ri] += left * observed;
171            for (ci, &right) in row.iter().enumerate() {
172                xtx[ri][ci] += left * right;
173            }
174        }
175    }
176    let theta = solve(xtx, xty)?;
177    let mut rss = 0.0;
178    for (row, &observed) in rows.iter().zip(target) {
179        let pred: f64 = row
180            .iter()
181            .zip(&theta)
182            .map(|(coeff, value)| coeff * value)
183            .sum();
184        let resid = observed - pred;
185        rss += resid * resid;
186    }
187    Some(rss)
188}
189
190/// Solve the linear system `mat·x = rhs` by Gaussian elimination, returning
191/// `None` if the matrix is (numerically) singular. `mat` is row-major.
192fn solve(mut mat: Vec<Vec<f64>>, mut rhs: Vec<f64>) -> Option<Vec<f64>> {
193    let dim = rhs.len();
194    for col in 0..dim {
195        let pivot = mat[col][col];
196        if pivot.abs() < 1e-12 {
197            return None;
198        }
199        let pivot_row = mat[col].clone();
200        for row in (col + 1)..dim {
201            let factor = mat[row][col] / pivot;
202            for (cell, &above) in mat[row].iter_mut().zip(&pivot_row).skip(col) {
203                *cell -= factor * above;
204            }
205            rhs[row] -= factor * rhs[col];
206        }
207    }
208    let mut sol = vec![0.0; dim];
209    for row in (0..dim).rev() {
210        let known: f64 = mat[row]
211            .iter()
212            .zip(&sol)
213            .skip(row + 1)
214            .map(|(coeff, value)| coeff * value)
215            .sum();
216        sol[row] = (rhs[row] - known) / mat[row][row];
217    }
218    Some(sol)
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224    use crate::traits::BatchExt;
225
226    #[test]
227    fn rejects_bad_parameters() {
228        assert!(GrangerCausality::new(10, 0).is_err()); // lag must be >= 1
229        assert!(GrangerCausality::new(4, 1).is_err()); // period must be >= 3*lag + 2
230        assert!(GrangerCausality::new(5, 1).is_ok());
231    }
232
233    #[test]
234    fn accessors_and_metadata() {
235        let g = GrangerCausality::new(60, 2).unwrap();
236        assert_eq!(g.period(), 60);
237        assert_eq!(g.lag(), 2);
238        assert_eq!(g.warmup_period(), 60);
239        assert_eq!(g.name(), "GrangerCausality");
240        assert!(!g.is_ready());
241    }
242
243    #[test]
244    fn warmup_returns_none() {
245        let mut g = GrangerCausality::new(5, 1).unwrap();
246        for t in 0..4 {
247            assert_eq!(g.update((f64::from(t), f64::from(t) * 0.5)), None);
248        }
249        assert!(g.update((4.0, 2.0)).is_some());
250        assert!(g.is_ready());
251    }
252
253    #[test]
254    fn b_leading_a_has_positive_statistic() {
255        // a[t] is driven by b[t-1] plus a little of its own past ⇒ b helps.
256        let mut prev_drive = 0.0;
257        let pairs: Vec<(f64, f64)> = (0..120)
258            .map(|t| {
259                let drive = (f64::from(t) * 0.3).sin() + 0.4 * (f64::from(t) * 0.11).cos();
260                let a = 0.8 * prev_drive + 0.05 * (f64::from(t) * 0.7).sin();
261                prev_drive = drive;
262                (a, drive)
263            })
264            .collect();
265        let last = GrangerCausality::new(60, 1)
266            .unwrap()
267            .batch(&pairs)
268            .into_iter()
269            .flatten()
270            .last()
271            .unwrap();
272        assert!(last > 1.0, "F {last}");
273    }
274
275    #[test]
276    fn constant_b_is_singular_and_returns_zero() {
277        // b is constant ⇒ its lag columns are collinear with the intercept ⇒
278        // the unrestricted normal equations are singular ⇒ 0.
279        let pairs: Vec<(f64, f64)> = (0..40)
280            .map(|t| (f64::from(t) + (f64::from(t) * 0.6).sin(), 3.0))
281            .collect();
282        let last = GrangerCausality::new(20, 1)
283            .unwrap()
284            .batch(&pairs)
285            .into_iter()
286            .flatten()
287            .last()
288            .unwrap();
289        assert_eq!(last, 0.0);
290    }
291
292    #[test]
293    fn constant_a_restricted_singular_returns_zero() {
294        // a is constant ⇒ its own lag columns are collinear with the intercept
295        // ⇒ the restricted normal equations are singular ⇒ 0.
296        let pairs: Vec<(f64, f64)> = (0..40).map(|t| (5.0, (f64::from(t) * 0.4).sin())).collect();
297        let last = GrangerCausality::new(20, 1)
298            .unwrap()
299            .batch(&pairs)
300            .into_iter()
301            .flatten()
302            .last()
303            .unwrap();
304        assert_eq!(last, 0.0);
305    }
306
307    #[test]
308    fn reset_clears_state() {
309        let mut g = GrangerCausality::new(8, 1).unwrap();
310        for t in 0..12 {
311            g.update((
312                f64::from(t) + (f64::from(t) * 0.7).sin(),
313                (f64::from(t) * 0.3).cos(),
314            ));
315        }
316        assert!(g.is_ready());
317        g.reset();
318        assert!(!g.is_ready());
319        assert_eq!(g.update((1.0, 1.0)), None);
320    }
321
322    #[test]
323    fn batch_equals_streaming() {
324        let pairs: Vec<(f64, f64)> = (0..80)
325            .map(|t| {
326                let b = (f64::from(t) * 0.4).sin();
327                (
328                    0.6 * (f64::from(t.max(1) - 1) * 0.4).sin() + 0.1 * f64::from(t % 3),
329                    b,
330                )
331            })
332            .collect();
333        let batch = GrangerCausality::new(30, 2).unwrap().batch(&pairs);
334        let mut g = GrangerCausality::new(30, 2).unwrap();
335        let streamed: Vec<_> = pairs.iter().map(|p| g.update(*p)).collect();
336        assert_eq!(batch, streamed);
337    }
338
339    #[test]
340    fn non_finite_input_returns_none() {
341        let mut g = GrangerCausality::new(5, 1).unwrap();
342        assert_eq!(g.update((f64::NAN, 1.0)), None);
343        assert_eq!(g.update((1.0, f64::INFINITY)), None);
344        // The rejected ticks leave no trace: a fresh window still warms up.
345        for t in 0..4 {
346            assert_eq!(g.update((f64::from(t), f64::from(t) * 0.5)), None);
347        }
348        assert!(g.update((4.0, 2.0)).is_some());
349    }
350}