Skip to main content

wickra_core/indicators/
linreg_channel.rs

1//! Linear Regression Channel — OLS endpoint ± k · stddev of residuals.
2
3use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8/// Linear Regression Channel output.
9#[derive(Debug, Clone, Copy, PartialEq)]
10pub struct LinRegChannelOutput {
11    /// Upper channel: regression endpoint plus `multiplier · stddev` of the
12    /// residuals.
13    pub upper: f64,
14    /// Middle line: OLS endpoint over the window.
15    pub middle: f64,
16    /// Lower channel: regression endpoint minus `multiplier · stddev` of the
17    /// residuals.
18    pub lower: f64,
19}
20
21/// Linear Regression Channel: rolling least-squares line with `±k·σ` bands
22/// sized by the residuals about the fitted line.
23///
24/// ```text
25/// fit y = a + b·x by OLS over the last `period` closes
26/// residual_i = y_i − (a + b · x_i)
27/// sigma      = sqrt( Σ residual_i² / period )      // population stddev
28/// middle     = a + b · (period − 1)                // endpoint of the line
29/// upper      = middle + multiplier · sigma
30/// lower      = middle − multiplier · sigma
31/// ```
32///
33/// Where [`BollingerBands`](crate::BollingerBands) measures dispersion about
34/// the *mean*, the `LinReg` Channel measures it about the *trend*: detrended
35/// residuals, so a steady drift up or down does not bias the band width. The
36/// resulting envelope tracks the trend without flaring on momentum bursts —
37/// breakouts are statistically meaningful in the direction of trend, not just
38/// in absolute price.
39///
40/// # Example
41///
42/// ```
43/// use wickra_core::{Indicator, LinRegChannel};
44///
45/// let mut indicator = LinRegChannel::new(20, 2.0).unwrap();
46/// let mut last = None;
47/// for i in 0..40 {
48///     last = indicator.update(100.0 + f64::from(i));
49/// }
50/// assert!(last.is_some());
51/// ```
52#[derive(Debug, Clone)]
53pub struct LinRegChannel {
54    period: usize,
55    multiplier: f64,
56    window: VecDeque<f64>,
57    sum_x: f64,
58    sum_xx: f64,
59}
60
61impl LinRegChannel {
62    /// # Errors
63    /// Returns [`Error::InvalidPeriod`] if `period < 2` and
64    /// [`Error::NonPositiveMultiplier`] if `multiplier` is not strictly
65    /// positive and finite.
66    pub fn new(period: usize, multiplier: f64) -> Result<Self> {
67        if period < 2 {
68            return Err(Error::InvalidPeriod {
69                message: "linear regression channel needs period >= 2",
70            });
71        }
72        if !multiplier.is_finite() || multiplier <= 0.0 {
73            return Err(Error::NonPositiveMultiplier);
74        }
75        let n = period as f64;
76        Ok(Self {
77            period,
78            multiplier,
79            window: VecDeque::with_capacity(period),
80            sum_x: n * (n - 1.0) / 2.0,
81            sum_xx: (n - 1.0) * n * (2.0 * n - 1.0) / 6.0,
82        })
83    }
84
85    /// Configured period.
86    pub const fn period(&self) -> usize {
87        self.period
88    }
89
90    /// Configured multiplier.
91    pub const fn multiplier(&self) -> f64 {
92        self.multiplier
93    }
94}
95
96impl Indicator for LinRegChannel {
97    type Input = f64;
98    type Output = LinRegChannelOutput;
99
100    fn update(&mut self, value: f64) -> Option<LinRegChannelOutput> {
101        if !value.is_finite() {
102            return None;
103        }
104        if self.window.len() == self.period {
105            self.window.pop_front();
106        }
107        self.window.push_back(value);
108        if self.window.len() < self.period {
109            return None;
110        }
111        // Recompute over the live window every bar. The OLS endpoint *could*
112        // be maintained incrementally (see `LinearRegression`) but the
113        // residual-stddev cannot be slid in closed form without storing each
114        // residual; recomputing both keeps the code simple and is O(period)
115        // per update — entirely acceptable for the periods used in practice.
116        let n = self.period as f64;
117        let mut sum_y = 0.0;
118        let mut sum_xy = 0.0;
119        for (i, &y) in self.window.iter().enumerate() {
120            let x = i as f64;
121            sum_y += y;
122            sum_xy += x * y;
123        }
124        let denom = n * self.sum_xx - self.sum_x * self.sum_x;
125        let slope = (n * sum_xy - self.sum_x * sum_y) / denom;
126        let intercept = (sum_y - slope * self.sum_x) / n;
127
128        // Residuals about the fitted line.
129        let mut sum_sq = 0.0;
130        for (i, &y) in self.window.iter().enumerate() {
131            let fitted = intercept + slope * (i as f64);
132            let r = y - fitted;
133            sum_sq += r * r;
134        }
135        let sigma = (sum_sq / n).sqrt();
136        let middle = intercept + slope * (n - 1.0);
137        Some(LinRegChannelOutput {
138            upper: middle + self.multiplier * sigma,
139            middle,
140            lower: middle - self.multiplier * sigma,
141        })
142    }
143
144    fn reset(&mut self) {
145        self.window.clear();
146    }
147
148    fn warmup_period(&self) -> usize {
149        self.period
150    }
151
152    fn is_ready(&self) -> bool {
153        self.window.len() == self.period
154    }
155
156    fn name(&self) -> &'static str {
157        "LinRegChannel"
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    use crate::traits::BatchExt;
165    use approx::assert_relative_eq;
166
167    #[test]
168    fn rejects_period_below_two() {
169        assert!(LinRegChannel::new(0, 2.0).is_err());
170        assert!(LinRegChannel::new(1, 2.0).is_err());
171        assert!(LinRegChannel::new(2, 2.0).is_ok());
172    }
173
174    #[test]
175    fn rejects_non_positive_multiplier() {
176        assert!(matches!(
177            LinRegChannel::new(20, 0.0),
178            Err(Error::NonPositiveMultiplier)
179        ));
180        assert!(matches!(
181            LinRegChannel::new(20, -1.0),
182            Err(Error::NonPositiveMultiplier)
183        ));
184        assert!(matches!(
185            LinRegChannel::new(20, f64::NAN),
186            Err(Error::NonPositiveMultiplier)
187        ));
188    }
189
190    #[test]
191    fn accessors_and_metadata() {
192        let lc = LinRegChannel::new(20, 2.0).unwrap();
193        assert_eq!(lc.period(), 20);
194        assert_relative_eq!(lc.multiplier(), 2.0, epsilon = 1e-12);
195        assert_eq!(lc.warmup_period(), 20);
196        assert_eq!(lc.name(), "LinRegChannel");
197    }
198
199    #[test]
200    fn perfect_line_collapses_channel() {
201        // A perfectly linear series has zero residuals, so upper == middle == lower.
202        let prices: Vec<f64> = (0..40).map(|i| 2.0 * f64::from(i) + 5.0).collect();
203        let mut lc = LinRegChannel::new(10, 2.0).unwrap();
204        for o in lc.batch(&prices).into_iter().flatten() {
205            assert_relative_eq!(o.upper, o.middle, epsilon = 1e-9);
206            assert_relative_eq!(o.middle, o.lower, epsilon = 1e-9);
207        }
208    }
209
210    #[test]
211    fn constant_series_collapses_channel() {
212        let mut lc = LinRegChannel::new(8, 2.0).unwrap();
213        let out = lc.batch(&[42.0; 20]);
214        let v = out.iter().rev().flatten().next().unwrap();
215        assert_relative_eq!(v.middle, 42.0, epsilon = 1e-9);
216        assert_relative_eq!(v.upper, 42.0, epsilon = 1e-9);
217        assert_relative_eq!(v.lower, 42.0, epsilon = 1e-9);
218    }
219
220    #[test]
221    fn upper_above_middle_above_lower() {
222        let prices: Vec<f64> = (0..80)
223            .map(|i| 100.0 + (f64::from(i) * 0.3).sin() * 10.0)
224            .collect();
225        let mut lc = LinRegChannel::new(20, 2.0).unwrap();
226        for o in lc.batch(&prices).into_iter().flatten() {
227            assert!(o.upper >= o.middle);
228            assert!(o.middle >= o.lower);
229        }
230    }
231
232    #[test]
233    fn batch_equals_streaming() {
234        let prices: Vec<f64> = (0..60)
235            .map(|i| 50.0 + (f64::from(i) * 0.3).sin() * 10.0)
236            .collect();
237        let mut a = LinRegChannel::new(14, 2.0).unwrap();
238        let mut b = LinRegChannel::new(14, 2.0).unwrap();
239        assert_eq!(
240            a.batch(&prices),
241            prices.iter().map(|p| b.update(*p)).collect::<Vec<_>>()
242        );
243    }
244
245    #[test]
246    fn reset_clears_state() {
247        let mut lc = LinRegChannel::new(5, 2.0).unwrap();
248        lc.batch(&[1.0, 2.0, 3.0, 4.0, 5.0]);
249        assert!(lc.is_ready());
250        lc.reset();
251        assert!(!lc.is_ready());
252        assert_eq!(lc.update(1.0), None);
253    }
254
255    /// Reference: period 3 over `[1, 2, 9]`. Fitted line `y = 0 + 4·x`,
256    /// endpoint at `x = 2` is `8`. Residuals: `1 − 0 = 1`, `2 − 4 = −2`,
257    /// `9 − 8 = 1`. Population variance = (1 + 4 + 1) / 3 = 2, sigma = sqrt(2).
258    /// With multiplier 2.0, upper = 8 + 2·sqrt(2), lower = 8 − 2·sqrt(2).
259    #[test]
260    fn reference_values() {
261        let mut lc = LinRegChannel::new(3, 2.0).unwrap();
262        let out = lc.batch(&[1.0, 2.0, 9.0]);
263        let v = out[2].unwrap();
264        let s2 = f64::sqrt(2.0);
265        assert_relative_eq!(v.middle, 8.0, epsilon = 1e-9);
266        assert_relative_eq!(v.upper, 8.0 + 2.0 * s2, epsilon = 1e-9);
267        assert_relative_eq!(v.lower, 8.0 - 2.0 * s2, epsilon = 1e-9);
268    }
269}