wickra_core/indicators/
standard_error.rs1use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8#[derive(Debug, Clone)]
48pub struct StandardError {
49 period: usize,
50 window: VecDeque<f64>,
51 sum_x: f64,
52 denom: f64,
54 sum_y: f64,
55 sum_xy: f64,
56 sum_y_sq: f64,
57}
58
59impl StandardError {
60 pub fn new(period: usize) -> Result<Self> {
66 if period < 3 {
67 return Err(Error::InvalidPeriod {
68 message: "standard error needs period >= 3",
69 });
70 }
71 let n = period as f64;
72 let sum_x = n * (n - 1.0) / 2.0;
73 let sum_xx = (n - 1.0) * n * (2.0 * n - 1.0) / 6.0;
74 Ok(Self {
75 period,
76 window: VecDeque::with_capacity(period),
77 sum_x,
78 denom: n * sum_xx - sum_x * sum_x,
79 sum_y: 0.0,
80 sum_xy: 0.0,
81 sum_y_sq: 0.0,
82 })
83 }
84
85 pub const fn period(&self) -> usize {
87 self.period
88 }
89}
90
91impl Indicator for StandardError {
92 type Input = f64;
93 type Output = f64;
94
95 fn update(&mut self, value: f64) -> Option<f64> {
96 if !value.is_finite() {
97 return None;
98 }
99 if self.window.len() == self.period {
100 let y0 = self.window.pop_front().expect("non-empty");
102 self.sum_xy = self.sum_xy - self.sum_y + y0;
103 self.sum_y -= y0;
104 self.sum_y_sq -= y0 * y0;
105 }
106 let k = self.window.len() as f64;
107 self.window.push_back(value);
108 self.sum_y += value;
109 self.sum_xy += k * value;
110 self.sum_y_sq += value * value;
111
112 if self.window.len() < self.period {
113 return None;
114 }
115 let n = self.period as f64;
116 let slope = (n * self.sum_xy - self.sum_x * self.sum_y) / self.denom;
117 let mean_y = self.sum_y / n;
118 let ss_total = self.sum_y_sq - n * mean_y * mean_y;
119 let s_xx = self.denom / n;
121 let rss = (ss_total - slope * slope * s_xx).max(0.0);
122 Some((rss / (n - 2.0)).sqrt())
123 }
124
125 fn reset(&mut self) {
126 self.window.clear();
127 self.sum_y = 0.0;
128 self.sum_xy = 0.0;
129 self.sum_y_sq = 0.0;
130 }
131
132 fn warmup_period(&self) -> usize {
133 self.period
134 }
135
136 fn is_ready(&self) -> bool {
137 self.window.len() == self.period
138 }
139
140 fn name(&self) -> &'static str {
141 "StandardError"
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148 use crate::traits::BatchExt;
149 use approx::assert_relative_eq;
150
151 #[test]
152 fn rejects_period_below_three() {
153 assert!(StandardError::new(0).is_err());
154 assert!(StandardError::new(2).is_err());
155 assert!(StandardError::new(3).is_ok());
156 }
157
158 #[test]
159 fn accessors_and_metadata() {
160 let se = StandardError::new(14).unwrap();
161 assert_eq!(se.period(), 14);
162 assert_eq!(se.warmup_period(), 14);
163 assert_eq!(se.name(), "StandardError");
164 }
165
166 #[test]
167 fn perfect_line_has_zero_error() {
168 let prices: Vec<f64> = (0..30).map(|i| 2.0 * f64::from(i) + 5.0).collect();
170 let mut se = StandardError::new(10).unwrap();
171 for v in se.batch(&prices).into_iter().flatten() {
172 assert_relative_eq!(v, 0.0, epsilon = 1e-9);
173 }
174 }
175
176 #[test]
177 fn constant_series_yields_zero() {
178 let mut se = StandardError::new(5).unwrap();
179 for v in se.batch(&[42.0; 20]).into_iter().flatten() {
180 assert_relative_eq!(v, 0.0, epsilon = 1e-9);
181 }
182 }
183
184 #[test]
185 fn matches_naive_definition() {
186 fn naive(window: &[f64]) -> f64 {
188 let n = window.len() as f64;
189 let mean_y = window.iter().sum::<f64>() / n;
190 let mut sum_xy = 0.0;
191 let mut sum_x = 0.0;
192 let mut sum_xx = 0.0;
193 for (i, &y) in window.iter().enumerate() {
194 let x = i as f64;
195 sum_xy += x * y;
196 sum_x += x;
197 sum_xx += x * x;
198 }
199 let mean_x = sum_x / n;
200 let s_xx = sum_xx - n * mean_x * mean_x;
201 let slope = (sum_xy - n * mean_x * mean_y) / s_xx;
202 let intercept = mean_y - slope * mean_x;
203 let rss: f64 = window
204 .iter()
205 .enumerate()
206 .map(|(i, &y)| {
207 let r = y - (intercept + slope * i as f64);
208 r * r
209 })
210 .sum();
211 (rss / (n - 2.0)).sqrt()
212 }
213
214 let prices: Vec<f64> = (0..60)
215 .map(|i| 100.0 + f64::from(i) * 0.5 + (f64::from(i) * 0.7).sin() * 3.0)
216 .collect();
217 let period = 14;
218 let got = StandardError::new(period).unwrap().batch(&prices);
219 for (i, g) in got.iter().enumerate() {
220 if let Some(v) = g {
221 let expected = naive(&prices[i + 1 - period..=i]);
222 assert_relative_eq!(*v, expected, epsilon = 1e-9);
223 }
224 }
225 }
226
227 #[test]
228 fn reset_clears_state() {
229 let mut se = StandardError::new(5).unwrap();
230 se.batch(&[1.0, 2.0, 3.0, 4.0, 5.0]);
231 assert!(se.is_ready());
232 se.reset();
233 assert!(!se.is_ready());
234 assert_eq!(se.update(1.0), None);
235 }
236
237 #[test]
238 fn batch_equals_streaming() {
239 let prices: Vec<f64> = (0..60)
240 .map(|i| 100.0 + (f64::from(i) * 0.4).sin() * 10.0)
241 .collect();
242 let batch = StandardError::new(14).unwrap().batch(&prices);
243 let mut b = StandardError::new(14).unwrap();
244 let streamed: Vec<_> = prices.iter().map(|p| b.update(*p)).collect();
245 assert_eq!(batch, streamed);
246 }
247}