Skip to main content

wickra_core/indicators/
linreg_slope.rs

1//! Linear Regression Slope.
2
3use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8/// Linear Regression Slope — the slope of a rolling least-squares fit.
9///
10/// Over the last `period` inputs, indexed `x = 0, 1, …, period − 1`, it fits
11/// the line `y = a + b·x` by ordinary least squares and reports the slope:
12///
13/// ```text
14/// b = (n·Σxy − Σx·Σy) / (n·Σxx − (Σx)²)
15/// ```
16///
17/// This is TA-Lib's `LINEARREG_SLOPE`: a momentum-like reading of how steeply
18/// price is trending over the window — positive while it rises, negative
19/// while it falls, near zero when it is flat — without the band-pass quirks
20/// of a difference-based oscillator.
21///
22/// Each `update` is O(1): the same incremental OLS state as
23/// [`LinearRegression`](crate::LinearRegression) is maintained — `Σx` and
24/// `Σxx` are precomputed once from `period`, while `Σy` and `Σxy` are slid
25/// forward in closed form on every push.
26///
27/// # Example
28///
29/// ```
30/// use wickra_core::{Indicator, LinRegSlope};
31///
32/// let mut indicator = LinRegSlope::new(14).unwrap();
33/// let mut last = None;
34/// for i in 0..80 {
35///     last = indicator.update(f64::from(i));
36/// }
37/// assert!(last.is_some());
38/// ```
39#[derive(Debug, Clone)]
40pub struct LinRegSlope {
41    period: usize,
42    window: VecDeque<f64>,
43    /// Closed form of `Σx` over `x = 0, 1, …, period − 1` — constant in `period`.
44    sum_x: f64,
45    /// Closed form of `n · Σxx − (Σx)²` — constant in `period`.
46    denom: f64,
47    /// Running sum of the values currently in the window.
48    sum_y: f64,
49    /// Running `Σ(x · y)` where `x` is the position within the trailing window.
50    sum_xy: f64,
51}
52
53impl LinRegSlope {
54    /// Construct a new rolling linear-regression slope over `period` inputs.
55    ///
56    /// # Errors
57    /// Returns [`Error::InvalidPeriod`] if `period < 2` — a regression line is
58    /// undefined for fewer than two points.
59    pub fn new(period: usize) -> Result<Self> {
60        if period < 2 {
61            return Err(Error::InvalidPeriod {
62                message: "linear regression slope needs period >= 2",
63            });
64        }
65        let n = period as f64;
66        // Closed forms for x = 0, 1, …, period − 1.
67        let sum_x = n * (n - 1.0) / 2.0;
68        let sum_xx = (n - 1.0) * n * (2.0 * n - 1.0) / 6.0;
69        Ok(Self {
70            period,
71            window: VecDeque::with_capacity(period),
72            sum_x,
73            denom: n * sum_xx - sum_x * sum_x,
74            sum_y: 0.0,
75            sum_xy: 0.0,
76        })
77    }
78
79    /// Configured period.
80    pub const fn period(&self) -> usize {
81        self.period
82    }
83}
84
85impl Indicator for LinRegSlope {
86    type Input = f64;
87    type Output = f64;
88
89    fn update(&mut self, value: f64) -> Option<f64> {
90        if !value.is_finite() {
91            return None;
92        }
93        if self.window.len() == self.period {
94            // Sliding-window identity: when the window slides one step forward
95            // the indices `x` for every kept entry shift down by 1, so
96            //   new_sum_xy = old_sum_xy − old_sum_y + y0
97            // (`y0` is the popped front value).
98            let y0 = self.window.pop_front().expect("non-empty");
99            self.sum_xy = self.sum_xy - self.sum_y + y0;
100            self.sum_y -= y0;
101        }
102        let k = self.window.len() as f64;
103        self.window.push_back(value);
104        self.sum_y += value;
105        self.sum_xy += k * value;
106
107        if self.window.len() < self.period {
108            return None;
109        }
110        let n = self.period as f64;
111        Some((n * self.sum_xy - self.sum_x * self.sum_y) / self.denom)
112    }
113
114    fn reset(&mut self) {
115        self.window.clear();
116        self.sum_y = 0.0;
117        self.sum_xy = 0.0;
118    }
119
120    fn warmup_period(&self) -> usize {
121        self.period
122    }
123
124    fn is_ready(&self) -> bool {
125        self.window.len() == self.period
126    }
127
128    fn name(&self) -> &'static str {
129        "LinRegSlope"
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136    use crate::traits::BatchExt;
137    use approx::assert_relative_eq;
138
139    #[test]
140    fn reference_values() {
141        // period 3 over [1, 2, 9]: fit y = 0 + 4x, so the slope is 4.
142        let mut ls = LinRegSlope::new(3).unwrap();
143        let out = ls.batch(&[1.0, 2.0, 9.0]);
144        assert!(out[0].is_none());
145        assert!(out[1].is_none());
146        assert_relative_eq!(out[2].unwrap(), 4.0, epsilon = 1e-9);
147    }
148
149    #[test]
150    fn perfect_line_returns_its_step() {
151        // A series rising by a fixed step has exactly that slope.
152        let prices: Vec<f64> = (0..40).map(|i| 2.5 * f64::from(i) + 7.0).collect();
153        let mut ls = LinRegSlope::new(10).unwrap();
154        for v in ls.batch(&prices).into_iter().flatten() {
155            assert_relative_eq!(v, 2.5, epsilon = 1e-6);
156        }
157    }
158
159    #[test]
160    fn constant_series_has_zero_slope() {
161        let mut ls = LinRegSlope::new(8).unwrap();
162        for v in ls.batch(&[42.0; 20]).into_iter().flatten() {
163            assert_relative_eq!(v, 0.0, epsilon = 1e-9);
164        }
165    }
166
167    #[test]
168    fn falling_series_has_negative_slope() {
169        let prices: Vec<f64> = (0..30).map(|i| 100.0 - f64::from(i)).collect();
170        let mut ls = LinRegSlope::new(10).unwrap();
171        for v in ls.batch(&prices).into_iter().flatten() {
172            assert!(v < 0.0, "a falling series must have a negative slope");
173        }
174    }
175
176    #[test]
177    fn first_value_on_period_th_input() {
178        let mut ls = LinRegSlope::new(5).unwrap();
179        let out = ls.batch(&[1.0, 3.0, 2.0, 5.0, 4.0, 6.0]);
180        for (i, v) in out.iter().enumerate().take(4) {
181            assert!(v.is_none(), "index {i} must be None during warmup");
182        }
183        assert!(out[4].is_some(), "first value lands at index period - 1");
184        assert_eq!(ls.warmup_period(), 5);
185    }
186
187    #[test]
188    fn rejects_period_below_two() {
189        assert!(LinRegSlope::new(0).is_err());
190        assert!(LinRegSlope::new(1).is_err());
191        assert!(LinRegSlope::new(2).is_ok());
192    }
193
194    /// Cover the const accessor `period` (80-82) and the Indicator-impl
195    /// `name` body (125-127). `warmup_period` is exercised elsewhere.
196    #[test]
197    fn accessors_and_metadata() {
198        let ls = LinRegSlope::new(14).unwrap();
199        assert_eq!(ls.period(), 14);
200        assert_eq!(ls.name(), "LinRegSlope");
201    }
202
203    #[test]
204    fn reset_clears_state() {
205        let mut ls = LinRegSlope::new(5).unwrap();
206        ls.batch(&[1.0, 2.0, 3.0, 4.0, 5.0]);
207        assert!(ls.is_ready());
208        ls.reset();
209        assert!(!ls.is_ready());
210        assert_eq!(ls.update(1.0), None);
211    }
212
213    #[test]
214    fn batch_equals_streaming() {
215        let prices: Vec<f64> = (0..60)
216            .map(|i| 50.0 + (f64::from(i) * 0.3).sin() * 10.0)
217            .collect();
218        let mut a = LinRegSlope::new(14).unwrap();
219        let mut b = LinRegSlope::new(14).unwrap();
220        assert_eq!(
221            a.batch(&prices),
222            prices.iter().map(|x| b.update(*x)).collect::<Vec<_>>()
223        );
224    }
225
226    /// Incremental OLS equivalence for the slope: the O(1) implementation must
227    /// agree bar-by-bar with a fresh-from-scratch O(n) refit, on a noisy ramp
228    /// (sliding-phase dominated) and a step function (large pop/push deltas).
229    #[test]
230    fn incremental_matches_naive_slope_bar_by_bar() {
231        fn naive_slope(window: &[f64]) -> f64 {
232            let n = window.len() as f64;
233            let mut sum_y = 0.0;
234            let mut sum_xy = 0.0;
235            let mut sum_x = 0.0;
236            let mut sum_xx = 0.0;
237            for (i, &y) in window.iter().enumerate() {
238                let x = i as f64;
239                sum_y += y;
240                sum_xy += x * y;
241                sum_x += x;
242                sum_xx += x * x;
243            }
244            (n * sum_xy - sum_x * sum_y) / (n * sum_xx - sum_x * sum_x)
245        }
246
247        fn check(prices: &[f64], period: usize) {
248            let mut ls = LinRegSlope::new(period).unwrap();
249            for (t, p) in prices.iter().enumerate() {
250                let streaming = ls.update(*p);
251                if t + 1 >= period {
252                    let lo = t + 1 - period;
253                    let expected = naive_slope(&prices[lo..=t]);
254                    let got = streaming.expect("warmed up");
255                    assert!(
256                        (got - expected).abs() < 1e-9,
257                        "slope diverges at t={t}, period={period}: got={got}, expected={expected}",
258                    );
259                }
260            }
261        }
262
263        let noisy_ramp: Vec<f64> = (0..120)
264            .map(|i| 100.0 + f64::from(i) * 0.5 + (f64::from(i) * 0.7).sin() * 3.0)
265            .collect();
266        check(&noisy_ramp, 5);
267        check(&noisy_ramp, 14);
268
269        let mut step = vec![1.0; 30];
270        step.extend(std::iter::repeat_n(100.0, 30));
271        check(&step, 7);
272    }
273}