wickra_core/indicators/
linreg.rs1use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8#[derive(Debug, Clone)]
50pub struct LinearRegression {
51 period: usize,
52 window: VecDeque<f64>,
53 sum_x: f64,
55 denom: f64,
58 sum_y: f64,
60 sum_xy: f64,
63}
64
65impl LinearRegression {
66 pub fn new(period: usize) -> Result<Self> {
72 if period < 2 {
73 return Err(Error::InvalidPeriod {
74 message: "linear regression needs period >= 2",
75 });
76 }
77 let n = period as f64;
78 let sum_x = n * (n - 1.0) / 2.0;
80 let sum_xx = (n - 1.0) * n * (2.0 * n - 1.0) / 6.0;
81 Ok(Self {
82 period,
83 window: VecDeque::with_capacity(period),
84 sum_x,
85 denom: n * sum_xx - sum_x * sum_x,
86 sum_y: 0.0,
87 sum_xy: 0.0,
88 })
89 }
90
91 pub const fn period(&self) -> usize {
93 self.period
94 }
95}
96
97impl Indicator for LinearRegression {
98 type Input = f64;
99 type Output = f64;
100
101 fn update(&mut self, value: f64) -> Option<f64> {
102 if self.window.len() == self.period {
103 let y0 = self.window.pop_front().expect("non-empty");
108 self.sum_xy = self.sum_xy - self.sum_y + y0;
109 self.sum_y -= y0;
110 }
111 let k = self.window.len() as f64;
115 self.window.push_back(value);
116 self.sum_y += value;
117 self.sum_xy += k * value;
118
119 if self.window.len() < self.period {
120 return None;
121 }
122 let n = self.period as f64;
123 let slope = (n * self.sum_xy - self.sum_x * self.sum_y) / self.denom;
124 let intercept = (self.sum_y - slope * self.sum_x) / n;
125 Some(intercept + slope * (n - 1.0))
126 }
127
128 fn reset(&mut self) {
129 self.window.clear();
130 self.sum_y = 0.0;
131 self.sum_xy = 0.0;
132 }
133
134 fn warmup_period(&self) -> usize {
135 self.period
136 }
137
138 fn is_ready(&self) -> bool {
139 self.window.len() == self.period
140 }
141
142 fn name(&self) -> &'static str {
143 "LinearRegression"
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150 use crate::traits::BatchExt;
151 use approx::assert_relative_eq;
152
153 #[test]
154 fn reference_values() {
155 let mut lr = LinearRegression::new(3).unwrap();
157 let out = lr.batch(&[1.0, 2.0, 9.0]);
158 assert!(out[0].is_none());
159 assert!(out[1].is_none());
160 assert_relative_eq!(out[2].unwrap(), 8.0, epsilon = 1e-9);
161 }
162
163 #[test]
164 fn perfect_line_returns_current_value() {
165 let prices: Vec<f64> = (0..40).map(|i| 2.0 * f64::from(i) + 5.0).collect();
168 let mut lr = LinearRegression::new(10).unwrap();
169 for (i, v) in lr.batch(&prices).into_iter().enumerate() {
170 if let Some(v) = v {
171 assert_relative_eq!(v, 2.0 * i as f64 + 5.0, epsilon = 1e-6);
172 }
173 }
174 }
175
176 #[test]
177 fn constant_series_returns_the_constant() {
178 let mut lr = LinearRegression::new(8).unwrap();
179 for v in lr.batch(&[42.0; 20]).into_iter().flatten() {
180 assert_relative_eq!(v, 42.0, epsilon = 1e-9);
181 }
182 }
183
184 #[test]
185 fn first_value_on_period_th_input() {
186 let mut lr = LinearRegression::new(5).unwrap();
187 let out = lr.batch(&[1.0, 3.0, 2.0, 5.0, 4.0, 6.0]);
188 for (i, v) in out.iter().enumerate().take(4) {
189 assert!(v.is_none(), "index {i} must be None during warmup");
190 }
191 assert!(out[4].is_some(), "first value lands at index period - 1");
192 assert_eq!(lr.warmup_period(), 5);
193 }
194
195 #[test]
196 fn rejects_period_below_two() {
197 assert!(LinearRegression::new(0).is_err());
198 assert!(LinearRegression::new(1).is_err());
199 assert!(LinearRegression::new(2).is_ok());
200 }
201
202 #[test]
205 fn accessors_and_metadata() {
206 let lr = LinearRegression::new(14).unwrap();
207 assert_eq!(lr.period(), 14);
208 assert_eq!(lr.name(), "LinearRegression");
209 }
210
211 #[test]
212 fn reset_clears_state() {
213 let mut lr = LinearRegression::new(5).unwrap();
214 lr.batch(&[1.0, 2.0, 3.0, 4.0, 5.0]);
215 assert!(lr.is_ready());
216 lr.reset();
217 assert!(!lr.is_ready());
218 assert_eq!(lr.update(1.0), None);
219 }
220
221 #[test]
222 fn batch_equals_streaming() {
223 let prices: Vec<f64> = (0..60)
224 .map(|i| 50.0 + (f64::from(i) * 0.3).sin() * 10.0)
225 .collect();
226 let mut a = LinearRegression::new(14).unwrap();
227 let mut b = LinearRegression::new(14).unwrap();
228 assert_eq!(
229 a.batch(&prices),
230 prices.iter().map(|x| b.update(*x)).collect::<Vec<_>>()
231 );
232 }
233
234 #[test]
241 fn incremental_matches_naive_fit_bar_by_bar() {
242 fn naive_endpoint(window: &[f64]) -> f64 {
243 let n = window.len() as f64;
244 let mut sum_y = 0.0;
245 let mut sum_xy = 0.0;
246 let mut sum_x = 0.0;
247 let mut sum_xx = 0.0;
248 for (i, &y) in window.iter().enumerate() {
249 let x = i as f64;
250 sum_y += y;
251 sum_xy += x * y;
252 sum_x += x;
253 sum_xx += x * x;
254 }
255 let denom = n * sum_xx - sum_x * sum_x;
256 let slope = (n * sum_xy - sum_x * sum_y) / denom;
257 let intercept = (sum_y - slope * sum_x) / n;
258 intercept + slope * (n - 1.0)
259 }
260
261 fn check(prices: &[f64], period: usize) {
262 let mut lr = LinearRegression::new(period).unwrap();
263 for (t, p) in prices.iter().enumerate() {
264 let streaming = lr.update(*p);
265 if t + 1 >= period {
266 let lo = t + 1 - period;
267 let expected = naive_endpoint(&prices[lo..=t]);
268 let got = streaming.expect("warmed up");
269 assert!(
270 (got - expected).abs() < 1e-9,
271 "endpoint diverges at t={t}, period={period}: got={got}, expected={expected}",
272 );
273 }
274 }
275 }
276
277 let noisy_ramp: Vec<f64> = (0..120)
278 .map(|i| 100.0 + f64::from(i) * 0.5 + (f64::from(i) * 0.7).sin() * 3.0)
279 .collect();
280 check(&noisy_ramp, 5);
281 check(&noisy_ramp, 14);
282 check(&noisy_ramp, 30);
283
284 let mut step = vec![1.0; 30];
285 step.extend(std::iter::repeat_n(100.0, 30));
286 step.extend(std::iter::repeat_n(0.001, 30));
287 check(&step, 5);
288 check(&step, 14);
289
290 let constant = vec![42.0; 50];
291 check(&constant, 8);
292 check(&constant, 25);
293 }
294}