wickra_core/indicators/
linreg.rs1use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8#[derive(Debug, Clone)]
50pub struct LinearRegression {
51 period: usize,
52 window: VecDeque<f64>,
53 sum_x: f64,
55 denom: f64,
58 sum_y: f64,
60 sum_xy: f64,
63}
64
65impl LinearRegression {
66 pub fn new(period: usize) -> Result<Self> {
72 if period < 2 {
73 return Err(Error::InvalidPeriod {
74 message: "linear regression needs period >= 2",
75 });
76 }
77 let n = period as f64;
78 let sum_x = n * (n - 1.0) / 2.0;
80 let sum_xx = (n - 1.0) * n * (2.0 * n - 1.0) / 6.0;
81 Ok(Self {
82 period,
83 window: VecDeque::with_capacity(period),
84 sum_x,
85 denom: n * sum_xx - sum_x * sum_x,
86 sum_y: 0.0,
87 sum_xy: 0.0,
88 })
89 }
90
91 pub const fn period(&self) -> usize {
93 self.period
94 }
95}
96
97impl Indicator for LinearRegression {
98 type Input = f64;
99 type Output = f64;
100
101 fn update(&mut self, value: f64) -> Option<f64> {
102 if !value.is_finite() {
103 return None;
104 }
105 if self.window.len() == self.period {
106 let y0 = self.window.pop_front().expect("non-empty");
111 self.sum_xy = self.sum_xy - self.sum_y + y0;
112 self.sum_y -= y0;
113 }
114 let k = self.window.len() as f64;
118 self.window.push_back(value);
119 self.sum_y += value;
120 self.sum_xy += k * value;
121
122 if self.window.len() < self.period {
123 return None;
124 }
125 let n = self.period as f64;
126 let slope = (n * self.sum_xy - self.sum_x * self.sum_y) / self.denom;
127 let intercept = (self.sum_y - slope * self.sum_x) / n;
128 Some(intercept + slope * (n - 1.0))
129 }
130
131 fn reset(&mut self) {
132 self.window.clear();
133 self.sum_y = 0.0;
134 self.sum_xy = 0.0;
135 }
136
137 fn warmup_period(&self) -> usize {
138 self.period
139 }
140
141 fn is_ready(&self) -> bool {
142 self.window.len() == self.period
143 }
144
145 fn name(&self) -> &'static str {
146 "LinearRegression"
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153 use crate::traits::BatchExt;
154 use approx::assert_relative_eq;
155
156 #[test]
157 fn reference_values() {
158 let mut lr = LinearRegression::new(3).unwrap();
160 let out = lr.batch(&[1.0, 2.0, 9.0]);
161 assert!(out[0].is_none());
162 assert!(out[1].is_none());
163 assert_relative_eq!(out[2].unwrap(), 8.0, epsilon = 1e-9);
164 }
165
166 #[test]
167 fn perfect_line_returns_current_value() {
168 let prices: Vec<f64> = (0..40).map(|i| 2.0 * f64::from(i) + 5.0).collect();
171 let mut lr = LinearRegression::new(10).unwrap();
172 for (i, v) in lr.batch(&prices).into_iter().enumerate() {
173 if let Some(v) = v {
174 assert_relative_eq!(v, 2.0 * i as f64 + 5.0, epsilon = 1e-6);
175 }
176 }
177 }
178
179 #[test]
180 fn constant_series_returns_the_constant() {
181 let mut lr = LinearRegression::new(8).unwrap();
182 for v in lr.batch(&[42.0; 20]).into_iter().flatten() {
183 assert_relative_eq!(v, 42.0, epsilon = 1e-9);
184 }
185 }
186
187 #[test]
188 fn first_value_on_period_th_input() {
189 let mut lr = LinearRegression::new(5).unwrap();
190 let out = lr.batch(&[1.0, 3.0, 2.0, 5.0, 4.0, 6.0]);
191 for (i, v) in out.iter().enumerate().take(4) {
192 assert!(v.is_none(), "index {i} must be None during warmup");
193 }
194 assert!(out[4].is_some(), "first value lands at index period - 1");
195 assert_eq!(lr.warmup_period(), 5);
196 }
197
198 #[test]
199 fn rejects_period_below_two() {
200 assert!(LinearRegression::new(0).is_err());
201 assert!(LinearRegression::new(1).is_err());
202 assert!(LinearRegression::new(2).is_ok());
203 }
204
205 #[test]
208 fn accessors_and_metadata() {
209 let lr = LinearRegression::new(14).unwrap();
210 assert_eq!(lr.period(), 14);
211 assert_eq!(lr.name(), "LinearRegression");
212 }
213
214 #[test]
215 fn reset_clears_state() {
216 let mut lr = LinearRegression::new(5).unwrap();
217 lr.batch(&[1.0, 2.0, 3.0, 4.0, 5.0]);
218 assert!(lr.is_ready());
219 lr.reset();
220 assert!(!lr.is_ready());
221 assert_eq!(lr.update(1.0), None);
222 }
223
224 #[test]
225 fn batch_equals_streaming() {
226 let prices: Vec<f64> = (0..60)
227 .map(|i| 50.0 + (f64::from(i) * 0.3).sin() * 10.0)
228 .collect();
229 let mut a = LinearRegression::new(14).unwrap();
230 let mut b = LinearRegression::new(14).unwrap();
231 assert_eq!(
232 a.batch(&prices),
233 prices.iter().map(|x| b.update(*x)).collect::<Vec<_>>()
234 );
235 }
236
237 #[test]
244 fn incremental_matches_naive_fit_bar_by_bar() {
245 fn naive_endpoint(window: &[f64]) -> f64 {
246 let n = window.len() as f64;
247 let mut sum_y = 0.0;
248 let mut sum_xy = 0.0;
249 let mut sum_x = 0.0;
250 let mut sum_xx = 0.0;
251 for (i, &y) in window.iter().enumerate() {
252 let x = i as f64;
253 sum_y += y;
254 sum_xy += x * y;
255 sum_x += x;
256 sum_xx += x * x;
257 }
258 let denom = n * sum_xx - sum_x * sum_x;
259 let slope = (n * sum_xy - sum_x * sum_y) / denom;
260 let intercept = (sum_y - slope * sum_x) / n;
261 intercept + slope * (n - 1.0)
262 }
263
264 fn check(prices: &[f64], period: usize) {
265 let mut lr = LinearRegression::new(period).unwrap();
266 for (t, p) in prices.iter().enumerate() {
267 let streaming = lr.update(*p);
268 if t + 1 >= period {
269 let lo = t + 1 - period;
270 let expected = naive_endpoint(&prices[lo..=t]);
271 let got = streaming.expect("warmed up");
272 assert!(
273 (got - expected).abs() < 1e-9,
274 "endpoint diverges at t={t}, period={period}: got={got}, expected={expected}",
275 );
276 }
277 }
278 }
279
280 let noisy_ramp: Vec<f64> = (0..120)
281 .map(|i| 100.0 + f64::from(i) * 0.5 + (f64::from(i) * 0.7).sin() * 3.0)
282 .collect();
283 check(&noisy_ramp, 5);
284 check(&noisy_ramp, 14);
285 check(&noisy_ramp, 30);
286
287 let mut step = vec![1.0; 30];
288 step.extend(std::iter::repeat_n(100.0, 30));
289 step.extend(std::iter::repeat_n(0.001, 30));
290 check(&step, 5);
291 check(&step, 14);
292
293 let constant = vec![42.0; 50];
294 check(&constant, 8);
295 check(&constant, 25);
296 }
297}