Skip to main content

wickra_core/indicators/
sma.rs

1//! Simple Moving Average.
2
3use crate::error::{Error, Result};
4use crate::traits::Indicator;
5
6/// Simple Moving Average over a fixed window.
7///
8/// Maintains a rolling sum so each update is O(1). Output equals
9/// `sum(last `period` prices) / period` once the window is full; `None` before.
10///
11/// On long-running streams a single-subtract incremental sum can accumulate
12/// rounding error (catastrophic cancellation when values of very different
13/// magnitudes are alternately added and removed). To keep drift bounded, the
14/// running sum is reseeded from the live window every `16 · period` updates —
15/// O(1) amortised cost (`O(period)` work amortised over `O(period)` updates),
16/// zero observable behaviour change on inputs that did not drift to begin
17/// with, and a strict cap on accumulated rounding for streams that did.
18///
19/// # Example
20///
21/// ```
22/// use wickra_core::{Indicator, Sma};
23///
24/// let mut indicator = Sma::new(3).unwrap();
25/// let mut last = None;
26/// for i in 0..80 {
27///     last = indicator.update(100.0 + f64::from(i));
28/// }
29/// assert!(last.is_some());
30/// ```
31#[derive(Debug, Clone)]
32pub struct Sma {
33    period: usize,
34    /// Fixed-capacity ring buffer of the last `period` finite inputs. A flat
35    /// `Box<[f64]>` with a manual write cursor beats `VecDeque` on this hot path:
36    /// sequential storage, branchless wraparound, no per-call bookkeeping.
37    buf: Box<[f64]>,
38    /// Index of the next slot to write — also the oldest element once full.
39    head: usize,
40    /// Number of slots filled, saturating at `period`.
41    count: usize,
42    sum: f64,
43    /// Number of finite updates since the running `sum` was last reseeded from
44    /// the live window. Caps accumulated floating-point drift on long streams.
45    /// See [`RECOMPUTE_EVERY`] below.
46    updates_since_recompute: usize,
47}
48
49/// How often (in finite updates) the incremental sum is reseeded from the live
50/// window. The multiplier `16` is the smallest power of two that keeps the
51/// amortised cost flat under any `period` while still bounding any drift to
52/// roughly `16 · period · ULP · max(|x|)` — sub-picodollar on real-world price
53/// scales.
54const RECOMPUTE_EVERY: usize = 16;
55
56impl Sma {
57    /// Construct a new SMA with the given window length.
58    ///
59    /// # Errors
60    ///
61    /// Returns [`Error::PeriodZero`] if `period == 0`.
62    pub fn new(period: usize) -> Result<Self> {
63        if period == 0 {
64            return Err(Error::PeriodZero);
65        }
66        Ok(Self {
67            period,
68            buf: vec![0.0; period].into_boxed_slice(),
69            head: 0,
70            count: 0,
71            sum: 0.0,
72            updates_since_recompute: 0,
73        })
74    }
75
76    /// Configured window length.
77    pub const fn period(&self) -> usize {
78        self.period
79    }
80
81    /// Current value if available.
82    pub fn value(&self) -> Option<f64> {
83        if self.count == self.period {
84            Some(self.sum / self.period as f64)
85        } else {
86            None
87        }
88    }
89
90    /// Vectorized batch returning one `f64` per input (`NaN` during warmup).
91    ///
92    /// Shadows the generic [`BatchNanExt::batch_nan`](crate::BatchNanExt) blanket
93    /// default via inherent-method resolution. For a fresh, all-finite slice it
94    /// inlines `update`'s rolling sum and drift-reseed, writing the mean as a bare
95    /// `f64` (warmup → `NaN`) instead of allocating an `Option<f64>` per element
96    /// and walking the result a second time. Same add/subtract order, same reseed
97    /// cadence, same `sum / period` division — so it is *bit-for-bit* equal to
98    /// replaying `update`, including the long-stream drift bound. Any other state,
99    /// or a non-finite element, defers to the exact `update` replay.
100    pub fn batch_nan(&mut self, inputs: &[f64]) -> Vec<f64> {
101        let p = self.period;
102        if self.count != 0
103            || self.updates_since_recompute != 0
104            || !inputs.iter().all(|x| x.is_finite())
105        {
106            return inputs
107                .iter()
108                .map(|&x| self.update(x).unwrap_or(f64::NAN))
109                .collect();
110        }
111
112        let p_f64 = p as f64;
113        let mut out = Vec::with_capacity(inputs.len());
114        for &x in inputs {
115            if self.count == p {
116                self.sum -= self.buf[self.head];
117                self.buf[self.head] = x;
118                self.sum += x;
119            } else {
120                self.buf[self.head] = x;
121                self.sum += x;
122                self.count += 1;
123            }
124            self.head += 1;
125            if self.head == p {
126                self.head = 0;
127            }
128            self.updates_since_recompute += 1;
129            if self.updates_since_recompute >= RECOMPUTE_EVERY * p {
130                self.sum = self.buf[self.head..]
131                    .iter()
132                    .chain(&self.buf[..self.head])
133                    .copied()
134                    .sum();
135                self.updates_since_recompute = 0;
136            }
137            out.push(if self.count == p {
138                self.sum / p_f64
139            } else {
140                f64::NAN
141            });
142        }
143        out
144    }
145}
146
147impl Indicator for Sma {
148    type Input = f64;
149    type Output = f64;
150
151    fn update(&mut self, input: f64) -> Option<f64> {
152        if !input.is_finite() {
153            return self.value();
154        }
155        if self.count == self.period {
156            // Window full: overwrite the oldest slot (at `head`). Each step is a
157            // single f64 add/subtract — O(1) but introduces ~1 ULP of rounding
158            // noise. The periodic reseed below caps the accumulated drift.
159            self.sum -= self.buf[self.head];
160            self.buf[self.head] = input;
161            self.sum += input;
162        } else {
163            self.buf[self.head] = input;
164            self.sum += input;
165            self.count += 1;
166        }
167        // Branchless-ish wraparound, cheaper than `% period`.
168        self.head += 1;
169        if self.head == self.period {
170            self.head = 0;
171        }
172        self.updates_since_recompute += 1;
173        if self.updates_since_recompute >= RECOMPUTE_EVERY * self.period {
174            // Reseed in chronological order (oldest at `head`) so the running sum
175            // tracks a fresh from-scratch mean to the bit on stable inputs.
176            self.sum = self.buf[self.head..]
177                .iter()
178                .chain(&self.buf[..self.head])
179                .copied()
180                .sum();
181            self.updates_since_recompute = 0;
182        }
183        self.value()
184    }
185
186    fn reset(&mut self) {
187        self.head = 0;
188        self.count = 0;
189        self.sum = 0.0;
190        self.updates_since_recompute = 0;
191    }
192
193    fn warmup_period(&self) -> usize {
194        self.period
195    }
196
197    fn is_ready(&self) -> bool {
198        self.count == self.period
199    }
200
201    fn name(&self) -> &'static str {
202        "SMA"
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209    use crate::traits::BatchExt;
210    use approx::assert_relative_eq;
211    use std::collections::VecDeque;
212
213    #[test]
214    fn new_rejects_zero_period() {
215        assert!(matches!(Sma::new(0), Err(Error::PeriodZero)));
216    }
217
218    /// Cover the const accessor `period` (70-72) and the Indicator-impl
219    /// `warmup_period` (115-117) + `name` (123-125). Existing tests
220    /// inspect SMA output but never query the metadata.
221    #[test]
222    fn accessors_and_metadata() {
223        let sma = Sma::new(20).unwrap();
224        assert_eq!(sma.period(), 20);
225        assert_eq!(sma.warmup_period(), 20);
226        assert_eq!(sma.name(), "SMA");
227    }
228
229    #[test]
230    fn warmup_returns_none() {
231        let mut sma = Sma::new(3).unwrap();
232        assert_eq!(sma.update(1.0), None);
233        assert_eq!(sma.update(2.0), None);
234        assert_eq!(sma.update(3.0), Some(2.0));
235    }
236
237    #[test]
238    fn rolls_window_after_full() {
239        let mut sma = Sma::new(3).unwrap();
240        let out: Vec<_> = [1.0, 2.0, 3.0, 4.0, 5.0]
241            .iter()
242            .map(|p| sma.update(*p))
243            .collect();
244        assert_eq!(out, vec![None, None, Some(2.0), Some(3.0), Some(4.0)]);
245    }
246
247    #[test]
248    fn period_one_is_pass_through() {
249        let mut sma = Sma::new(1).unwrap();
250        assert_eq!(sma.update(5.0), Some(5.0));
251        assert_eq!(sma.update(10.0), Some(10.0));
252    }
253
254    #[test]
255    fn ignores_non_finite_input_but_keeps_state() {
256        let mut sma = Sma::new(3).unwrap();
257        sma.update(1.0);
258        sma.update(2.0);
259        sma.update(3.0);
260        assert_eq!(sma.update(f64::NAN), Some(2.0));
261        assert_eq!(sma.update(f64::INFINITY), Some(2.0));
262        // Non-finite inputs were not pushed; window still holds 1,2,3.
263        assert_eq!(sma.update(6.0), Some((2.0 + 3.0 + 6.0) / 3.0));
264    }
265
266    #[test]
267    fn reset_clears_state() {
268        let mut sma = Sma::new(3).unwrap();
269        sma.batch(&[1.0, 2.0, 3.0]);
270        assert!(sma.is_ready());
271        sma.reset();
272        assert!(!sma.is_ready());
273        assert_eq!(sma.update(10.0), None);
274    }
275
276    #[test]
277    fn batch_equals_streaming() {
278        let prices: Vec<f64> = (1..=20).map(f64::from).collect();
279        let mut a = Sma::new(5).unwrap();
280        let batch = a.batch(&prices);
281        let mut b = Sma::new(5).unwrap();
282        let streamed: Vec<_> = prices.iter().map(|p| b.update(*p)).collect();
283        assert_eq!(batch, streamed);
284    }
285
286    #[test]
287    fn known_reference_values() {
288        // SMA(3) of [2, 4, 6, 8, 10] -> [_, _, 4, 6, 8]
289        let mut sma = Sma::new(3).unwrap();
290        let out = sma.batch(&[2.0, 4.0, 6.0, 8.0, 10.0]);
291        assert_eq!(out[2], Some(4.0));
292        assert_eq!(out[3], Some(6.0));
293        assert_eq!(out[4], Some(8.0));
294    }
295
296    #[test]
297    fn constant_series_yields_constant_sma() {
298        let mut sma = Sma::new(5).unwrap();
299        let v = sma.batch(&[7.0; 10]);
300        for x in v.iter().skip(4) {
301            assert_relative_eq!(x.unwrap(), 7.0, epsilon = 1e-12);
302        }
303    }
304
305    /// NaN-aware bit-equality for the `f64`-with-NaN-warmup batch outputs.
306    fn bits_eq(a: &[f64], b: &[f64]) -> bool {
307        a.len() == b.len()
308            && a.iter()
309                .zip(b)
310                .all(|(x, y)| x == y || (x.is_nan() && y.is_nan()))
311    }
312
313    fn sma_replay(period: usize, series: &[f64]) -> Vec<f64> {
314        let mut s = Sma::new(period).unwrap();
315        series
316            .iter()
317            .map(|&x| s.update(x).unwrap_or(f64::NAN))
318            .collect()
319    }
320
321    #[test]
322    fn batch_nan_fast_path_is_bit_identical_with_reseed() {
323        // > 16*period inputs so the drift-reseed branch fires inside batch_nan.
324        let series: Vec<f64> = (0..500)
325            .map(|i| (f64::from(i) * 0.2).sin() * 10.0 + 50.0)
326            .collect();
327        let mut sma = Sma::new(14).unwrap();
328        let got = sma.batch_nan(&series);
329        assert!(bits_eq(&got, &sma_replay(14, &series)));
330        // State left where the replay would: continued updates agree.
331        let mut ref_sma = Sma::new(14).unwrap();
332        for &x in &series {
333            ref_sma.update(x);
334        }
335        assert_eq!(sma.update(42.0), ref_sma.update(42.0));
336    }
337
338    #[test]
339    fn batch_nan_falls_back_on_non_finite() {
340        let series = [1.0, 2.0, f64::NAN, 4.0, 5.0, 6.0];
341        let mut sma = Sma::new(3).unwrap();
342        assert!(bits_eq(&sma.batch_nan(&series), &sma_replay(3, &series)));
343    }
344
345    #[test]
346    fn batch_nan_falls_back_when_not_fresh() {
347        let mut sma = Sma::new(3).unwrap();
348        sma.update(99.0);
349        let series = [1.0, 2.0, 3.0, 4.0];
350        let mut ref_sma = Sma::new(3).unwrap();
351        ref_sma.update(99.0);
352        let want: Vec<f64> = series
353            .iter()
354            .map(|&x| ref_sma.update(x).unwrap_or(f64::NAN))
355            .collect();
356        assert!(bits_eq(&sma.batch_nan(&series), &want));
357    }
358
359    #[test]
360    fn batch_nan_sub_period_slice_is_all_nan() {
361        let series = [1.0, 2.0, 3.0];
362        let mut sma = Sma::new(10).unwrap();
363        let got = sma.batch_nan(&series);
364        assert!(bits_eq(&got, &sma_replay(10, &series)));
365        assert!(got.iter().all(|x| x.is_nan()));
366    }
367
368    proptest::proptest! {
369        #![proptest_config(proptest::test_runner::Config::with_cases(64))]
370        #[test]
371        fn sma_matches_naive_definition(
372            period in 1usize..20,
373            prices in proptest::collection::vec(-1000.0_f64..1000.0, 0..200),
374        ) {
375            let mut sma = Sma::new(period).unwrap();
376            let stream: Vec<_> = prices.iter().map(|p| sma.update(*p)).collect();
377            for (i, got) in stream.iter().enumerate() {
378                if i + 1 < period {
379                    proptest::prop_assert!(got.is_none());
380                } else {
381                    let window = &prices[i + 1 - period..=i];
382                    let expected = window.iter().sum::<f64>() / period as f64;
383                    let actual = got.expect("ready");
384                    proptest::prop_assert!(
385                        (actual - expected).abs() < 1e-9,
386                        "i={i} actual={actual} expected={expected}"
387                    );
388                }
389            }
390        }
391    }
392
393    /// Long-running stability check. Runs more updates than `RECOMPUTE_EVERY *
394    /// period` so the periodic reseed must fire several times, then asserts
395    /// that the reported SMA still equals a fresh from-scratch mean over the
396    /// live window to within tight floating-point tolerance. Inputs swing
397    /// between two magnitudes (`1e9` and `1.0`) — a pattern designed to
398    /// expose catastrophic cancellation in a naive single-subtract sum.
399    #[test]
400    fn long_stream_drift_stays_bounded() {
401        let period = 20;
402        let mut sma = Sma::new(period).unwrap();
403        let mut window: VecDeque<f64> = VecDeque::with_capacity(period);
404        // `RECOMPUTE_EVERY * period * 5` updates → recompute fires 5+ times.
405        let n_updates = 16 * period * 5;
406        for i in 0..n_updates {
407            let v = if i % 2 == 0 { 1e9 } else { 1.0 };
408            sma.update(v);
409            if window.len() == period {
410                window.pop_front();
411            }
412            window.push_back(v);
413        }
414        let from_scratch: f64 = window.iter().sum::<f64>() / period as f64;
415        let got = sma.value().expect("warmed up");
416        assert!(
417            (got - from_scratch).abs() < 1e-6,
418            "SMA drift exceeds 1e-6 over {n_updates} updates: got={got}, scratch={from_scratch}"
419        );
420    }
421}