1use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8#[derive(Debug, Clone, Copy, PartialEq)]
10pub struct CointegrationOutput {
11 pub hedge_ratio: f64,
13 pub spread: f64,
15 pub adf_stat: f64,
20}
21
22#[derive(Debug, Clone)]
62pub struct Cointegration {
63 period: usize,
64 adf_lags: usize,
65 window: VecDeque<(f64, f64)>,
66 sum_a: f64,
67 sum_b: f64,
68 sum_bb: f64,
69 sum_ab: f64,
70}
71
72impl Cointegration {
73 pub fn new(period: usize, adf_lags: usize) -> Result<Self> {
84 let min_period = 2 * adf_lags + 4;
85 if period < min_period {
86 return Err(Error::InvalidPeriod {
87 message: "cointegration needs period >= 2*adf_lags + 4",
88 });
89 }
90 Ok(Self {
91 period,
92 adf_lags,
93 window: VecDeque::with_capacity(period),
94 sum_a: 0.0,
95 sum_b: 0.0,
96 sum_bb: 0.0,
97 sum_ab: 0.0,
98 })
99 }
100
101 pub const fn period(&self) -> usize {
103 self.period
104 }
105
106 pub const fn adf_lags(&self) -> usize {
108 self.adf_lags
109 }
110}
111
112impl Indicator for Cointegration {
113 type Input = (f64, f64);
115 type Output = CointegrationOutput;
116
117 fn update(&mut self, input: (f64, f64)) -> Option<CointegrationOutput> {
118 let (a, b) = input;
119 if self.window.len() == self.period {
120 let (oa, ob) = self.window.pop_front().expect("non-empty");
121 self.sum_a -= oa;
122 self.sum_b -= ob;
123 self.sum_bb -= ob * ob;
124 self.sum_ab -= oa * ob;
125 }
126 self.window.push_back((a, b));
127 self.sum_a += a;
128 self.sum_b += b;
129 self.sum_bb += b * b;
130 self.sum_ab += a * b;
131 if self.window.len() < self.period {
132 return None;
133 }
134 let n = self.period as f64;
135 let mean_a = self.sum_a / n;
136 let mean_b = self.sum_b / n;
137 let var_b = (self.sum_bb / n - mean_b * mean_b).max(0.0);
138 let (hedge_ratio, intercept) = if var_b == 0.0 {
139 (0.0, mean_a)
141 } else {
142 let cov = self.sum_ab / n - mean_a * mean_b;
143 let beta = cov / var_b;
144 (beta, mean_a - beta * mean_b)
145 };
146 let spreads: Vec<f64> = self
148 .window
149 .iter()
150 .map(|&(ai, bi)| ai - (intercept + hedge_ratio * bi))
151 .collect();
152 let spread = *spreads.last().expect("window is full");
153 let adf_stat = adf_no_constant(&spreads, self.adf_lags);
154 Some(CointegrationOutput {
155 hedge_ratio,
156 spread,
157 adf_stat,
158 })
159 }
160
161 fn reset(&mut self) {
162 self.window.clear();
163 self.sum_a = 0.0;
164 self.sum_b = 0.0;
165 self.sum_bb = 0.0;
166 self.sum_ab = 0.0;
167 }
168
169 fn warmup_period(&self) -> usize {
170 self.period
171 }
172
173 fn is_ready(&self) -> bool {
174 self.window.len() == self.period
175 }
176
177 fn name(&self) -> &'static str {
178 "Cointegration"
179 }
180}
181
182fn solve(mut mat: Vec<Vec<f64>>, mut rhs: Vec<f64>) -> Option<Vec<f64>> {
187 let dim = rhs.len();
188 for col in 0..dim {
189 let pivot = mat[col][col];
190 if pivot.abs() < 1e-12 {
191 return None;
192 }
193 let pivot_row = mat[col].clone();
194 for row in (col + 1)..dim {
195 let factor = mat[row][col] / pivot;
196 for (cell, &above) in mat[row].iter_mut().zip(&pivot_row).skip(col) {
197 *cell -= factor * above;
198 }
199 rhs[row] -= factor * rhs[col];
200 }
201 }
202 let mut sol = vec![0.0; dim];
203 for row in (0..dim).rev() {
204 let known: f64 = mat[row]
205 .iter()
206 .zip(&sol)
207 .skip(row + 1)
208 .map(|(coeff, value)| coeff * value)
209 .sum();
210 sol[row] = (rhs[row] - known) / mat[row][row];
211 }
212 Some(sol)
213}
214
215fn adf_no_constant(series: &[f64], lags: usize) -> f64 {
222 let len = series.len();
223 let num_reg = lags + 1; let first = lags + 1; if len <= first {
226 return 0.0;
227 }
228 let num_obs = len - first;
229 if num_obs <= num_reg {
230 return 0.0; }
232 let regressors = |idx: usize| -> Vec<f64> {
233 let mut row = vec![0.0; num_reg];
234 row[0] = series[idx - 1];
235 for lag in 1..=lags {
236 row[lag] = series[idx - lag] - series[idx - lag - 1];
237 }
238 row
239 };
240 let mut xtx = vec![vec![0.0; num_reg]; num_reg];
241 let mut xty = vec![0.0; num_reg];
242 for idx in first..len {
243 let diff = series[idx] - series[idx - 1];
244 let row = regressors(idx);
245 for (ri, &left) in row.iter().enumerate() {
246 xty[ri] += left * diff;
247 for (ci, &right) in row.iter().enumerate() {
248 xtx[ri][ci] += left * right;
249 }
250 }
251 }
252 let Some(theta) = solve(xtx.clone(), xty) else {
253 return 0.0;
254 };
255 let rho = theta[0];
256 let mut rss = 0.0;
257 for idx in first..len {
258 let diff = series[idx] - series[idx - 1];
259 let pred: f64 = regressors(idx)
260 .iter()
261 .zip(&theta)
262 .map(|(coeff, value)| coeff * value)
263 .sum();
264 let resid = diff - pred;
265 rss += resid * resid;
266 }
267 let dof = (num_obs - num_reg) as f64;
268 let sigma2 = rss / dof;
269 let mut unit = vec![0.0; num_reg];
272 unit[0] = 1.0;
273 let inverse = solve(xtx, unit).expect("xtx is non-singular: the coefficient solve succeeded");
274 let var_rho = sigma2 * inverse[0];
275 if var_rho <= 0.0 {
276 return 0.0;
277 }
278 rho / var_rho.sqrt()
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284 use crate::traits::BatchExt;
285 use approx::assert_relative_eq;
286
287 #[test]
288 fn rejects_too_small_period() {
289 assert!(Cointegration::new(3, 0).is_err()); assert!(Cointegration::new(4, 0).is_ok());
292 assert!(Cointegration::new(5, 1).is_err()); assert!(Cointegration::new(6, 1).is_ok());
294 }
295
296 #[test]
297 fn accessors_and_metadata() {
298 let c = Cointegration::new(30, 2).unwrap();
299 assert_eq!(c.period(), 30);
300 assert_eq!(c.adf_lags(), 2);
301 assert_eq!(c.warmup_period(), 30);
302 assert_eq!(c.name(), "Cointegration");
303 }
304
305 #[test]
306 fn adf_guards_and_degenerate_spread() {
307 assert_eq!(adf_no_constant(&[1.0], 1), 0.0);
309 assert_eq!(adf_no_constant(&[1.0, 2.0, 3.0], 1), 0.0);
311 let geom: Vec<f64> = (0..8).map(|t| 0.5_f64.powi(t)).collect();
314 assert_eq!(adf_no_constant(&geom, 0), 0.0);
315 }
316
317 #[test]
318 fn recovers_hedge_ratio() {
319 let pairs: Vec<(f64, f64)> = (0..60)
321 .map(|t| {
322 let b = 100.0 + f64::from(t);
323 let a = 2.0 * b + 5.0 + 0.4 * (f64::from(t) * 0.9).sin();
324 (a, b)
325 })
326 .collect();
327 let out = Cointegration::new(30, 1)
328 .unwrap()
329 .batch(&pairs)
330 .into_iter()
331 .flatten()
332 .last()
333 .unwrap();
334 assert!(
335 (out.hedge_ratio - 2.0).abs() < 0.1,
336 "beta {}",
337 out.hedge_ratio
338 );
339 }
340
341 #[test]
342 fn stationary_spread_is_strongly_negative() {
343 let pairs: Vec<(f64, f64)> = (0..80)
345 .map(|t| {
346 let b = 50.0 + 0.5 * f64::from(t);
347 let a = 2.0 * b + 1.0 + 0.5 * (f64::from(t) * 0.6).sin();
348 (a, b)
349 })
350 .collect();
351 let out = Cointegration::new(40, 1)
352 .unwrap()
353 .batch(&pairs)
354 .into_iter()
355 .flatten()
356 .last()
357 .unwrap();
358 assert!(out.adf_stat < -2.0, "adf {}", out.adf_stat);
359 }
360
361 #[test]
362 fn perfect_cointegration_has_zero_spread_and_defined_ratio() {
363 let pairs: Vec<(f64, f64)> = (0..40)
365 .map(|t| {
366 let b = 100.0 + f64::from(t);
367 (2.0 * b + 5.0, b)
368 })
369 .collect();
370 let out = Cointegration::new(20, 1)
371 .unwrap()
372 .batch(&pairs)
373 .into_iter()
374 .flatten()
375 .last()
376 .unwrap();
377 assert_relative_eq!(out.hedge_ratio, 2.0, epsilon = 1e-9);
378 assert_relative_eq!(out.spread, 0.0, epsilon = 1e-6);
379 assert_relative_eq!(out.adf_stat, 0.0, epsilon = 1e-12);
380 }
381
382 #[test]
383 fn flat_b_falls_back_to_level() {
384 let pairs: Vec<(f64, f64)> = (0..20)
386 .map(|t| (10.0 + 0.3 * (f64::from(t) * 0.5).sin(), 7.0))
387 .collect();
388 let out = Cointegration::new(10, 0)
389 .unwrap()
390 .batch(&pairs)
391 .into_iter()
392 .flatten()
393 .last()
394 .unwrap();
395 assert_relative_eq!(out.hedge_ratio, 0.0, epsilon = 1e-12);
396 }
397
398 #[test]
399 fn plain_dickey_fuller_lags_zero() {
400 let pairs: Vec<(f64, f64)> = (0..40)
402 .map(|t| {
403 let b = 20.0 + 0.4 * f64::from(t);
404 let a = 1.5 * b + 0.6 * (f64::from(t) * 0.7).sin();
405 (a, b)
406 })
407 .collect();
408 let out = Cointegration::new(20, 0)
409 .unwrap()
410 .batch(&pairs)
411 .into_iter()
412 .flatten()
413 .last()
414 .unwrap();
415 assert!((out.hedge_ratio - 1.5).abs() < 0.1);
416 assert!(out.adf_stat < 0.0);
417 }
418
419 #[test]
420 fn reset_clears_state() {
421 let mut c = Cointegration::new(10, 1).unwrap();
422 for t in 0..20 {
423 let b = 100.0 + f64::from(t);
424 c.update((2.0 * b + (f64::from(t) * 0.5).sin(), b));
425 }
426 assert!(c.is_ready());
427 c.reset();
428 assert!(!c.is_ready());
429 assert_eq!(c.update((1.0, 1.0)), None);
430 }
431
432 #[test]
433 fn batch_equals_streaming() {
434 let pairs: Vec<(f64, f64)> = (0..80)
435 .map(|t| {
436 let b = 30.0 + 0.7 * f64::from(t);
437 let a = 1.8 * b + 2.0 + 0.5 * (f64::from(t) * 0.4).sin();
438 (a, b)
439 })
440 .collect();
441 let batch = Cointegration::new(25, 2).unwrap().batch(&pairs);
442 let mut c = Cointegration::new(25, 2).unwrap();
443 let streamed: Vec<_> = pairs.iter().map(|p| c.update(*p)).collect();
444 assert_eq!(batch, streamed);
445 }
446}