Skip to main content

wickra_core/indicators/
standard_error.rs

1//! Standard Error of the rolling least-squares regression.
2
3use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8/// Standard Error of the regression line fit over the last `period` inputs.
9///
10/// Over the trailing window indexed `x = 0, 1, …, period − 1` the OLS line
11/// `y = a + b·x` is fitted, then:
12///
13/// ```text
14/// slope     = (n·Σxy − Σx·Σy) / (n·Σxx − (Σx)²)
15/// SS_total  = Σy² − n·ȳ²                            // total sum of squares
16/// RSS       = SS_total − slope² · S_xx              // residual sum of squares
17/// StdErr    = √( RSS / (n − 2) )                    // n − 2 residual d.o.f.
18/// ```
19///
20/// where `S_xx = (n·Σxx − (Σx)²) / n` is the centred sum of squares of the
21/// design.
22///
23/// This is the textbook **standard error of estimate** of OLS: it measures
24/// the typical distance between the observed prices and the fitted line,
25/// using the residual degrees of freedom `n − 2`. It is the spread that
26/// drives [`crate::BollingerBands`]-style bands around a regression instead of
27/// around an SMA — when the price hugs its trend, `StdErr` is small.
28///
29/// Each `update` is O(1): the `Σx` and `Σxx` terms depend only on `period`
30/// and are precomputed once, while `Σy`, `Σxy`, and `Σy²` are maintained
31/// incrementally as the window slides. Tiny floating-point cancellation
32/// noise that could drive the residual sum of squares slightly negative is
33/// clamped to zero before the square root.
34///
35/// # Example
36///
37/// ```
38/// use wickra_core::{Indicator, StandardError};
39///
40/// let mut indicator = StandardError::new(14).unwrap();
41/// let mut last = None;
42/// for i in 0..40 {
43///     last = indicator.update(100.0 + f64::from(i) + (f64::from(i) * 0.5).sin());
44/// }
45/// assert!(last.is_some());
46/// ```
47#[derive(Debug, Clone)]
48pub struct StandardError {
49    period: usize,
50    window: VecDeque<f64>,
51    sum_x: f64,
52    /// `n·Σxx − (Σx)²` — OLS denominator, constant in `period`.
53    denom: f64,
54    sum_y: f64,
55    sum_xy: f64,
56    sum_y_sq: f64,
57}
58
59impl StandardError {
60    /// Construct a new rolling standard error of regression.
61    ///
62    /// # Errors
63    /// Returns [`Error::InvalidPeriod`] if `period < 3` — the residual
64    /// degrees of freedom `n − 2` would be non-positive.
65    pub fn new(period: usize) -> Result<Self> {
66        if period < 3 {
67            return Err(Error::InvalidPeriod {
68                message: "standard error needs period >= 3",
69            });
70        }
71        let n = period as f64;
72        let sum_x = n * (n - 1.0) / 2.0;
73        let sum_xx = (n - 1.0) * n * (2.0 * n - 1.0) / 6.0;
74        Ok(Self {
75            period,
76            window: VecDeque::with_capacity(period),
77            sum_x,
78            denom: n * sum_xx - sum_x * sum_x,
79            sum_y: 0.0,
80            sum_xy: 0.0,
81            sum_y_sq: 0.0,
82        })
83    }
84
85    /// Configured period.
86    pub const fn period(&self) -> usize {
87        self.period
88    }
89}
90
91impl Indicator for StandardError {
92    type Input = f64;
93    type Output = f64;
94
95    fn update(&mut self, value: f64) -> Option<f64> {
96        if self.window.len() == self.period {
97            // Slide: pop oldest, shift indices, then push the new value at index n − 1.
98            let y0 = self.window.pop_front().expect("non-empty");
99            self.sum_xy = self.sum_xy - self.sum_y + y0;
100            self.sum_y -= y0;
101            self.sum_y_sq -= y0 * y0;
102        }
103        let k = self.window.len() as f64;
104        self.window.push_back(value);
105        self.sum_y += value;
106        self.sum_xy += k * value;
107        self.sum_y_sq += value * value;
108
109        if self.window.len() < self.period {
110            return None;
111        }
112        let n = self.period as f64;
113        let slope = (n * self.sum_xy - self.sum_x * self.sum_y) / self.denom;
114        let mean_y = self.sum_y / n;
115        let ss_total = self.sum_y_sq - n * mean_y * mean_y;
116        // S_xx = denom / n
117        let s_xx = self.denom / n;
118        let rss = (ss_total - slope * slope * s_xx).max(0.0);
119        Some((rss / (n - 2.0)).sqrt())
120    }
121
122    fn reset(&mut self) {
123        self.window.clear();
124        self.sum_y = 0.0;
125        self.sum_xy = 0.0;
126        self.sum_y_sq = 0.0;
127    }
128
129    fn warmup_period(&self) -> usize {
130        self.period
131    }
132
133    fn is_ready(&self) -> bool {
134        self.window.len() == self.period
135    }
136
137    fn name(&self) -> &'static str {
138        "StandardError"
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145    use crate::traits::BatchExt;
146    use approx::assert_relative_eq;
147
148    #[test]
149    fn rejects_period_below_three() {
150        assert!(StandardError::new(0).is_err());
151        assert!(StandardError::new(2).is_err());
152        assert!(StandardError::new(3).is_ok());
153    }
154
155    #[test]
156    fn accessors_and_metadata() {
157        let se = StandardError::new(14).unwrap();
158        assert_eq!(se.period(), 14);
159        assert_eq!(se.warmup_period(), 14);
160        assert_eq!(se.name(), "StandardError");
161    }
162
163    #[test]
164    fn perfect_line_has_zero_error() {
165        // Residuals from a perfectly linear fit are zero, so SE = 0.
166        let prices: Vec<f64> = (0..30).map(|i| 2.0 * f64::from(i) + 5.0).collect();
167        let mut se = StandardError::new(10).unwrap();
168        for v in se.batch(&prices).into_iter().flatten() {
169            assert_relative_eq!(v, 0.0, epsilon = 1e-9);
170        }
171    }
172
173    #[test]
174    fn constant_series_yields_zero() {
175        let mut se = StandardError::new(5).unwrap();
176        for v in se.batch(&[42.0; 20]).into_iter().flatten() {
177            assert_relative_eq!(v, 0.0, epsilon = 1e-9);
178        }
179    }
180
181    #[test]
182    fn matches_naive_definition() {
183        // Compare the O(1) update against a fresh-from-scratch OLS refit each bar.
184        fn naive(window: &[f64]) -> f64 {
185            let n = window.len() as f64;
186            let mean_y = window.iter().sum::<f64>() / n;
187            let mut sum_xy = 0.0;
188            let mut sum_x = 0.0;
189            let mut sum_xx = 0.0;
190            for (i, &y) in window.iter().enumerate() {
191                let x = i as f64;
192                sum_xy += x * y;
193                sum_x += x;
194                sum_xx += x * x;
195            }
196            let mean_x = sum_x / n;
197            let s_xx = sum_xx - n * mean_x * mean_x;
198            let slope = (sum_xy - n * mean_x * mean_y) / s_xx;
199            let intercept = mean_y - slope * mean_x;
200            let rss: f64 = window
201                .iter()
202                .enumerate()
203                .map(|(i, &y)| {
204                    let r = y - (intercept + slope * i as f64);
205                    r * r
206                })
207                .sum();
208            (rss / (n - 2.0)).sqrt()
209        }
210
211        let prices: Vec<f64> = (0..60)
212            .map(|i| 100.0 + f64::from(i) * 0.5 + (f64::from(i) * 0.7).sin() * 3.0)
213            .collect();
214        let period = 14;
215        let got = StandardError::new(period).unwrap().batch(&prices);
216        for (i, g) in got.iter().enumerate() {
217            if let Some(v) = g {
218                let expected = naive(&prices[i + 1 - period..=i]);
219                assert_relative_eq!(*v, expected, epsilon = 1e-9);
220            }
221        }
222    }
223
224    #[test]
225    fn reset_clears_state() {
226        let mut se = StandardError::new(5).unwrap();
227        se.batch(&[1.0, 2.0, 3.0, 4.0, 5.0]);
228        assert!(se.is_ready());
229        se.reset();
230        assert!(!se.is_ready());
231        assert_eq!(se.update(1.0), None);
232    }
233
234    #[test]
235    fn batch_equals_streaming() {
236        let prices: Vec<f64> = (0..60)
237            .map(|i| 100.0 + (f64::from(i) * 0.4).sin() * 10.0)
238            .collect();
239        let batch = StandardError::new(14).unwrap().batch(&prices);
240        let mut b = StandardError::new(14).unwrap();
241        let streamed: Vec<_> = prices.iter().map(|p| b.update(*p)).collect();
242        assert_eq!(batch, streamed);
243    }
244}