Skip to main content

wickra_core/indicators/
sma.rs

1//! Simple Moving Average.
2
3use crate::error::{Error, Result};
4use crate::traits::Indicator;
5
6/// Simple Moving Average over a fixed window.
7///
8/// Maintains a rolling sum so each update is O(1). Output equals
9/// `sum(last `period` prices) / period` once the window is full; `None` before.
10///
11/// On long-running streams a single-subtract incremental sum can accumulate
12/// rounding error (catastrophic cancellation when values of very different
13/// magnitudes are alternately added and removed). To keep drift bounded, the
14/// running sum is reseeded from the live window every `16 · period` updates —
15/// O(1) amortised cost (`O(period)` work amortised over `O(period)` updates),
16/// zero observable behaviour change on inputs that did not drift to begin
17/// with, and a strict cap on accumulated rounding for streams that did.
18///
19/// # Example
20///
21/// ```
22/// use wickra_core::{Indicator, Sma};
23///
24/// let mut indicator = Sma::new(3).unwrap();
25/// let mut last = None;
26/// for i in 0..80 {
27///     last = indicator.update(100.0 + f64::from(i));
28/// }
29/// assert!(last.is_some());
30/// ```
31#[derive(Debug, Clone)]
32pub struct Sma {
33    period: usize,
34    /// Fixed-capacity ring buffer of the last `period` finite inputs. A flat
35    /// `Box<[f64]>` with a manual write cursor beats `VecDeque` on this hot path:
36    /// sequential storage, branchless wraparound, no per-call bookkeeping.
37    buf: Box<[f64]>,
38    /// Index of the next slot to write — also the oldest element once full.
39    head: usize,
40    /// Number of slots filled, saturating at `period`.
41    count: usize,
42    sum: f64,
43    /// Number of finite updates since the running `sum` was last reseeded from
44    /// the live window. Caps accumulated floating-point drift on long streams.
45    /// See [`RECOMPUTE_EVERY`] below.
46    updates_since_recompute: usize,
47}
48
49/// How often (in finite updates) the incremental sum is reseeded from the live
50/// window. The multiplier `16` is the smallest power of two that keeps the
51/// amortised cost flat under any `period` while still bounding any drift to
52/// roughly `16 · period · ULP · max(|x|)` — sub-picodollar on real-world price
53/// scales.
54const RECOMPUTE_EVERY: usize = 16;
55
56impl Sma {
57    /// Construct a new SMA with the given window length.
58    ///
59    /// # Errors
60    ///
61    /// Returns [`Error::PeriodZero`] if `period == 0`.
62    pub fn new(period: usize) -> Result<Self> {
63        if period == 0 {
64            return Err(Error::PeriodZero);
65        }
66        Ok(Self {
67            period,
68            buf: vec![0.0; period].into_boxed_slice(),
69            head: 0,
70            count: 0,
71            sum: 0.0,
72            updates_since_recompute: 0,
73        })
74    }
75
76    /// Configured window length.
77    pub const fn period(&self) -> usize {
78        self.period
79    }
80
81    /// Current value if available.
82    pub fn value(&self) -> Option<f64> {
83        if self.count == self.period {
84            Some(self.sum / self.period as f64)
85        } else {
86            None
87        }
88    }
89}
90
91impl Indicator for Sma {
92    type Input = f64;
93    type Output = f64;
94
95    fn update(&mut self, input: f64) -> Option<f64> {
96        if !input.is_finite() {
97            return self.value();
98        }
99        if self.count == self.period {
100            // Window full: overwrite the oldest slot (at `head`). Each step is a
101            // single f64 add/subtract — O(1) but introduces ~1 ULP of rounding
102            // noise. The periodic reseed below caps the accumulated drift.
103            self.sum -= self.buf[self.head];
104            self.buf[self.head] = input;
105            self.sum += input;
106        } else {
107            self.buf[self.head] = input;
108            self.sum += input;
109            self.count += 1;
110        }
111        // Branchless-ish wraparound, cheaper than `% period`.
112        self.head += 1;
113        if self.head == self.period {
114            self.head = 0;
115        }
116        self.updates_since_recompute += 1;
117        if self.updates_since_recompute >= RECOMPUTE_EVERY * self.period {
118            // Reseed in chronological order (oldest at `head`) so the running sum
119            // tracks a fresh from-scratch mean to the bit on stable inputs.
120            self.sum = self.buf[self.head..]
121                .iter()
122                .chain(&self.buf[..self.head])
123                .copied()
124                .sum();
125            self.updates_since_recompute = 0;
126        }
127        self.value()
128    }
129
130    fn reset(&mut self) {
131        self.head = 0;
132        self.count = 0;
133        self.sum = 0.0;
134        self.updates_since_recompute = 0;
135    }
136
137    fn warmup_period(&self) -> usize {
138        self.period
139    }
140
141    fn is_ready(&self) -> bool {
142        self.count == self.period
143    }
144
145    fn name(&self) -> &'static str {
146        "SMA"
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153    use crate::traits::BatchExt;
154    use approx::assert_relative_eq;
155    use std::collections::VecDeque;
156
157    #[test]
158    fn new_rejects_zero_period() {
159        assert!(matches!(Sma::new(0), Err(Error::PeriodZero)));
160    }
161
162    /// Cover the const accessor `period` (70-72) and the Indicator-impl
163    /// `warmup_period` (115-117) + `name` (123-125). Existing tests
164    /// inspect SMA output but never query the metadata.
165    #[test]
166    fn accessors_and_metadata() {
167        let sma = Sma::new(20).unwrap();
168        assert_eq!(sma.period(), 20);
169        assert_eq!(sma.warmup_period(), 20);
170        assert_eq!(sma.name(), "SMA");
171    }
172
173    #[test]
174    fn warmup_returns_none() {
175        let mut sma = Sma::new(3).unwrap();
176        assert_eq!(sma.update(1.0), None);
177        assert_eq!(sma.update(2.0), None);
178        assert_eq!(sma.update(3.0), Some(2.0));
179    }
180
181    #[test]
182    fn rolls_window_after_full() {
183        let mut sma = Sma::new(3).unwrap();
184        let out: Vec<_> = [1.0, 2.0, 3.0, 4.0, 5.0]
185            .iter()
186            .map(|p| sma.update(*p))
187            .collect();
188        assert_eq!(out, vec![None, None, Some(2.0), Some(3.0), Some(4.0)]);
189    }
190
191    #[test]
192    fn period_one_is_pass_through() {
193        let mut sma = Sma::new(1).unwrap();
194        assert_eq!(sma.update(5.0), Some(5.0));
195        assert_eq!(sma.update(10.0), Some(10.0));
196    }
197
198    #[test]
199    fn ignores_non_finite_input_but_keeps_state() {
200        let mut sma = Sma::new(3).unwrap();
201        sma.update(1.0);
202        sma.update(2.0);
203        sma.update(3.0);
204        assert_eq!(sma.update(f64::NAN), Some(2.0));
205        assert_eq!(sma.update(f64::INFINITY), Some(2.0));
206        // Non-finite inputs were not pushed; window still holds 1,2,3.
207        assert_eq!(sma.update(6.0), Some((2.0 + 3.0 + 6.0) / 3.0));
208    }
209
210    #[test]
211    fn reset_clears_state() {
212        let mut sma = Sma::new(3).unwrap();
213        sma.batch(&[1.0, 2.0, 3.0]);
214        assert!(sma.is_ready());
215        sma.reset();
216        assert!(!sma.is_ready());
217        assert_eq!(sma.update(10.0), None);
218    }
219
220    #[test]
221    fn batch_equals_streaming() {
222        let prices: Vec<f64> = (1..=20).map(f64::from).collect();
223        let mut a = Sma::new(5).unwrap();
224        let batch = a.batch(&prices);
225        let mut b = Sma::new(5).unwrap();
226        let streamed: Vec<_> = prices.iter().map(|p| b.update(*p)).collect();
227        assert_eq!(batch, streamed);
228    }
229
230    #[test]
231    fn known_reference_values() {
232        // SMA(3) of [2, 4, 6, 8, 10] -> [_, _, 4, 6, 8]
233        let mut sma = Sma::new(3).unwrap();
234        let out = sma.batch(&[2.0, 4.0, 6.0, 8.0, 10.0]);
235        assert_eq!(out[2], Some(4.0));
236        assert_eq!(out[3], Some(6.0));
237        assert_eq!(out[4], Some(8.0));
238    }
239
240    #[test]
241    fn constant_series_yields_constant_sma() {
242        let mut sma = Sma::new(5).unwrap();
243        let v = sma.batch(&[7.0; 10]);
244        for x in v.iter().skip(4) {
245            assert_relative_eq!(x.unwrap(), 7.0, epsilon = 1e-12);
246        }
247    }
248
249    proptest::proptest! {
250        #![proptest_config(proptest::test_runner::Config::with_cases(64))]
251        #[test]
252        fn sma_matches_naive_definition(
253            period in 1usize..20,
254            prices in proptest::collection::vec(-1000.0_f64..1000.0, 0..200),
255        ) {
256            let mut sma = Sma::new(period).unwrap();
257            let stream: Vec<_> = prices.iter().map(|p| sma.update(*p)).collect();
258            for (i, got) in stream.iter().enumerate() {
259                if i + 1 < period {
260                    proptest::prop_assert!(got.is_none());
261                } else {
262                    let window = &prices[i + 1 - period..=i];
263                    let expected = window.iter().sum::<f64>() / period as f64;
264                    let actual = got.expect("ready");
265                    proptest::prop_assert!(
266                        (actual - expected).abs() < 1e-9,
267                        "i={i} actual={actual} expected={expected}"
268                    );
269                }
270            }
271        }
272    }
273
274    /// Long-running stability check. Runs more updates than `RECOMPUTE_EVERY *
275    /// period` so the periodic reseed must fire several times, then asserts
276    /// that the reported SMA still equals a fresh from-scratch mean over the
277    /// live window to within tight floating-point tolerance. Inputs swing
278    /// between two magnitudes (`1e9` and `1.0`) — a pattern designed to
279    /// expose catastrophic cancellation in a naive single-subtract sum.
280    #[test]
281    fn long_stream_drift_stays_bounded() {
282        let period = 20;
283        let mut sma = Sma::new(period).unwrap();
284        let mut window: VecDeque<f64> = VecDeque::with_capacity(period);
285        // `RECOMPUTE_EVERY * period * 5` updates → recompute fires 5+ times.
286        let n_updates = 16 * period * 5;
287        for i in 0..n_updates {
288            let v = if i % 2 == 0 { 1e9 } else { 1.0 };
289            sma.update(v);
290            if window.len() == period {
291                window.pop_front();
292            }
293            window.push_back(v);
294        }
295        let from_scratch: f64 = window.iter().sum::<f64>() / period as f64;
296        let got = sma.value().expect("warmed up");
297        assert!(
298            (got - from_scratch).abs() < 1e-6,
299            "SMA drift exceeds 1e-6 over {n_updates} updates: got={got}, scratch={from_scratch}"
300        );
301    }
302}