wickra_core/indicators/
standard_error.rs1use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8#[derive(Debug, Clone)]
48pub struct StandardError {
49 period: usize,
50 window: VecDeque<f64>,
51 sum_x: f64,
52 denom: f64,
54 sum_y: f64,
55 sum_xy: f64,
56 sum_y_sq: f64,
57}
58
59impl StandardError {
60 pub fn new(period: usize) -> Result<Self> {
66 if period < 3 {
67 return Err(Error::InvalidPeriod {
68 message: "standard error needs period >= 3",
69 });
70 }
71 let n = period as f64;
72 let sum_x = n * (n - 1.0) / 2.0;
73 let sum_xx = (n - 1.0) * n * (2.0 * n - 1.0) / 6.0;
74 Ok(Self {
75 period,
76 window: VecDeque::with_capacity(period),
77 sum_x,
78 denom: n * sum_xx - sum_x * sum_x,
79 sum_y: 0.0,
80 sum_xy: 0.0,
81 sum_y_sq: 0.0,
82 })
83 }
84
85 pub const fn period(&self) -> usize {
87 self.period
88 }
89}
90
91impl Indicator for StandardError {
92 type Input = f64;
93 type Output = f64;
94
95 fn update(&mut self, value: f64) -> Option<f64> {
96 if self.window.len() == self.period {
97 let y0 = self.window.pop_front().expect("non-empty");
99 self.sum_xy = self.sum_xy - self.sum_y + y0;
100 self.sum_y -= y0;
101 self.sum_y_sq -= y0 * y0;
102 }
103 let k = self.window.len() as f64;
104 self.window.push_back(value);
105 self.sum_y += value;
106 self.sum_xy += k * value;
107 self.sum_y_sq += value * value;
108
109 if self.window.len() < self.period {
110 return None;
111 }
112 let n = self.period as f64;
113 let slope = (n * self.sum_xy - self.sum_x * self.sum_y) / self.denom;
114 let mean_y = self.sum_y / n;
115 let ss_total = self.sum_y_sq - n * mean_y * mean_y;
116 let s_xx = self.denom / n;
118 let rss = (ss_total - slope * slope * s_xx).max(0.0);
119 Some((rss / (n - 2.0)).sqrt())
120 }
121
122 fn reset(&mut self) {
123 self.window.clear();
124 self.sum_y = 0.0;
125 self.sum_xy = 0.0;
126 self.sum_y_sq = 0.0;
127 }
128
129 fn warmup_period(&self) -> usize {
130 self.period
131 }
132
133 fn is_ready(&self) -> bool {
134 self.window.len() == self.period
135 }
136
137 fn name(&self) -> &'static str {
138 "StandardError"
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145 use crate::traits::BatchExt;
146 use approx::assert_relative_eq;
147
148 #[test]
149 fn rejects_period_below_three() {
150 assert!(StandardError::new(0).is_err());
151 assert!(StandardError::new(2).is_err());
152 assert!(StandardError::new(3).is_ok());
153 }
154
155 #[test]
156 fn accessors_and_metadata() {
157 let se = StandardError::new(14).unwrap();
158 assert_eq!(se.period(), 14);
159 assert_eq!(se.warmup_period(), 14);
160 assert_eq!(se.name(), "StandardError");
161 }
162
163 #[test]
164 fn perfect_line_has_zero_error() {
165 let prices: Vec<f64> = (0..30).map(|i| 2.0 * f64::from(i) + 5.0).collect();
167 let mut se = StandardError::new(10).unwrap();
168 for v in se.batch(&prices).into_iter().flatten() {
169 assert_relative_eq!(v, 0.0, epsilon = 1e-9);
170 }
171 }
172
173 #[test]
174 fn constant_series_yields_zero() {
175 let mut se = StandardError::new(5).unwrap();
176 for v in se.batch(&[42.0; 20]).into_iter().flatten() {
177 assert_relative_eq!(v, 0.0, epsilon = 1e-9);
178 }
179 }
180
181 #[test]
182 fn matches_naive_definition() {
183 fn naive(window: &[f64]) -> f64 {
185 let n = window.len() as f64;
186 let mean_y = window.iter().sum::<f64>() / n;
187 let mut sum_xy = 0.0;
188 let mut sum_x = 0.0;
189 let mut sum_xx = 0.0;
190 for (i, &y) in window.iter().enumerate() {
191 let x = i as f64;
192 sum_xy += x * y;
193 sum_x += x;
194 sum_xx += x * x;
195 }
196 let mean_x = sum_x / n;
197 let s_xx = sum_xx - n * mean_x * mean_x;
198 let slope = (sum_xy - n * mean_x * mean_y) / s_xx;
199 let intercept = mean_y - slope * mean_x;
200 let rss: f64 = window
201 .iter()
202 .enumerate()
203 .map(|(i, &y)| {
204 let r = y - (intercept + slope * i as f64);
205 r * r
206 })
207 .sum();
208 (rss / (n - 2.0)).sqrt()
209 }
210
211 let prices: Vec<f64> = (0..60)
212 .map(|i| 100.0 + f64::from(i) * 0.5 + (f64::from(i) * 0.7).sin() * 3.0)
213 .collect();
214 let period = 14;
215 let got = StandardError::new(period).unwrap().batch(&prices);
216 for (i, g) in got.iter().enumerate() {
217 if let Some(v) = g {
218 let expected = naive(&prices[i + 1 - period..=i]);
219 assert_relative_eq!(*v, expected, epsilon = 1e-9);
220 }
221 }
222 }
223
224 #[test]
225 fn reset_clears_state() {
226 let mut se = StandardError::new(5).unwrap();
227 se.batch(&[1.0, 2.0, 3.0, 4.0, 5.0]);
228 assert!(se.is_ready());
229 se.reset();
230 assert!(!se.is_ready());
231 assert_eq!(se.update(1.0), None);
232 }
233
234 #[test]
235 fn batch_equals_streaming() {
236 let prices: Vec<f64> = (0..60)
237 .map(|i| 100.0 + (f64::from(i) * 0.4).sin() * 10.0)
238 .collect();
239 let batch = StandardError::new(14).unwrap().batch(&prices);
240 let mut b = StandardError::new(14).unwrap();
241 let streamed: Vec<_> = prices.iter().map(|p| b.update(*p)).collect();
242 assert_eq!(batch, streamed);
243 }
244}