wickra_core/indicators/
linreg_intercept.rs1use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8#[derive(Debug, Clone)]
35pub struct LinRegIntercept {
36 period: usize,
37 window: VecDeque<f64>,
38 sum_x: f64,
39 denom: f64,
40 sum_y: f64,
41 sum_xy: f64,
42}
43
44impl LinRegIntercept {
45 pub fn new(period: usize) -> Result<Self> {
51 if period < 2 {
52 return Err(Error::InvalidPeriod {
53 message: "linear regression intercept needs period >= 2",
54 });
55 }
56 let n = period as f64;
57 let sum_x = n * (n - 1.0) / 2.0;
58 let sum_xx = (n - 1.0) * n * (2.0 * n - 1.0) / 6.0;
59 Ok(Self {
60 period,
61 window: VecDeque::with_capacity(period),
62 sum_x,
63 denom: n * sum_xx - sum_x * sum_x,
64 sum_y: 0.0,
65 sum_xy: 0.0,
66 })
67 }
68
69 pub const fn period(&self) -> usize {
71 self.period
72 }
73}
74
75impl Indicator for LinRegIntercept {
76 type Input = f64;
77 type Output = f64;
78
79 fn update(&mut self, value: f64) -> Option<f64> {
80 if self.window.len() == self.period {
81 let y0 = self.window.pop_front().expect("non-empty");
82 self.sum_xy = self.sum_xy - self.sum_y + y0;
83 self.sum_y -= y0;
84 }
85 let k = self.window.len() as f64;
86 self.window.push_back(value);
87 self.sum_y += value;
88 self.sum_xy += k * value;
89
90 if self.window.len() < self.period {
91 return None;
92 }
93 let n = self.period as f64;
94 let slope = (n * self.sum_xy - self.sum_x * self.sum_y) / self.denom;
95 let intercept = (self.sum_y - slope * self.sum_x) / n;
96 Some(intercept)
97 }
98
99 fn reset(&mut self) {
100 self.window.clear();
101 self.sum_y = 0.0;
102 self.sum_xy = 0.0;
103 }
104
105 fn warmup_period(&self) -> usize {
106 self.period
107 }
108
109 fn is_ready(&self) -> bool {
110 self.window.len() == self.period
111 }
112
113 fn name(&self) -> &'static str {
114 "LINEARREG_INTERCEPT"
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121 use crate::traits::BatchExt;
122 use approx::assert_relative_eq;
123
124 #[test]
125 fn rejects_short_period() {
126 assert!(matches!(
127 LinRegIntercept::new(1),
128 Err(Error::InvalidPeriod { .. })
129 ));
130 }
131
132 #[test]
133 fn accessors_report_config() {
134 let lr = LinRegIntercept::new(5).unwrap();
135 assert_eq!(lr.period(), 5);
136 assert_eq!(lr.name(), "LINEARREG_INTERCEPT");
137 assert_eq!(lr.warmup_period(), 5);
138 assert!(!lr.is_ready());
139 }
140
141 #[test]
142 fn reference_value() {
143 let mut lr = LinRegIntercept::new(3).unwrap();
145 let out: Vec<Option<f64>> = lr.batch(&[1.0, 2.0, 9.0]);
146 assert!(out[0].is_none());
147 assert!(out[1].is_none());
148 assert_relative_eq!(out[2].unwrap(), 0.0, epsilon = 1e-9);
149 assert!(lr.is_ready());
150 }
151
152 #[test]
153 fn slides_and_tracks_a_shifted_line() {
154 let mut lr = LinRegIntercept::new(3).unwrap();
157 let out: Vec<Option<f64>> = lr.batch(&[1.0, 10.0, 12.0, 14.0]);
158 assert_relative_eq!(out[3].unwrap(), 10.0, epsilon = 1e-9);
159 }
160
161 #[test]
162 fn reset_clears_state() {
163 let mut lr = LinRegIntercept::new(3).unwrap();
164 let _ = lr.batch(&[1.0, 2.0, 9.0]);
165 assert!(lr.is_ready());
166 lr.reset();
167 assert!(!lr.is_ready());
168 assert_eq!(lr.update(1.0), None);
169 }
170}