wickra_core/indicators/
sma.rs1use crate::error::{Error, Result};
4use crate::traits::Indicator;
5
6#[derive(Debug, Clone)]
32pub struct Sma {
33 period: usize,
34 buf: Box<[f64]>,
38 head: usize,
40 count: usize,
42 sum: f64,
43 updates_since_recompute: usize,
47}
48
49const RECOMPUTE_EVERY: usize = 16;
55
56impl Sma {
57 pub fn new(period: usize) -> Result<Self> {
63 if period == 0 {
64 return Err(Error::PeriodZero);
65 }
66 Ok(Self {
67 period,
68 buf: vec![0.0; period].into_boxed_slice(),
69 head: 0,
70 count: 0,
71 sum: 0.0,
72 updates_since_recompute: 0,
73 })
74 }
75
76 pub const fn period(&self) -> usize {
78 self.period
79 }
80
81 pub fn value(&self) -> Option<f64> {
83 if self.count == self.period {
84 Some(self.sum / self.period as f64)
85 } else {
86 None
87 }
88 }
89
90 pub fn batch_nan(&mut self, inputs: &[f64]) -> Vec<f64> {
101 let p = self.period;
102 if self.count != 0
103 || self.updates_since_recompute != 0
104 || !inputs.iter().all(|x| x.is_finite())
105 {
106 return inputs
107 .iter()
108 .map(|&x| self.update(x).unwrap_or(f64::NAN))
109 .collect();
110 }
111
112 let p_f64 = p as f64;
113 let mut out = Vec::with_capacity(inputs.len());
114 for &x in inputs {
115 if self.count == p {
116 self.sum -= self.buf[self.head];
117 self.buf[self.head] = x;
118 self.sum += x;
119 } else {
120 self.buf[self.head] = x;
121 self.sum += x;
122 self.count += 1;
123 }
124 self.head += 1;
125 if self.head == p {
126 self.head = 0;
127 }
128 self.updates_since_recompute += 1;
129 if self.updates_since_recompute >= RECOMPUTE_EVERY * p {
130 self.sum = self.buf[self.head..]
131 .iter()
132 .chain(&self.buf[..self.head])
133 .copied()
134 .sum();
135 self.updates_since_recompute = 0;
136 }
137 out.push(if self.count == p {
138 self.sum / p_f64
139 } else {
140 f64::NAN
141 });
142 }
143 out
144 }
145}
146
147impl Indicator for Sma {
148 type Input = f64;
149 type Output = f64;
150
151 fn update(&mut self, input: f64) -> Option<f64> {
152 if !input.is_finite() {
153 return self.value();
154 }
155 if self.count == self.period {
156 self.sum -= self.buf[self.head];
160 self.buf[self.head] = input;
161 self.sum += input;
162 } else {
163 self.buf[self.head] = input;
164 self.sum += input;
165 self.count += 1;
166 }
167 self.head += 1;
169 if self.head == self.period {
170 self.head = 0;
171 }
172 self.updates_since_recompute += 1;
173 if self.updates_since_recompute >= RECOMPUTE_EVERY * self.period {
174 self.sum = self.buf[self.head..]
177 .iter()
178 .chain(&self.buf[..self.head])
179 .copied()
180 .sum();
181 self.updates_since_recompute = 0;
182 }
183 self.value()
184 }
185
186 fn reset(&mut self) {
187 self.head = 0;
188 self.count = 0;
189 self.sum = 0.0;
190 self.updates_since_recompute = 0;
191 }
192
193 fn warmup_period(&self) -> usize {
194 self.period
195 }
196
197 fn is_ready(&self) -> bool {
198 self.count == self.period
199 }
200
201 fn name(&self) -> &'static str {
202 "SMA"
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209 use crate::traits::BatchExt;
210 use approx::assert_relative_eq;
211 use std::collections::VecDeque;
212
213 #[test]
214 fn new_rejects_zero_period() {
215 assert!(matches!(Sma::new(0), Err(Error::PeriodZero)));
216 }
217
218 #[test]
222 fn accessors_and_metadata() {
223 let sma = Sma::new(20).unwrap();
224 assert_eq!(sma.period(), 20);
225 assert_eq!(sma.warmup_period(), 20);
226 assert_eq!(sma.name(), "SMA");
227 }
228
229 #[test]
230 fn warmup_returns_none() {
231 let mut sma = Sma::new(3).unwrap();
232 assert_eq!(sma.update(1.0), None);
233 assert_eq!(sma.update(2.0), None);
234 assert_eq!(sma.update(3.0), Some(2.0));
235 }
236
237 #[test]
238 fn rolls_window_after_full() {
239 let mut sma = Sma::new(3).unwrap();
240 let out: Vec<_> = [1.0, 2.0, 3.0, 4.0, 5.0]
241 .iter()
242 .map(|p| sma.update(*p))
243 .collect();
244 assert_eq!(out, vec![None, None, Some(2.0), Some(3.0), Some(4.0)]);
245 }
246
247 #[test]
248 fn period_one_is_pass_through() {
249 let mut sma = Sma::new(1).unwrap();
250 assert_eq!(sma.update(5.0), Some(5.0));
251 assert_eq!(sma.update(10.0), Some(10.0));
252 }
253
254 #[test]
255 fn ignores_non_finite_input_but_keeps_state() {
256 let mut sma = Sma::new(3).unwrap();
257 sma.update(1.0);
258 sma.update(2.0);
259 sma.update(3.0);
260 assert_eq!(sma.update(f64::NAN), Some(2.0));
261 assert_eq!(sma.update(f64::INFINITY), Some(2.0));
262 assert_eq!(sma.update(6.0), Some((2.0 + 3.0 + 6.0) / 3.0));
264 }
265
266 #[test]
267 fn reset_clears_state() {
268 let mut sma = Sma::new(3).unwrap();
269 sma.batch(&[1.0, 2.0, 3.0]);
270 assert!(sma.is_ready());
271 sma.reset();
272 assert!(!sma.is_ready());
273 assert_eq!(sma.update(10.0), None);
274 }
275
276 #[test]
277 fn batch_equals_streaming() {
278 let prices: Vec<f64> = (1..=20).map(f64::from).collect();
279 let mut a = Sma::new(5).unwrap();
280 let batch = a.batch(&prices);
281 let mut b = Sma::new(5).unwrap();
282 let streamed: Vec<_> = prices.iter().map(|p| b.update(*p)).collect();
283 assert_eq!(batch, streamed);
284 }
285
286 #[test]
287 fn known_reference_values() {
288 let mut sma = Sma::new(3).unwrap();
290 let out = sma.batch(&[2.0, 4.0, 6.0, 8.0, 10.0]);
291 assert_eq!(out[2], Some(4.0));
292 assert_eq!(out[3], Some(6.0));
293 assert_eq!(out[4], Some(8.0));
294 }
295
296 #[test]
297 fn constant_series_yields_constant_sma() {
298 let mut sma = Sma::new(5).unwrap();
299 let v = sma.batch(&[7.0; 10]);
300 for x in v.iter().skip(4) {
301 assert_relative_eq!(x.unwrap(), 7.0, epsilon = 1e-12);
302 }
303 }
304
305 fn bits_eq(a: &[f64], b: &[f64]) -> bool {
307 a.len() == b.len()
308 && a.iter()
309 .zip(b)
310 .all(|(x, y)| x == y || (x.is_nan() && y.is_nan()))
311 }
312
313 fn sma_replay(period: usize, series: &[f64]) -> Vec<f64> {
314 let mut s = Sma::new(period).unwrap();
315 series
316 .iter()
317 .map(|&x| s.update(x).unwrap_or(f64::NAN))
318 .collect()
319 }
320
321 #[test]
322 fn batch_nan_fast_path_is_bit_identical_with_reseed() {
323 let series: Vec<f64> = (0..500)
325 .map(|i| (f64::from(i) * 0.2).sin() * 10.0 + 50.0)
326 .collect();
327 let mut sma = Sma::new(14).unwrap();
328 let got = sma.batch_nan(&series);
329 assert!(bits_eq(&got, &sma_replay(14, &series)));
330 let mut ref_sma = Sma::new(14).unwrap();
332 for &x in &series {
333 ref_sma.update(x);
334 }
335 assert_eq!(sma.update(42.0), ref_sma.update(42.0));
336 }
337
338 #[test]
339 fn batch_nan_falls_back_on_non_finite() {
340 let series = [1.0, 2.0, f64::NAN, 4.0, 5.0, 6.0];
341 let mut sma = Sma::new(3).unwrap();
342 assert!(bits_eq(&sma.batch_nan(&series), &sma_replay(3, &series)));
343 }
344
345 #[test]
346 fn batch_nan_falls_back_when_not_fresh() {
347 let mut sma = Sma::new(3).unwrap();
348 sma.update(99.0);
349 let series = [1.0, 2.0, 3.0, 4.0];
350 let mut ref_sma = Sma::new(3).unwrap();
351 ref_sma.update(99.0);
352 let want: Vec<f64> = series
353 .iter()
354 .map(|&x| ref_sma.update(x).unwrap_or(f64::NAN))
355 .collect();
356 assert!(bits_eq(&sma.batch_nan(&series), &want));
357 }
358
359 #[test]
360 fn batch_nan_sub_period_slice_is_all_nan() {
361 let series = [1.0, 2.0, 3.0];
362 let mut sma = Sma::new(10).unwrap();
363 let got = sma.batch_nan(&series);
364 assert!(bits_eq(&got, &sma_replay(10, &series)));
365 assert!(got.iter().all(|x| x.is_nan()));
366 }
367
368 proptest::proptest! {
369 #![proptest_config(proptest::test_runner::Config::with_cases(64))]
370 #[test]
371 fn sma_matches_naive_definition(
372 period in 1usize..20,
373 prices in proptest::collection::vec(-1000.0_f64..1000.0, 0..200),
374 ) {
375 let mut sma = Sma::new(period).unwrap();
376 let stream: Vec<_> = prices.iter().map(|p| sma.update(*p)).collect();
377 for (i, got) in stream.iter().enumerate() {
378 if i + 1 < period {
379 proptest::prop_assert!(got.is_none());
380 } else {
381 let window = &prices[i + 1 - period..=i];
382 let expected = window.iter().sum::<f64>() / period as f64;
383 let actual = got.expect("ready");
384 proptest::prop_assert!(
385 (actual - expected).abs() < 1e-9,
386 "i={i} actual={actual} expected={expected}"
387 );
388 }
389 }
390 }
391 }
392
393 #[test]
400 fn long_stream_drift_stays_bounded() {
401 let period = 20;
402 let mut sma = Sma::new(period).unwrap();
403 let mut window: VecDeque<f64> = VecDeque::with_capacity(period);
404 let n_updates = 16 * period * 5;
406 for i in 0..n_updates {
407 let v = if i % 2 == 0 { 1e9 } else { 1.0 };
408 sma.update(v);
409 if window.len() == period {
410 window.pop_front();
411 }
412 window.push_back(v);
413 }
414 let from_scratch: f64 = window.iter().sum::<f64>() / period as f64;
415 let got = sma.value().expect("warmed up");
416 assert!(
417 (got - from_scratch).abs() < 1e-6,
418 "SMA drift exceeds 1e-6 over {n_updates} updates: got={got}, scratch={from_scratch}"
419 );
420 }
421}