Skip to main content

wickra_core/indicators/
cointegration.rs

1//! Cointegration — rolling Engle–Granger hedge ratio plus an ADF stationarity test.
2
3use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8/// Output of [`Cointegration`].
9#[derive(Debug, Clone, Copy, PartialEq)]
10pub struct CointegrationOutput {
11    /// Engle–Granger hedge ratio `β`: the rolling OLS slope of `a` on `b`.
12    pub hedge_ratio: f64,
13    /// The current spread (regression residual) `a − (α + β·b)`.
14    pub spread: f64,
15    /// Augmented Dickey–Fuller `t`-statistic on the spread. **More negative**
16    /// means more strongly mean-reverting (cointegrated); compare against the
17    /// usual ADF/MacKinnon critical values (e.g. roughly `−2.9` at 5%). `0`
18    /// when the test is undefined (a degenerate, zero-variance spread).
19    pub adf_stat: f64,
20}
21
22/// Rolling cointegration test for a pair of assets (Engle–Granger two-step).
23///
24/// Each `update` receives one `(a, b)` pair (price levels, or log-levels if you
25/// prefer). Over the trailing window of `period` pairs the indicator:
26///
27/// 1. fits the **hedge ratio** `β` (and intercept `α`) by ordinary least
28///    squares of `a` on `b`, and forms the **spread** `eₜ = aₜ − (α + β·bₜ)`;
29/// 2. runs an **augmented Dickey–Fuller** test (no constant, no trend, with
30///    `adf_lags` lagged differences) on the spread series and reports its
31///    `t`-statistic.
32///
33/// A strongly negative ADF statistic means the spread reverts to its mean — the
34/// pair is cointegrated and the spread is tradeable. A statistic near zero
35/// means the spread wanders like a random walk (no cointegration). This is the
36/// classic pairs-trading screen: `β` tells you the hedge size, the spread is
37/// what you trade, and the ADF statistic tells you whether it is worth trading.
38///
39/// Each `update` is `O(period + adf_lags³)`: the hedge ratio is maintained from
40/// running sums, while the spread series and the small ADF regression are
41/// recomputed over the window — both bounded by the fixed parameters, not the
42/// series length.
43///
44/// # Example
45///
46/// ```
47/// use wickra_core::{Cointegration, Indicator};
48///
49/// let mut c = Cointegration::new(30, 1).unwrap();
50/// let mut last = None;
51/// for t in 0..60 {
52///     let b = 100.0 + f64::from(t);
53///     // `a` tracks 2·b with a small mean-reverting wobble ⇒ cointegrated.
54///     let a = 2.0 * b + 5.0 + 0.5 * (f64::from(t) * 0.7).sin();
55///     last = c.update((a, b));
56/// }
57/// let out = last.unwrap();
58/// assert!((out.hedge_ratio - 2.0).abs() < 0.1);
59/// assert!(out.adf_stat < 0.0); // mean-reverting spread
60/// ```
61#[derive(Debug, Clone)]
62pub struct Cointegration {
63    period: usize,
64    adf_lags: usize,
65    window: VecDeque<(f64, f64)>,
66    sum_a: f64,
67    sum_b: f64,
68    sum_bb: f64,
69    sum_ab: f64,
70}
71
72impl Cointegration {
73    /// Construct a new rolling cointegration test.
74    ///
75    /// `period` is the look-back window; `adf_lags` is the number of lagged
76    /// differences in the augmented Dickey–Fuller regression (`0` is the plain
77    /// Dickey–Fuller test).
78    ///
79    /// # Errors
80    /// Returns [`Error::InvalidPeriod`] if `period < 2·adf_lags + 4`, which is
81    /// the smallest window that leaves the ADF regression at least one degree
82    /// of freedom.
83    pub fn new(period: usize, adf_lags: usize) -> Result<Self> {
84        let min_period = 2 * adf_lags + 4;
85        if period < min_period {
86            return Err(Error::InvalidPeriod {
87                message: "cointegration needs period >= 2*adf_lags + 4",
88            });
89        }
90        Ok(Self {
91            period,
92            adf_lags,
93            window: VecDeque::with_capacity(period),
94            sum_a: 0.0,
95            sum_b: 0.0,
96            sum_bb: 0.0,
97            sum_ab: 0.0,
98        })
99    }
100
101    /// Look-back window length.
102    pub const fn period(&self) -> usize {
103        self.period
104    }
105
106    /// Number of lagged differences in the ADF regression.
107    pub const fn adf_lags(&self) -> usize {
108        self.adf_lags
109    }
110}
111
112impl Indicator for Cointegration {
113    /// `(a, b)` price pair.
114    type Input = (f64, f64);
115    type Output = CointegrationOutput;
116
117    fn update(&mut self, input: (f64, f64)) -> Option<CointegrationOutput> {
118        let (a, b) = input;
119        if self.window.len() == self.period {
120            let (oa, ob) = self.window.pop_front().expect("non-empty");
121            self.sum_a -= oa;
122            self.sum_b -= ob;
123            self.sum_bb -= ob * ob;
124            self.sum_ab -= oa * ob;
125        }
126        self.window.push_back((a, b));
127        self.sum_a += a;
128        self.sum_b += b;
129        self.sum_bb += b * b;
130        self.sum_ab += a * b;
131        if self.window.len() < self.period {
132            return None;
133        }
134        let n = self.period as f64;
135        let mean_a = self.sum_a / n;
136        let mean_b = self.sum_b / n;
137        let var_b = (self.sum_bb / n - mean_b * mean_b).max(0.0);
138        let (hedge_ratio, intercept) = if var_b == 0.0 {
139            // A flat `b` window has no defined slope; fall back to a level shift.
140            (0.0, mean_a)
141        } else {
142            let cov = self.sum_ab / n - mean_a * mean_b;
143            let beta = cov / var_b;
144            (beta, mean_a - beta * mean_b)
145        };
146        // Build the spread (residual) series over the window, oldest → newest.
147        let spreads: Vec<f64> = self
148            .window
149            .iter()
150            .map(|&(ai, bi)| ai - (intercept + hedge_ratio * bi))
151            .collect();
152        let spread = *spreads.last().expect("window is full");
153        let adf_stat = adf_no_constant(&spreads, self.adf_lags);
154        Some(CointegrationOutput {
155            hedge_ratio,
156            spread,
157            adf_stat,
158        })
159    }
160
161    fn reset(&mut self) {
162        self.window.clear();
163        self.sum_a = 0.0;
164        self.sum_b = 0.0;
165        self.sum_bb = 0.0;
166        self.sum_ab = 0.0;
167    }
168
169    fn warmup_period(&self) -> usize {
170        self.period
171    }
172
173    fn is_ready(&self) -> bool {
174        self.window.len() == self.period
175    }
176
177    fn name(&self) -> &'static str {
178        "Cointegration"
179    }
180}
181
182/// Solve the linear system `mat·x = rhs` for a small square system by Gaussian
183/// elimination, returning `None` if the matrix is (numerically) singular.
184///
185/// `mat` is row-major and consumed; `rhs` is the right-hand side.
186fn solve(mut mat: Vec<Vec<f64>>, mut rhs: Vec<f64>) -> Option<Vec<f64>> {
187    let dim = rhs.len();
188    for col in 0..dim {
189        let pivot = mat[col][col];
190        if pivot.abs() < 1e-12 {
191            return None;
192        }
193        let pivot_row = mat[col].clone();
194        for row in (col + 1)..dim {
195            let factor = mat[row][col] / pivot;
196            for (cell, &above) in mat[row].iter_mut().zip(&pivot_row).skip(col) {
197                *cell -= factor * above;
198            }
199            rhs[row] -= factor * rhs[col];
200        }
201    }
202    let mut sol = vec![0.0; dim];
203    for row in (0..dim).rev() {
204        let known: f64 = mat[row]
205            .iter()
206            .zip(&sol)
207            .skip(row + 1)
208            .map(|(coeff, value)| coeff * value)
209            .sum();
210        sol[row] = (rhs[row] - known) / mat[row][row];
211    }
212    Some(sol)
213}
214
215/// Augmented Dickey–Fuller `t`-statistic on `series`, with `lags` lagged
216/// differences and **no** constant or trend term (the Engle–Granger residual
217/// form). Returns `0.0` when the regression is degenerate.
218///
219/// The regression is `Δeₜ = ρ·eₜ₋₁ + Σ γᵢ·Δeₜ₋ᵢ + εₜ`; the reported statistic
220/// is `ρ̂ / se(ρ̂)`.
221fn adf_no_constant(series: &[f64], lags: usize) -> f64 {
222    let len = series.len();
223    let num_reg = lags + 1; // regressors: eₜ₋₁ plus `lags` lagged differences
224    let first = lags + 1; // first usable observation index
225    if len <= first {
226        return 0.0;
227    }
228    let num_obs = len - first;
229    if num_obs <= num_reg {
230        return 0.0; // need at least one residual degree of freedom
231    }
232    let regressors = |idx: usize| -> Vec<f64> {
233        let mut row = vec![0.0; num_reg];
234        row[0] = series[idx - 1];
235        for lag in 1..=lags {
236            row[lag] = series[idx - lag] - series[idx - lag - 1];
237        }
238        row
239    };
240    let mut xtx = vec![vec![0.0; num_reg]; num_reg];
241    let mut xty = vec![0.0; num_reg];
242    for idx in first..len {
243        let diff = series[idx] - series[idx - 1];
244        let row = regressors(idx);
245        for (ri, &left) in row.iter().enumerate() {
246            xty[ri] += left * diff;
247            for (ci, &right) in row.iter().enumerate() {
248                xtx[ri][ci] += left * right;
249            }
250        }
251    }
252    let Some(theta) = solve(xtx.clone(), xty) else {
253        return 0.0;
254    };
255    let rho = theta[0];
256    let mut rss = 0.0;
257    for idx in first..len {
258        let diff = series[idx] - series[idx - 1];
259        let pred: f64 = regressors(idx)
260            .iter()
261            .zip(&theta)
262            .map(|(coeff, value)| coeff * value)
263            .sum();
264        let resid = diff - pred;
265        rss += resid * resid;
266    }
267    let dof = (num_obs - num_reg) as f64;
268    let sigma2 = rss / dof;
269    // (XᵀX)⁻¹₀₀ from solving XᵀX·x = e₀. `xtx` is the same matrix the first
270    // solve already factored successfully, so this one cannot be singular.
271    let mut unit = vec![0.0; num_reg];
272    unit[0] = 1.0;
273    let inverse = solve(xtx, unit).expect("xtx is non-singular: the coefficient solve succeeded");
274    let var_rho = sigma2 * inverse[0];
275    if var_rho <= 0.0 {
276        return 0.0;
277    }
278    rho / var_rho.sqrt()
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284    use crate::traits::BatchExt;
285    use approx::assert_relative_eq;
286
287    #[test]
288    fn rejects_too_small_period() {
289        // period must be >= 2*lags + 4.
290        assert!(Cointegration::new(3, 0).is_err()); // needs >= 4
291        assert!(Cointegration::new(4, 0).is_ok());
292        assert!(Cointegration::new(5, 1).is_err()); // needs >= 6
293        assert!(Cointegration::new(6, 1).is_ok());
294    }
295
296    #[test]
297    fn accessors_and_metadata() {
298        let c = Cointegration::new(30, 2).unwrap();
299        assert_eq!(c.period(), 30);
300        assert_eq!(c.adf_lags(), 2);
301        assert_eq!(c.warmup_period(), 30);
302        assert_eq!(c.name(), "Cointegration");
303    }
304
305    #[test]
306    fn adf_guards_and_degenerate_spread() {
307        // Series too short for any observation ⇒ 0.
308        assert_eq!(adf_no_constant(&[1.0], 1), 0.0);
309        // Long enough but too few degrees of freedom ⇒ 0.
310        assert_eq!(adf_no_constant(&[1.0, 2.0, 3.0], 1), 0.0);
311        // A perfect deterministic AR(1) spread (eₜ = 0.5·eₜ₋₁) is fit exactly,
312        // so the residual variance — and hence the t-statistic — is 0.
313        let geom: Vec<f64> = (0..8).map(|t| 0.5_f64.powi(t)).collect();
314        assert_eq!(adf_no_constant(&geom, 0), 0.0);
315    }
316
317    #[test]
318    fn recovers_hedge_ratio() {
319        // a = 2·b + 5 + small wobble ⇒ β ≈ 2.
320        let pairs: Vec<(f64, f64)> = (0..60)
321            .map(|t| {
322                let b = 100.0 + f64::from(t);
323                let a = 2.0 * b + 5.0 + 0.4 * (f64::from(t) * 0.9).sin();
324                (a, b)
325            })
326            .collect();
327        let out = Cointegration::new(30, 1)
328            .unwrap()
329            .batch(&pairs)
330            .into_iter()
331            .flatten()
332            .last()
333            .unwrap();
334        assert!(
335            (out.hedge_ratio - 2.0).abs() < 0.1,
336            "beta {}",
337            out.hedge_ratio
338        );
339    }
340
341    #[test]
342    fn stationary_spread_is_strongly_negative() {
343        // A clean mean-reverting (sinusoidal) spread ⇒ very negative ADF.
344        let pairs: Vec<(f64, f64)> = (0..80)
345            .map(|t| {
346                let b = 50.0 + 0.5 * f64::from(t);
347                let a = 2.0 * b + 1.0 + 0.5 * (f64::from(t) * 0.6).sin();
348                (a, b)
349            })
350            .collect();
351        let out = Cointegration::new(40, 1)
352            .unwrap()
353            .batch(&pairs)
354            .into_iter()
355            .flatten()
356            .last()
357            .unwrap();
358        assert!(out.adf_stat < -2.0, "adf {}", out.adf_stat);
359    }
360
361    #[test]
362    fn perfect_cointegration_has_zero_spread_and_defined_ratio() {
363        // a = 2·b + 5 exactly ⇒ residuals all zero ⇒ ADF degenerate ⇒ 0.
364        let pairs: Vec<(f64, f64)> = (0..40)
365            .map(|t| {
366                let b = 100.0 + f64::from(t);
367                (2.0 * b + 5.0, b)
368            })
369            .collect();
370        let out = Cointegration::new(20, 1)
371            .unwrap()
372            .batch(&pairs)
373            .into_iter()
374            .flatten()
375            .last()
376            .unwrap();
377        assert_relative_eq!(out.hedge_ratio, 2.0, epsilon = 1e-9);
378        assert_relative_eq!(out.spread, 0.0, epsilon = 1e-6);
379        assert_relative_eq!(out.adf_stat, 0.0, epsilon = 1e-12);
380    }
381
382    #[test]
383    fn flat_b_falls_back_to_level() {
384        // Constant b ⇒ no slope ⇒ hedge ratio 0, spread = a − mean(a).
385        let pairs: Vec<(f64, f64)> = (0..20)
386            .map(|t| (10.0 + 0.3 * (f64::from(t) * 0.5).sin(), 7.0))
387            .collect();
388        let out = Cointegration::new(10, 0)
389            .unwrap()
390            .batch(&pairs)
391            .into_iter()
392            .flatten()
393            .last()
394            .unwrap();
395        assert_relative_eq!(out.hedge_ratio, 0.0, epsilon = 1e-12);
396    }
397
398    #[test]
399    fn plain_dickey_fuller_lags_zero() {
400        // Exercise the lags = 0 path (1×1 ADF system).
401        let pairs: Vec<(f64, f64)> = (0..40)
402            .map(|t| {
403                let b = 20.0 + 0.4 * f64::from(t);
404                let a = 1.5 * b + 0.6 * (f64::from(t) * 0.7).sin();
405                (a, b)
406            })
407            .collect();
408        let out = Cointegration::new(20, 0)
409            .unwrap()
410            .batch(&pairs)
411            .into_iter()
412            .flatten()
413            .last()
414            .unwrap();
415        assert!((out.hedge_ratio - 1.5).abs() < 0.1);
416        assert!(out.adf_stat < 0.0);
417    }
418
419    #[test]
420    fn reset_clears_state() {
421        let mut c = Cointegration::new(10, 1).unwrap();
422        for t in 0..20 {
423            let b = 100.0 + f64::from(t);
424            c.update((2.0 * b + (f64::from(t) * 0.5).sin(), b));
425        }
426        assert!(c.is_ready());
427        c.reset();
428        assert!(!c.is_ready());
429        assert_eq!(c.update((1.0, 1.0)), None);
430    }
431
432    #[test]
433    fn batch_equals_streaming() {
434        let pairs: Vec<(f64, f64)> = (0..80)
435            .map(|t| {
436                let b = 30.0 + 0.7 * f64::from(t);
437                let a = 1.8 * b + 2.0 + 0.5 * (f64::from(t) * 0.4).sin();
438                (a, b)
439            })
440            .collect();
441        let batch = Cointegration::new(25, 2).unwrap().batch(&pairs);
442        let mut c = Cointegration::new(25, 2).unwrap();
443        let streamed: Vec<_> = pairs.iter().map(|p| c.update(*p)).collect();
444        assert_eq!(batch, streamed);
445    }
446}