wickra_core/indicators/
linreg_slope.rs1use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8#[derive(Debug, Clone)]
40pub struct LinRegSlope {
41 period: usize,
42 window: VecDeque<f64>,
43 sum_x: f64,
45 denom: f64,
47 sum_y: f64,
49 sum_xy: f64,
51}
52
53impl LinRegSlope {
54 pub fn new(period: usize) -> Result<Self> {
60 if period < 2 {
61 return Err(Error::InvalidPeriod {
62 message: "linear regression slope needs period >= 2",
63 });
64 }
65 let n = period as f64;
66 let sum_x = n * (n - 1.0) / 2.0;
68 let sum_xx = (n - 1.0) * n * (2.0 * n - 1.0) / 6.0;
69 Ok(Self {
70 period,
71 window: VecDeque::with_capacity(period),
72 sum_x,
73 denom: n * sum_xx - sum_x * sum_x,
74 sum_y: 0.0,
75 sum_xy: 0.0,
76 })
77 }
78
79 pub const fn period(&self) -> usize {
81 self.period
82 }
83}
84
85impl Indicator for LinRegSlope {
86 type Input = f64;
87 type Output = f64;
88
89 fn update(&mut self, value: f64) -> Option<f64> {
90 if !value.is_finite() {
91 return None;
92 }
93 if self.window.len() == self.period {
94 let y0 = self.window.pop_front().expect("non-empty");
99 self.sum_xy = self.sum_xy - self.sum_y + y0;
100 self.sum_y -= y0;
101 }
102 let k = self.window.len() as f64;
103 self.window.push_back(value);
104 self.sum_y += value;
105 self.sum_xy += k * value;
106
107 if self.window.len() < self.period {
108 return None;
109 }
110 let n = self.period as f64;
111 Some((n * self.sum_xy - self.sum_x * self.sum_y) / self.denom)
112 }
113
114 fn reset(&mut self) {
115 self.window.clear();
116 self.sum_y = 0.0;
117 self.sum_xy = 0.0;
118 }
119
120 fn warmup_period(&self) -> usize {
121 self.period
122 }
123
124 fn is_ready(&self) -> bool {
125 self.window.len() == self.period
126 }
127
128 fn name(&self) -> &'static str {
129 "LinRegSlope"
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136 use crate::traits::BatchExt;
137 use approx::assert_relative_eq;
138
139 #[test]
140 fn reference_values() {
141 let mut ls = LinRegSlope::new(3).unwrap();
143 let out = ls.batch(&[1.0, 2.0, 9.0]);
144 assert!(out[0].is_none());
145 assert!(out[1].is_none());
146 assert_relative_eq!(out[2].unwrap(), 4.0, epsilon = 1e-9);
147 }
148
149 #[test]
150 fn perfect_line_returns_its_step() {
151 let prices: Vec<f64> = (0..40).map(|i| 2.5 * f64::from(i) + 7.0).collect();
153 let mut ls = LinRegSlope::new(10).unwrap();
154 for v in ls.batch(&prices).into_iter().flatten() {
155 assert_relative_eq!(v, 2.5, epsilon = 1e-6);
156 }
157 }
158
159 #[test]
160 fn constant_series_has_zero_slope() {
161 let mut ls = LinRegSlope::new(8).unwrap();
162 for v in ls.batch(&[42.0; 20]).into_iter().flatten() {
163 assert_relative_eq!(v, 0.0, epsilon = 1e-9);
164 }
165 }
166
167 #[test]
168 fn falling_series_has_negative_slope() {
169 let prices: Vec<f64> = (0..30).map(|i| 100.0 - f64::from(i)).collect();
170 let mut ls = LinRegSlope::new(10).unwrap();
171 for v in ls.batch(&prices).into_iter().flatten() {
172 assert!(v < 0.0, "a falling series must have a negative slope");
173 }
174 }
175
176 #[test]
177 fn first_value_on_period_th_input() {
178 let mut ls = LinRegSlope::new(5).unwrap();
179 let out = ls.batch(&[1.0, 3.0, 2.0, 5.0, 4.0, 6.0]);
180 for (i, v) in out.iter().enumerate().take(4) {
181 assert!(v.is_none(), "index {i} must be None during warmup");
182 }
183 assert!(out[4].is_some(), "first value lands at index period - 1");
184 assert_eq!(ls.warmup_period(), 5);
185 }
186
187 #[test]
188 fn rejects_period_below_two() {
189 assert!(LinRegSlope::new(0).is_err());
190 assert!(LinRegSlope::new(1).is_err());
191 assert!(LinRegSlope::new(2).is_ok());
192 }
193
194 #[test]
197 fn accessors_and_metadata() {
198 let ls = LinRegSlope::new(14).unwrap();
199 assert_eq!(ls.period(), 14);
200 assert_eq!(ls.name(), "LinRegSlope");
201 }
202
203 #[test]
204 fn reset_clears_state() {
205 let mut ls = LinRegSlope::new(5).unwrap();
206 ls.batch(&[1.0, 2.0, 3.0, 4.0, 5.0]);
207 assert!(ls.is_ready());
208 ls.reset();
209 assert!(!ls.is_ready());
210 assert_eq!(ls.update(1.0), None);
211 }
212
213 #[test]
214 fn batch_equals_streaming() {
215 let prices: Vec<f64> = (0..60)
216 .map(|i| 50.0 + (f64::from(i) * 0.3).sin() * 10.0)
217 .collect();
218 let mut a = LinRegSlope::new(14).unwrap();
219 let mut b = LinRegSlope::new(14).unwrap();
220 assert_eq!(
221 a.batch(&prices),
222 prices.iter().map(|x| b.update(*x)).collect::<Vec<_>>()
223 );
224 }
225
226 #[test]
230 fn incremental_matches_naive_slope_bar_by_bar() {
231 fn naive_slope(window: &[f64]) -> f64 {
232 let n = window.len() as f64;
233 let mut sum_y = 0.0;
234 let mut sum_xy = 0.0;
235 let mut sum_x = 0.0;
236 let mut sum_xx = 0.0;
237 for (i, &y) in window.iter().enumerate() {
238 let x = i as f64;
239 sum_y += y;
240 sum_xy += x * y;
241 sum_x += x;
242 sum_xx += x * x;
243 }
244 (n * sum_xy - sum_x * sum_y) / (n * sum_xx - sum_x * sum_x)
245 }
246
247 fn check(prices: &[f64], period: usize) {
248 let mut ls = LinRegSlope::new(period).unwrap();
249 for (t, p) in prices.iter().enumerate() {
250 let streaming = ls.update(*p);
251 if t + 1 >= period {
252 let lo = t + 1 - period;
253 let expected = naive_slope(&prices[lo..=t]);
254 let got = streaming.expect("warmed up");
255 assert!(
256 (got - expected).abs() < 1e-9,
257 "slope diverges at t={t}, period={period}: got={got}, expected={expected}",
258 );
259 }
260 }
261 }
262
263 let noisy_ramp: Vec<f64> = (0..120)
264 .map(|i| 100.0 + f64::from(i) * 0.5 + (f64::from(i) * 0.7).sin() * 3.0)
265 .collect();
266 check(&noisy_ramp, 5);
267 check(&noisy_ramp, 14);
268
269 let mut step = vec![1.0; 30];
270 step.extend(std::iter::repeat_n(100.0, 30));
271 check(&step, 7);
272 }
273}