1use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8#[derive(Debug, Clone, Copy, PartialEq)]
10pub struct CointegrationOutput {
11 pub hedge_ratio: f64,
13 pub spread: f64,
15 pub adf_stat: f64,
20}
21
22#[derive(Debug, Clone)]
62pub struct Cointegration {
63 period: usize,
64 adf_lags: usize,
65 window: VecDeque<(f64, f64)>,
66 sum_a: f64,
67 sum_b: f64,
68 sum_bb: f64,
69 sum_ab: f64,
70}
71
72impl Cointegration {
73 pub fn new(period: usize, adf_lags: usize) -> Result<Self> {
84 let min_period = 2 * adf_lags + 4;
85 if period < min_period {
86 return Err(Error::InvalidPeriod {
87 message: "cointegration needs period >= 2*adf_lags + 4",
88 });
89 }
90 Ok(Self {
91 period,
92 adf_lags,
93 window: VecDeque::with_capacity(period),
94 sum_a: 0.0,
95 sum_b: 0.0,
96 sum_bb: 0.0,
97 sum_ab: 0.0,
98 })
99 }
100
101 pub const fn period(&self) -> usize {
103 self.period
104 }
105
106 pub const fn adf_lags(&self) -> usize {
108 self.adf_lags
109 }
110}
111
112impl Indicator for Cointegration {
113 type Input = (f64, f64);
115 type Output = CointegrationOutput;
116
117 fn update(&mut self, input: (f64, f64)) -> Option<CointegrationOutput> {
118 let (a, b) = input;
119 if !a.is_finite() || !b.is_finite() {
120 return None;
121 }
122 if self.window.len() == self.period {
123 let (oa, ob) = self.window.pop_front().expect("non-empty");
124 self.sum_a -= oa;
125 self.sum_b -= ob;
126 self.sum_bb -= ob * ob;
127 self.sum_ab -= oa * ob;
128 }
129 self.window.push_back((a, b));
130 self.sum_a += a;
131 self.sum_b += b;
132 self.sum_bb += b * b;
133 self.sum_ab += a * b;
134 if self.window.len() < self.period {
135 return None;
136 }
137 let n = self.period as f64;
138 let mean_a = self.sum_a / n;
139 let mean_b = self.sum_b / n;
140 let var_b = (self.sum_bb / n - mean_b * mean_b).max(0.0);
141 let (hedge_ratio, intercept) = if var_b == 0.0 {
142 (0.0, mean_a)
144 } else {
145 let cov = self.sum_ab / n - mean_a * mean_b;
146 let beta = cov / var_b;
147 (beta, mean_a - beta * mean_b)
148 };
149 let spreads: Vec<f64> = self
151 .window
152 .iter()
153 .map(|&(ai, bi)| ai - (intercept + hedge_ratio * bi))
154 .collect();
155 let spread = *spreads.last().expect("window is full");
156 let adf_stat = adf_no_constant(&spreads, self.adf_lags);
157 Some(CointegrationOutput {
158 hedge_ratio,
159 spread,
160 adf_stat,
161 })
162 }
163
164 fn reset(&mut self) {
165 self.window.clear();
166 self.sum_a = 0.0;
167 self.sum_b = 0.0;
168 self.sum_bb = 0.0;
169 self.sum_ab = 0.0;
170 }
171
172 fn warmup_period(&self) -> usize {
173 self.period
174 }
175
176 fn is_ready(&self) -> bool {
177 self.window.len() == self.period
178 }
179
180 fn name(&self) -> &'static str {
181 "Cointegration"
182 }
183}
184
185fn solve(mut mat: Vec<Vec<f64>>, mut rhs: Vec<f64>) -> Option<Vec<f64>> {
190 let dim = rhs.len();
191 for col in 0..dim {
192 let pivot = mat[col][col];
193 if pivot.abs() < 1e-12 {
194 return None;
195 }
196 let pivot_row = mat[col].clone();
197 for row in (col + 1)..dim {
198 let factor = mat[row][col] / pivot;
199 for (cell, &above) in mat[row].iter_mut().zip(&pivot_row).skip(col) {
200 *cell -= factor * above;
201 }
202 rhs[row] -= factor * rhs[col];
203 }
204 }
205 let mut sol = vec![0.0; dim];
206 for row in (0..dim).rev() {
207 let known: f64 = mat[row]
208 .iter()
209 .zip(&sol)
210 .skip(row + 1)
211 .map(|(coeff, value)| coeff * value)
212 .sum();
213 sol[row] = (rhs[row] - known) / mat[row][row];
214 }
215 Some(sol)
216}
217
218fn adf_no_constant(series: &[f64], lags: usize) -> f64 {
225 let len = series.len();
226 let num_reg = lags + 1; let first = lags + 1; if len <= first {
229 return 0.0;
230 }
231 let num_obs = len - first;
232 if num_obs <= num_reg {
233 return 0.0; }
235 let regressors = |idx: usize| -> Vec<f64> {
236 let mut row = vec![0.0; num_reg];
237 row[0] = series[idx - 1];
238 for lag in 1..=lags {
239 row[lag] = series[idx - lag] - series[idx - lag - 1];
240 }
241 row
242 };
243 let mut xtx = vec![vec![0.0; num_reg]; num_reg];
244 let mut xty = vec![0.0; num_reg];
245 for idx in first..len {
246 let diff = series[idx] - series[idx - 1];
247 let row = regressors(idx);
248 for (ri, &left) in row.iter().enumerate() {
249 xty[ri] += left * diff;
250 for (ci, &right) in row.iter().enumerate() {
251 xtx[ri][ci] += left * right;
252 }
253 }
254 }
255 let Some(theta) = solve(xtx.clone(), xty) else {
256 return 0.0;
257 };
258 let rho = theta[0];
259 let mut rss = 0.0;
260 for idx in first..len {
261 let diff = series[idx] - series[idx - 1];
262 let pred: f64 = regressors(idx)
263 .iter()
264 .zip(&theta)
265 .map(|(coeff, value)| coeff * value)
266 .sum();
267 let resid = diff - pred;
268 rss += resid * resid;
269 }
270 let dof = (num_obs - num_reg) as f64;
271 let sigma2 = rss / dof;
272 let mut unit = vec![0.0; num_reg];
275 unit[0] = 1.0;
276 let inverse = solve(xtx, unit).expect("xtx is non-singular: the coefficient solve succeeded");
277 let var_rho = sigma2 * inverse[0];
278 if var_rho <= 0.0 {
279 return 0.0;
280 }
281 rho / var_rho.sqrt()
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287 use crate::traits::BatchExt;
288 use approx::assert_relative_eq;
289
290 #[test]
291 fn rejects_too_small_period() {
292 assert!(Cointegration::new(3, 0).is_err()); assert!(Cointegration::new(4, 0).is_ok());
295 assert!(Cointegration::new(5, 1).is_err()); assert!(Cointegration::new(6, 1).is_ok());
297 }
298
299 #[test]
300 fn accessors_and_metadata() {
301 let c = Cointegration::new(30, 2).unwrap();
302 assert_eq!(c.period(), 30);
303 assert_eq!(c.adf_lags(), 2);
304 assert_eq!(c.warmup_period(), 30);
305 assert_eq!(c.name(), "Cointegration");
306 }
307
308 #[test]
309 fn adf_guards_and_degenerate_spread() {
310 assert_eq!(adf_no_constant(&[1.0], 1), 0.0);
312 assert_eq!(adf_no_constant(&[1.0, 2.0, 3.0], 1), 0.0);
314 let geom: Vec<f64> = (0..8).map(|t| 0.5_f64.powi(t)).collect();
317 assert_eq!(adf_no_constant(&geom, 0), 0.0);
318 }
319
320 #[test]
321 fn recovers_hedge_ratio() {
322 let pairs: Vec<(f64, f64)> = (0..60)
324 .map(|t| {
325 let b = 100.0 + f64::from(t);
326 let a = 2.0 * b + 5.0 + 0.4 * (f64::from(t) * 0.9).sin();
327 (a, b)
328 })
329 .collect();
330 let out = Cointegration::new(30, 1)
331 .unwrap()
332 .batch(&pairs)
333 .into_iter()
334 .flatten()
335 .last()
336 .unwrap();
337 assert!(
338 (out.hedge_ratio - 2.0).abs() < 0.1,
339 "beta {}",
340 out.hedge_ratio
341 );
342 }
343
344 #[test]
345 fn stationary_spread_is_strongly_negative() {
346 let pairs: Vec<(f64, f64)> = (0..80)
348 .map(|t| {
349 let b = 50.0 + 0.5 * f64::from(t);
350 let a = 2.0 * b + 1.0 + 0.5 * (f64::from(t) * 0.6).sin();
351 (a, b)
352 })
353 .collect();
354 let out = Cointegration::new(40, 1)
355 .unwrap()
356 .batch(&pairs)
357 .into_iter()
358 .flatten()
359 .last()
360 .unwrap();
361 assert!(out.adf_stat < -2.0, "adf {}", out.adf_stat);
362 }
363
364 #[test]
365 fn perfect_cointegration_has_zero_spread_and_defined_ratio() {
366 let pairs: Vec<(f64, f64)> = (0..40)
368 .map(|t| {
369 let b = 100.0 + f64::from(t);
370 (2.0 * b + 5.0, b)
371 })
372 .collect();
373 let out = Cointegration::new(20, 1)
374 .unwrap()
375 .batch(&pairs)
376 .into_iter()
377 .flatten()
378 .last()
379 .unwrap();
380 assert_relative_eq!(out.hedge_ratio, 2.0, epsilon = 1e-9);
381 assert_relative_eq!(out.spread, 0.0, epsilon = 1e-6);
382 assert_relative_eq!(out.adf_stat, 0.0, epsilon = 1e-12);
383 }
384
385 #[test]
386 fn flat_b_falls_back_to_level() {
387 let pairs: Vec<(f64, f64)> = (0..20)
389 .map(|t| (10.0 + 0.3 * (f64::from(t) * 0.5).sin(), 7.0))
390 .collect();
391 let out = Cointegration::new(10, 0)
392 .unwrap()
393 .batch(&pairs)
394 .into_iter()
395 .flatten()
396 .last()
397 .unwrap();
398 assert_relative_eq!(out.hedge_ratio, 0.0, epsilon = 1e-12);
399 }
400
401 #[test]
402 fn plain_dickey_fuller_lags_zero() {
403 let pairs: Vec<(f64, f64)> = (0..40)
405 .map(|t| {
406 let b = 20.0 + 0.4 * f64::from(t);
407 let a = 1.5 * b + 0.6 * (f64::from(t) * 0.7).sin();
408 (a, b)
409 })
410 .collect();
411 let out = Cointegration::new(20, 0)
412 .unwrap()
413 .batch(&pairs)
414 .into_iter()
415 .flatten()
416 .last()
417 .unwrap();
418 assert!((out.hedge_ratio - 1.5).abs() < 0.1);
419 assert!(out.adf_stat < 0.0);
420 }
421
422 #[test]
423 fn reset_clears_state() {
424 let mut c = Cointegration::new(10, 1).unwrap();
425 for t in 0..20 {
426 let b = 100.0 + f64::from(t);
427 c.update((2.0 * b + (f64::from(t) * 0.5).sin(), b));
428 }
429 assert!(c.is_ready());
430 c.reset();
431 assert!(!c.is_ready());
432 assert_eq!(c.update((1.0, 1.0)), None);
433 }
434
435 #[test]
436 fn batch_equals_streaming() {
437 let pairs: Vec<(f64, f64)> = (0..80)
438 .map(|t| {
439 let b = 30.0 + 0.7 * f64::from(t);
440 let a = 1.8 * b + 2.0 + 0.5 * (f64::from(t) * 0.4).sin();
441 (a, b)
442 })
443 .collect();
444 let batch = Cointegration::new(25, 2).unwrap().batch(&pairs);
445 let mut c = Cointegration::new(25, 2).unwrap();
446 let streamed: Vec<_> = pairs.iter().map(|p| c.update(*p)).collect();
447 assert_eq!(batch, streamed);
448 }
449
450 #[test]
451 fn non_finite_input_returns_none() {
452 let mut c = Cointegration::new(4, 0).unwrap();
453 assert_eq!(c.update((f64::NAN, 1.0)), None);
454 assert_eq!(c.update((1.0, f64::INFINITY)), None);
455 assert_eq!(c.update((1.0, 2.0)), None);
457 assert_eq!(c.update((2.0, 5.0)), None);
458 assert_eq!(c.update((3.0, 7.0)), None);
459 assert!(c.update((4.0, 11.0)).is_some());
460 }
461}