Skip to main content

wickra_core/indicators/
linreg.rs

1//! Linear Regression (rolling least-squares endpoint).
2
3use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8/// Linear Regression — the endpoint of a rolling least-squares fit.
9///
10/// Over the last `period` inputs, indexed `x = 0, 1, …, period − 1`, it fits
11/// the line `y = a + b·x` by ordinary least squares and reports the line's
12/// value at the most recent point:
13///
14/// ```text
15/// b (slope)     = (n·Σxy − Σx·Σy) / (n·Σxx − (Σx)²)
16/// a (intercept) = (Σy − b·Σx) / n
17/// LinearReg     = a + b·(period − 1)
18/// ```
19///
20/// This is TA-Lib's `LINEARREG`: a smoothed price that lags less than an SMA
21/// because it extrapolates the *local trend* forward to the current bar
22/// instead of averaging it away.
23///
24/// Each `update` is O(1): the `Σx` and `Σxx` terms depend only on `period` and
25/// are precomputed once, while `Σy` and `Σxy` are maintained incrementally as
26/// the window slides. The closed-form sliding-window identity for
27/// `x = 0, 1, …, period − 1` is
28///
29/// ```text
30/// new_sum_xy = old_sum_xy − old_sum_y + popped_y0    // index shift by −1
31/// new_sum_y  = old_sum_y  − popped_y0
32/// // then push the new value at index n−1:
33/// sum_xy += (n − 1) · new_value
34/// sum_y  += new_value
35/// ```
36///
37/// # Example
38///
39/// ```
40/// use wickra_core::{Indicator, LinearRegression};
41///
42/// let mut indicator = LinearRegression::new(14).unwrap();
43/// let mut last = None;
44/// for i in 0..80 {
45///     last = indicator.update(f64::from(i));
46/// }
47/// assert!(last.is_some());
48/// ```
49#[derive(Debug, Clone)]
50pub struct LinearRegression {
51    period: usize,
52    window: VecDeque<f64>,
53    /// Closed form of `Σx` over `x = 0, 1, …, period − 1` — constant in `period`.
54    sum_x: f64,
55    /// Closed form of `n · Σxx − (Σx)²` — constant in `period`, the OLS
56    /// denominator.
57    denom: f64,
58    /// Running sum of the values currently in the window.
59    sum_y: f64,
60    /// Running `Σ(x · y)` where `x` is the position of each value within the
61    /// trailing window (`0` for the oldest, `period − 1` for the newest).
62    sum_xy: f64,
63}
64
65impl LinearRegression {
66    /// Construct a new rolling linear regression over `period` inputs.
67    ///
68    /// # Errors
69    /// Returns [`Error::InvalidPeriod`] if `period < 2` — a regression line is
70    /// undefined for fewer than two points.
71    pub fn new(period: usize) -> Result<Self> {
72        if period < 2 {
73            return Err(Error::InvalidPeriod {
74                message: "linear regression needs period >= 2",
75            });
76        }
77        let n = period as f64;
78        // Closed forms for x = 0, 1, …, period − 1.
79        let sum_x = n * (n - 1.0) / 2.0;
80        let sum_xx = (n - 1.0) * n * (2.0 * n - 1.0) / 6.0;
81        Ok(Self {
82            period,
83            window: VecDeque::with_capacity(period),
84            sum_x,
85            denom: n * sum_xx - sum_x * sum_x,
86            sum_y: 0.0,
87            sum_xy: 0.0,
88        })
89    }
90
91    /// Configured period.
92    pub const fn period(&self) -> usize {
93        self.period
94    }
95}
96
97impl Indicator for LinearRegression {
98    type Input = f64;
99    type Output = f64;
100
101    fn update(&mut self, value: f64) -> Option<f64> {
102        if !value.is_finite() {
103            return None;
104        }
105        if self.window.len() == self.period {
106            // Sliding phase: pop the oldest, then shift every remaining index
107            // down by 1 in the running `sum_xy`. The identity
108            //   Σ((i − 1) · y_i for i = 1..n−1) = Σ(i · y_i) − Σ(y_i) + y_0
109            // gives the closed-form update below.
110            let y0 = self.window.pop_front().expect("non-empty");
111            self.sum_xy = self.sum_xy - self.sum_y + y0;
112            self.sum_y -= y0;
113        }
114        // Append at position `k = current length` before the push. During
115        // warmup `k` ranges over `0..period − 1`; once the window is full it
116        // is always `period − 1`.
117        let k = self.window.len() as f64;
118        self.window.push_back(value);
119        self.sum_y += value;
120        self.sum_xy += k * value;
121
122        if self.window.len() < self.period {
123            return None;
124        }
125        let n = self.period as f64;
126        let slope = (n * self.sum_xy - self.sum_x * self.sum_y) / self.denom;
127        let intercept = (self.sum_y - slope * self.sum_x) / n;
128        Some(intercept + slope * (n - 1.0))
129    }
130
131    fn reset(&mut self) {
132        self.window.clear();
133        self.sum_y = 0.0;
134        self.sum_xy = 0.0;
135    }
136
137    fn warmup_period(&self) -> usize {
138        self.period
139    }
140
141    fn is_ready(&self) -> bool {
142        self.window.len() == self.period
143    }
144
145    fn name(&self) -> &'static str {
146        "LinearRegression"
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153    use crate::traits::BatchExt;
154    use approx::assert_relative_eq;
155
156    #[test]
157    fn reference_values() {
158        // period 3 over [1, 2, 9]: fit y = 0 + 4x, endpoint = 0 + 4·2 = 8.
159        let mut lr = LinearRegression::new(3).unwrap();
160        let out = lr.batch(&[1.0, 2.0, 9.0]);
161        assert!(out[0].is_none());
162        assert!(out[1].is_none());
163        assert_relative_eq!(out[2].unwrap(), 8.0, epsilon = 1e-9);
164    }
165
166    #[test]
167    fn perfect_line_returns_current_value() {
168        // The regression of a perfectly linear series is that line itself, so
169        // its endpoint equals the current value.
170        let prices: Vec<f64> = (0..40).map(|i| 2.0 * f64::from(i) + 5.0).collect();
171        let mut lr = LinearRegression::new(10).unwrap();
172        for (i, v) in lr.batch(&prices).into_iter().enumerate() {
173            if let Some(v) = v {
174                assert_relative_eq!(v, 2.0 * i as f64 + 5.0, epsilon = 1e-6);
175            }
176        }
177    }
178
179    #[test]
180    fn constant_series_returns_the_constant() {
181        let mut lr = LinearRegression::new(8).unwrap();
182        for v in lr.batch(&[42.0; 20]).into_iter().flatten() {
183            assert_relative_eq!(v, 42.0, epsilon = 1e-9);
184        }
185    }
186
187    #[test]
188    fn first_value_on_period_th_input() {
189        let mut lr = LinearRegression::new(5).unwrap();
190        let out = lr.batch(&[1.0, 3.0, 2.0, 5.0, 4.0, 6.0]);
191        for (i, v) in out.iter().enumerate().take(4) {
192            assert!(v.is_none(), "index {i} must be None during warmup");
193        }
194        assert!(out[4].is_some(), "first value lands at index period - 1");
195        assert_eq!(lr.warmup_period(), 5);
196    }
197
198    #[test]
199    fn rejects_period_below_two() {
200        assert!(LinearRegression::new(0).is_err());
201        assert!(LinearRegression::new(1).is_err());
202        assert!(LinearRegression::new(2).is_ok());
203    }
204
205    /// Cover the const accessor `period` (92-94) and the Indicator-impl
206    /// `name` body (142-144). `warmup_period` is exercised elsewhere.
207    #[test]
208    fn accessors_and_metadata() {
209        let lr = LinearRegression::new(14).unwrap();
210        assert_eq!(lr.period(), 14);
211        assert_eq!(lr.name(), "LinearRegression");
212    }
213
214    #[test]
215    fn reset_clears_state() {
216        let mut lr = LinearRegression::new(5).unwrap();
217        lr.batch(&[1.0, 2.0, 3.0, 4.0, 5.0]);
218        assert!(lr.is_ready());
219        lr.reset();
220        assert!(!lr.is_ready());
221        assert_eq!(lr.update(1.0), None);
222    }
223
224    #[test]
225    fn batch_equals_streaming() {
226        let prices: Vec<f64> = (0..60)
227            .map(|i| 50.0 + (f64::from(i) * 0.3).sin() * 10.0)
228            .collect();
229        let mut a = LinearRegression::new(14).unwrap();
230        let mut b = LinearRegression::new(14).unwrap();
231        assert_eq!(
232            a.batch(&prices),
233            prices.iter().map(|x| b.update(*x)).collect::<Vec<_>>()
234        );
235    }
236
237    /// Incremental OLS equivalence: the O(1) implementation must agree to
238    /// `1e-9` with a fresh-from-scratch O(n) refit on every bar, on inputs
239    /// chosen to stress every code path: a noisy ramp (sliding phase
240    /// dominates), a step function (the new value differs sharply from the
241    /// popped one), and constants (the floating-point accumulators must not
242    /// drift).
243    #[test]
244    fn incremental_matches_naive_fit_bar_by_bar() {
245        fn naive_endpoint(window: &[f64]) -> f64 {
246            let n = window.len() as f64;
247            let mut sum_y = 0.0;
248            let mut sum_xy = 0.0;
249            let mut sum_x = 0.0;
250            let mut sum_xx = 0.0;
251            for (i, &y) in window.iter().enumerate() {
252                let x = i as f64;
253                sum_y += y;
254                sum_xy += x * y;
255                sum_x += x;
256                sum_xx += x * x;
257            }
258            let denom = n * sum_xx - sum_x * sum_x;
259            let slope = (n * sum_xy - sum_x * sum_y) / denom;
260            let intercept = (sum_y - slope * sum_x) / n;
261            intercept + slope * (n - 1.0)
262        }
263
264        fn check(prices: &[f64], period: usize) {
265            let mut lr = LinearRegression::new(period).unwrap();
266            for (t, p) in prices.iter().enumerate() {
267                let streaming = lr.update(*p);
268                if t + 1 >= period {
269                    let lo = t + 1 - period;
270                    let expected = naive_endpoint(&prices[lo..=t]);
271                    let got = streaming.expect("warmed up");
272                    assert!(
273                        (got - expected).abs() < 1e-9,
274                        "endpoint diverges at t={t}, period={period}: got={got}, expected={expected}",
275                    );
276                }
277            }
278        }
279
280        let noisy_ramp: Vec<f64> = (0..120)
281            .map(|i| 100.0 + f64::from(i) * 0.5 + (f64::from(i) * 0.7).sin() * 3.0)
282            .collect();
283        check(&noisy_ramp, 5);
284        check(&noisy_ramp, 14);
285        check(&noisy_ramp, 30);
286
287        let mut step = vec![1.0; 30];
288        step.extend(std::iter::repeat_n(100.0, 30));
289        step.extend(std::iter::repeat_n(0.001, 30));
290        check(&step, 5);
291        check(&step, 14);
292
293        let constant = vec![42.0; 50];
294        check(&constant, 8);
295        check(&constant, 25);
296    }
297}