Skip to main content

wickra_core/indicators/
linreg_channel.rs

1//! Linear Regression Channel — OLS endpoint ± k · stddev of residuals.
2
3use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8/// Linear Regression Channel output.
9#[derive(Debug, Clone, Copy, PartialEq)]
10pub struct LinRegChannelOutput {
11    /// Upper channel: regression endpoint plus `multiplier · stddev` of the
12    /// residuals.
13    pub upper: f64,
14    /// Middle line: OLS endpoint over the window.
15    pub middle: f64,
16    /// Lower channel: regression endpoint minus `multiplier · stddev` of the
17    /// residuals.
18    pub lower: f64,
19}
20
21/// Linear Regression Channel: rolling least-squares line with `±k·σ` bands
22/// sized by the residuals about the fitted line.
23///
24/// ```text
25/// fit y = a + b·x by OLS over the last `period` closes
26/// residual_i = y_i − (a + b · x_i)
27/// sigma      = sqrt( Σ residual_i² / period )      // population stddev
28/// middle     = a + b · (period − 1)                // endpoint of the line
29/// upper      = middle + multiplier · sigma
30/// lower      = middle − multiplier · sigma
31/// ```
32///
33/// Where [`BollingerBands`](crate::BollingerBands) measures dispersion about
34/// the *mean*, the `LinReg` Channel measures it about the *trend*: detrended
35/// residuals, so a steady drift up or down does not bias the band width. The
36/// resulting envelope tracks the trend without flaring on momentum bursts —
37/// breakouts are statistically meaningful in the direction of trend, not just
38/// in absolute price.
39///
40/// # Example
41///
42/// ```
43/// use wickra_core::{Indicator, LinRegChannel};
44///
45/// let mut indicator = LinRegChannel::new(20, 2.0).unwrap();
46/// let mut last = None;
47/// for i in 0..40 {
48///     last = indicator.update(100.0 + f64::from(i));
49/// }
50/// assert!(last.is_some());
51/// ```
52#[derive(Debug, Clone)]
53pub struct LinRegChannel {
54    period: usize,
55    multiplier: f64,
56    window: VecDeque<f64>,
57    sum_x: f64,
58    sum_xx: f64,
59}
60
61impl LinRegChannel {
62    /// # Errors
63    /// Returns [`Error::InvalidPeriod`] if `period < 2` and
64    /// [`Error::NonPositiveMultiplier`] if `multiplier` is not strictly
65    /// positive and finite.
66    pub fn new(period: usize, multiplier: f64) -> Result<Self> {
67        if period < 2 {
68            return Err(Error::InvalidPeriod {
69                message: "linear regression channel needs period >= 2",
70            });
71        }
72        if !multiplier.is_finite() || multiplier <= 0.0 {
73            return Err(Error::NonPositiveMultiplier);
74        }
75        let n = period as f64;
76        Ok(Self {
77            period,
78            multiplier,
79            window: VecDeque::with_capacity(period),
80            sum_x: n * (n - 1.0) / 2.0,
81            sum_xx: (n - 1.0) * n * (2.0 * n - 1.0) / 6.0,
82        })
83    }
84
85    /// Configured period.
86    pub const fn period(&self) -> usize {
87        self.period
88    }
89
90    /// Configured multiplier.
91    pub const fn multiplier(&self) -> f64 {
92        self.multiplier
93    }
94}
95
96impl Indicator for LinRegChannel {
97    type Input = f64;
98    type Output = LinRegChannelOutput;
99
100    fn update(&mut self, value: f64) -> Option<LinRegChannelOutput> {
101        if self.window.len() == self.period {
102            self.window.pop_front();
103        }
104        self.window.push_back(value);
105        if self.window.len() < self.period {
106            return None;
107        }
108        // Recompute over the live window every bar. The OLS endpoint *could*
109        // be maintained incrementally (see `LinearRegression`) but the
110        // residual-stddev cannot be slid in closed form without storing each
111        // residual; recomputing both keeps the code simple and is O(period)
112        // per update — entirely acceptable for the periods used in practice.
113        let n = self.period as f64;
114        let mut sum_y = 0.0;
115        let mut sum_xy = 0.0;
116        for (i, &y) in self.window.iter().enumerate() {
117            let x = i as f64;
118            sum_y += y;
119            sum_xy += x * y;
120        }
121        let denom = n * self.sum_xx - self.sum_x * self.sum_x;
122        let slope = (n * sum_xy - self.sum_x * sum_y) / denom;
123        let intercept = (sum_y - slope * self.sum_x) / n;
124
125        // Residuals about the fitted line.
126        let mut sum_sq = 0.0;
127        for (i, &y) in self.window.iter().enumerate() {
128            let fitted = intercept + slope * (i as f64);
129            let r = y - fitted;
130            sum_sq += r * r;
131        }
132        let sigma = (sum_sq / n).sqrt();
133        let middle = intercept + slope * (n - 1.0);
134        Some(LinRegChannelOutput {
135            upper: middle + self.multiplier * sigma,
136            middle,
137            lower: middle - self.multiplier * sigma,
138        })
139    }
140
141    fn reset(&mut self) {
142        self.window.clear();
143    }
144
145    fn warmup_period(&self) -> usize {
146        self.period
147    }
148
149    fn is_ready(&self) -> bool {
150        self.window.len() == self.period
151    }
152
153    fn name(&self) -> &'static str {
154        "LinRegChannel"
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161    use crate::traits::BatchExt;
162    use approx::assert_relative_eq;
163
164    #[test]
165    fn rejects_period_below_two() {
166        assert!(LinRegChannel::new(0, 2.0).is_err());
167        assert!(LinRegChannel::new(1, 2.0).is_err());
168        assert!(LinRegChannel::new(2, 2.0).is_ok());
169    }
170
171    #[test]
172    fn rejects_non_positive_multiplier() {
173        assert!(matches!(
174            LinRegChannel::new(20, 0.0),
175            Err(Error::NonPositiveMultiplier)
176        ));
177        assert!(matches!(
178            LinRegChannel::new(20, -1.0),
179            Err(Error::NonPositiveMultiplier)
180        ));
181        assert!(matches!(
182            LinRegChannel::new(20, f64::NAN),
183            Err(Error::NonPositiveMultiplier)
184        ));
185    }
186
187    #[test]
188    fn accessors_and_metadata() {
189        let lc = LinRegChannel::new(20, 2.0).unwrap();
190        assert_eq!(lc.period(), 20);
191        assert_relative_eq!(lc.multiplier(), 2.0, epsilon = 1e-12);
192        assert_eq!(lc.warmup_period(), 20);
193        assert_eq!(lc.name(), "LinRegChannel");
194    }
195
196    #[test]
197    fn perfect_line_collapses_channel() {
198        // A perfectly linear series has zero residuals, so upper == middle == lower.
199        let prices: Vec<f64> = (0..40).map(|i| 2.0 * f64::from(i) + 5.0).collect();
200        let mut lc = LinRegChannel::new(10, 2.0).unwrap();
201        for o in lc.batch(&prices).into_iter().flatten() {
202            assert_relative_eq!(o.upper, o.middle, epsilon = 1e-9);
203            assert_relative_eq!(o.middle, o.lower, epsilon = 1e-9);
204        }
205    }
206
207    #[test]
208    fn constant_series_collapses_channel() {
209        let mut lc = LinRegChannel::new(8, 2.0).unwrap();
210        let out = lc.batch(&[42.0; 20]);
211        let v = out.iter().rev().flatten().next().unwrap();
212        assert_relative_eq!(v.middle, 42.0, epsilon = 1e-9);
213        assert_relative_eq!(v.upper, 42.0, epsilon = 1e-9);
214        assert_relative_eq!(v.lower, 42.0, epsilon = 1e-9);
215    }
216
217    #[test]
218    fn upper_above_middle_above_lower() {
219        let prices: Vec<f64> = (0..80)
220            .map(|i| 100.0 + (f64::from(i) * 0.3).sin() * 10.0)
221            .collect();
222        let mut lc = LinRegChannel::new(20, 2.0).unwrap();
223        for o in lc.batch(&prices).into_iter().flatten() {
224            assert!(o.upper >= o.middle);
225            assert!(o.middle >= o.lower);
226        }
227    }
228
229    #[test]
230    fn batch_equals_streaming() {
231        let prices: Vec<f64> = (0..60)
232            .map(|i| 50.0 + (f64::from(i) * 0.3).sin() * 10.0)
233            .collect();
234        let mut a = LinRegChannel::new(14, 2.0).unwrap();
235        let mut b = LinRegChannel::new(14, 2.0).unwrap();
236        assert_eq!(
237            a.batch(&prices),
238            prices.iter().map(|p| b.update(*p)).collect::<Vec<_>>()
239        );
240    }
241
242    #[test]
243    fn reset_clears_state() {
244        let mut lc = LinRegChannel::new(5, 2.0).unwrap();
245        lc.batch(&[1.0, 2.0, 3.0, 4.0, 5.0]);
246        assert!(lc.is_ready());
247        lc.reset();
248        assert!(!lc.is_ready());
249        assert_eq!(lc.update(1.0), None);
250    }
251
252    /// Reference: period 3 over `[1, 2, 9]`. Fitted line `y = 0 + 4·x`,
253    /// endpoint at `x = 2` is `8`. Residuals: `1 − 0 = 1`, `2 − 4 = −2`,
254    /// `9 − 8 = 1`. Population variance = (1 + 4 + 1) / 3 = 2, sigma = sqrt(2).
255    /// With multiplier 2.0, upper = 8 + 2·sqrt(2), lower = 8 − 2·sqrt(2).
256    #[test]
257    fn reference_values() {
258        let mut lc = LinRegChannel::new(3, 2.0).unwrap();
259        let out = lc.batch(&[1.0, 2.0, 9.0]);
260        let v = out[2].unwrap();
261        let s2 = f64::sqrt(2.0);
262        assert_relative_eq!(v.middle, 8.0, epsilon = 1e-9);
263        assert_relative_eq!(v.upper, 8.0 + 2.0 * s2, epsilon = 1e-9);
264        assert_relative_eq!(v.lower, 8.0 - 2.0 * s2, epsilon = 1e-9);
265    }
266}