Skip to main content

wickra_core/indicators/
ema.rs

1//! Exponential Moving Average.
2
3use crate::error::{Error, Result};
4use crate::traits::Indicator;
5
6/// Exponential Moving Average with smoothing factor `alpha = 2 / (period + 1)`.
7///
8/// The first value is seeded with the simple mean of the first `period` inputs
9/// (the classical TA-Lib convention). From then on each new input contributes
10/// `alpha * input + (1 - alpha) * previous`.
11///
12/// # Example
13///
14/// ```
15/// use wickra_core::{Indicator, Ema};
16///
17/// let mut indicator = Ema::new(3).unwrap();
18/// let mut last = None;
19/// for i in 0..80 {
20///     last = indicator.update(100.0 + f64::from(i));
21/// }
22/// assert!(last.is_some());
23/// ```
24#[derive(Debug, Clone)]
25pub struct Ema {
26    period: usize,
27    alpha: f64,
28    /// `1 - alpha`, precomputed so the recurrence avoids a subtraction per tick.
29    /// Cached value, so the steady-state output is bit-for-bit unchanged.
30    one_minus_alpha: f64,
31    /// Latest EMA value, valid only once `seeded` is true. Stored as a bare `f64`
32    /// (plus the `seeded` flag) rather than `Option<f64>` so the steady-state
33    /// recurrence reads and writes 8 bytes with no enum-tag handling per tick.
34    current: f64,
35    /// Whether `current` holds a real value yet (warmup complete).
36    seeded: bool,
37    warmup_buf: Vec<f64>,
38}
39
40impl Ema {
41    /// Construct an EMA with the given period.
42    ///
43    /// # Errors
44    ///
45    /// Returns [`Error::PeriodZero`] if `period == 0`.
46    pub fn new(period: usize) -> Result<Self> {
47        if period == 0 {
48            return Err(Error::PeriodZero);
49        }
50        let alpha = 2.0 / (period as f64 + 1.0);
51        Ok(Self {
52            period,
53            alpha,
54            one_minus_alpha: 1.0 - alpha,
55            current: 0.0,
56            seeded: false,
57            warmup_buf: Vec::with_capacity(period),
58        })
59    }
60
61    /// Construct an EMA with a custom smoothing factor `alpha in (0, 1]`.
62    ///
63    /// The reported `period` is derived from `alpha` via `2/alpha - 1` and rounded;
64    /// `warmup_period()` falls back to `1` because the implementation seeds from the
65    /// very first input.
66    ///
67    /// # Errors
68    ///
69    /// Returns [`Error::InvalidPeriod`] if `alpha` is not in `(0.0, 1.0]` or non-finite.
70    pub fn with_alpha(alpha: f64) -> Result<Self> {
71        if !alpha.is_finite() || alpha <= 0.0 || alpha > 1.0 {
72            return Err(Error::InvalidPeriod {
73                message: "alpha must be in (0.0, 1.0]",
74            });
75        }
76        Ok(Self {
77            period: 1,
78            alpha,
79            one_minus_alpha: 1.0 - alpha,
80            current: 0.0,
81            seeded: false,
82            warmup_buf: Vec::with_capacity(1),
83        })
84    }
85
86    /// Configured period.
87    pub const fn period(&self) -> usize {
88        self.period
89    }
90
91    /// Smoothing factor.
92    pub const fn alpha(&self) -> f64 {
93        self.alpha
94    }
95
96    /// Current value if available.
97    pub const fn value(&self) -> Option<f64> {
98        if self.seeded {
99            Some(self.current)
100        } else {
101            None
102        }
103    }
104
105    /// Whether the EMA has seen no input yet (neither seeded nor mid-warmup).
106    /// Lets composite indicators (e.g. MACD) decide if a fast batch path is safe.
107    pub(crate) fn is_fresh(&self) -> bool {
108        !self.seeded && self.warmup_buf.is_empty()
109    }
110
111    /// Force the EMA into its seeded steady state with `current` as the latest
112    /// value. Used by composite fused batch paths (MACD) to leave each sub-EMA
113    /// where a per-tick `update` replay would, so a later `update` continues
114    /// correctly. The post-seed recurrence never re-reads `warmup_buf`, so it is
115    /// left as-is.
116    pub(crate) fn seed_to(&mut self, current: f64) {
117        self.current = current;
118        self.seeded = true;
119    }
120
121    /// Vectorized batch returning one `f64` per input (`NaN` during warmup).
122    ///
123    /// Shadows the generic [`BatchNanExt::batch_nan`](crate::BatchNanExt) blanket
124    /// default via inherent-method resolution. For a fresh indicator over an
125    /// all-finite slice it runs the seed (mean of the first `period`) once and
126    /// then the bare `alpha * x + (1 - alpha) * prev` recurrence in a tight loop
127    /// with no per-element `is_finite`/`seeded` branch and no `Option` — yet uses
128    /// the identical `mul_add`, so the result is *bit-for-bit* equal to replaying
129    /// `update`. Any other state, or a non-finite element, defers to the exact
130    /// `update` replay.
131    pub fn batch_nan(&mut self, inputs: &[f64]) -> Vec<f64> {
132        let p = self.period;
133        if self.seeded || !self.warmup_buf.is_empty() || !inputs.iter().all(|x| x.is_finite()) {
134            return inputs
135                .iter()
136                .map(|&x| self.update(x).unwrap_or(f64::NAN))
137                .collect();
138        }
139
140        let n = inputs.len();
141        if n < p {
142            // Not enough to seed; mirror `update` stashing inputs for warmup.
143            self.warmup_buf.extend_from_slice(inputs);
144            return vec![f64::NAN; n];
145        }
146
147        // Warmup `[0, p-1)` is `NaN`; values from the seed on are pushed once each.
148        let mut out = vec![f64::NAN; p - 1];
149        out.reserve(n - (p - 1));
150        let seed = inputs[..p].iter().copied().sum::<f64>() / p as f64;
151        let mut cur = seed;
152        out.push(seed);
153        let (alpha, oma) = (self.alpha, self.one_minus_alpha);
154        for &x in &inputs[p..] {
155            cur = alpha.mul_add(x, oma * cur);
156            out.push(cur);
157        }
158
159        // Leave state exactly where `update` would: seeded on `current`, with the
160        // first `period` inputs retained in `warmup_buf` (never cleared post-seed).
161        self.current = cur;
162        self.seeded = true;
163        self.warmup_buf.extend_from_slice(&inputs[..p]);
164        out
165    }
166
167    /// Internal helper that feeds a value without finiteness validation. The caller
168    /// guarantees `input.is_finite()`. Used by MACD which has already validated.
169    pub(crate) fn step_unchecked(&mut self, input: f64) -> Option<f64> {
170        if self.seeded {
171            let new = self
172                .alpha
173                .mul_add(input, self.one_minus_alpha * self.current);
174            self.current = new;
175            return Some(new);
176        }
177        self.warmup_buf.push(input);
178        if self.warmup_buf.len() == self.period {
179            let seed = self.warmup_buf.iter().copied().sum::<f64>() / self.period as f64;
180            self.current = seed;
181            self.seeded = true;
182            return Some(seed);
183        }
184        None
185    }
186}
187
188impl Indicator for Ema {
189    type Input = f64;
190    type Output = f64;
191
192    fn update(&mut self, input: f64) -> Option<f64> {
193        if !input.is_finite() {
194            return self.value();
195        }
196        self.step_unchecked(input)
197    }
198
199    fn reset(&mut self) {
200        self.current = 0.0;
201        self.seeded = false;
202        self.warmup_buf.clear();
203    }
204
205    fn warmup_period(&self) -> usize {
206        self.period
207    }
208
209    fn is_ready(&self) -> bool {
210        self.seeded
211    }
212
213    fn name(&self) -> &'static str {
214        "EMA"
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221    use crate::traits::BatchExt;
222    use approx::assert_relative_eq;
223
224    /// Independent reference: SMA-seeded EMA computed straight from the definition.
225    fn ema_naive(prices: &[f64], period: usize) -> Vec<Option<f64>> {
226        let alpha = 2.0 / (period as f64 + 1.0);
227        let mut out = Vec::with_capacity(prices.len());
228        let mut state: Option<f64> = None;
229        for (i, &p) in prices.iter().enumerate() {
230            if let Some(prev) = state {
231                let v = alpha * p + (1.0 - alpha) * prev;
232                state = Some(v);
233                out.push(Some(v));
234            } else if i + 1 == period {
235                let seed = prices[..period].iter().sum::<f64>() / period as f64;
236                state = Some(seed);
237                out.push(Some(seed));
238            } else {
239                out.push(None);
240            }
241        }
242        out
243    }
244
245    #[test]
246    fn new_rejects_zero_period() {
247        assert!(matches!(Ema::new(0), Err(Error::PeriodZero)));
248    }
249
250    /// Cover the const accessor `period` (74-77) and the Indicator-impl
251    /// `warmup_period` (123-125) + `name` (131-133). `alpha` and `value`
252    /// are exercised by other tests and downstream consumers; only the
253    /// three metadata methods were dead.
254    #[test]
255    fn accessors_and_metadata() {
256        let ema = Ema::new(14).unwrap();
257        assert_eq!(ema.period(), 14);
258        assert_eq!(ema.warmup_period(), 14);
259        assert_eq!(ema.name(), "EMA");
260    }
261
262    #[test]
263    fn warmup_returns_none_until_seed() {
264        let mut ema = Ema::new(3).unwrap();
265        assert_eq!(ema.update(1.0), None);
266        assert_eq!(ema.update(2.0), None);
267        assert_eq!(ema.update(3.0), Some(2.0)); // seed = SMA([1,2,3]) = 2
268    }
269
270    #[test]
271    fn first_value_equals_sma_seed() {
272        let mut ema = Ema::new(5).unwrap();
273        let inputs = [10.0, 20.0, 30.0, 40.0, 50.0];
274        let mut last = None;
275        for v in inputs {
276            last = ema.update(v);
277        }
278        assert_relative_eq!(last.unwrap(), 30.0, epsilon = 1e-12);
279    }
280
281    #[test]
282    fn alpha_matches_period_formula() {
283        let ema = Ema::new(10).unwrap();
284        assert_relative_eq!(ema.alpha(), 2.0 / 11.0, epsilon = 1e-15);
285    }
286
287    #[test]
288    fn step_after_seed_uses_alpha_formula() {
289        // period=3 => alpha = 0.5; seed = mean([1,2,3]) = 2; next input 10
290        // expected = 0.5*10 + 0.5*2 = 6
291        let mut ema = Ema::new(3).unwrap();
292        ema.batch(&[1.0, 2.0, 3.0]);
293        assert_relative_eq!(ema.update(10.0).unwrap(), 6.0, epsilon = 1e-12);
294    }
295
296    #[test]
297    fn constant_series_converges_to_constant() {
298        let mut ema = Ema::new(10).unwrap();
299        let out = ema.batch(&[42.0_f64; 100]);
300        for x in out.iter().skip(9) {
301            assert_relative_eq!(x.unwrap(), 42.0, epsilon = 1e-9);
302        }
303    }
304
305    #[test]
306    fn with_alpha_validates_range() {
307        assert!(Ema::with_alpha(0.5).is_ok());
308        assert!(Ema::with_alpha(1.0).is_ok());
309        assert!(matches!(
310            Ema::with_alpha(0.0),
311            Err(Error::InvalidPeriod { .. })
312        ));
313        assert!(matches!(
314            Ema::with_alpha(1.5),
315            Err(Error::InvalidPeriod { .. })
316        ));
317        assert!(matches!(
318            Ema::with_alpha(f64::NAN),
319            Err(Error::InvalidPeriod { .. })
320        ));
321    }
322
323    #[test]
324    fn reset_clears_state() {
325        let mut ema = Ema::new(3).unwrap();
326        ema.batch(&[1.0, 2.0, 3.0]);
327        assert!(ema.is_ready());
328        ema.reset();
329        assert!(!ema.is_ready());
330        assert_eq!(ema.update(1.0), None);
331    }
332
333    #[test]
334    fn batch_equals_streaming() {
335        let prices: Vec<f64> = (1..=30).map(f64::from).collect();
336        let mut a = Ema::new(5).unwrap();
337        let mut b = Ema::new(5).unwrap();
338        assert_eq!(
339            a.batch(&prices),
340            prices.iter().map(|p| b.update(*p)).collect::<Vec<_>>()
341        );
342    }
343
344    #[test]
345    fn ignores_non_finite_input() {
346        let mut ema = Ema::new(3).unwrap();
347        ema.batch(&[1.0, 2.0, 3.0]);
348        let before = ema.value();
349        assert_eq!(ema.update(f64::NAN), before);
350        assert_eq!(ema.update(f64::INFINITY), before);
351    }
352
353    fn bits_eq(a: &[f64], b: &[f64]) -> bool {
354        a.len() == b.len()
355            && a.iter()
356                .zip(b)
357                .all(|(x, y)| x == y || (x.is_nan() && y.is_nan()))
358    }
359
360    fn ema_replay(period: usize, series: &[f64]) -> Vec<f64> {
361        let mut e = Ema::new(period).unwrap();
362        series
363            .iter()
364            .map(|&x| e.update(x).unwrap_or(f64::NAN))
365            .collect()
366    }
367
368    #[test]
369    fn batch_nan_fast_path_is_bit_identical() {
370        let series: Vec<f64> = (0..300)
371            .map(|i| (f64::from(i) * 0.25).cos() * 8.0 + 40.0)
372            .collect();
373        let mut ema = Ema::new(14).unwrap();
374        let got = ema.batch_nan(&series);
375        assert!(bits_eq(&got, &ema_replay(14, &series)));
376        let mut ref_ema = Ema::new(14).unwrap();
377        for &x in &series {
378            ref_ema.update(x);
379        }
380        assert_eq!(ema.update(7.5), ref_ema.update(7.5));
381    }
382
383    #[test]
384    fn batch_nan_falls_back_on_non_finite() {
385        let series = [1.0, 2.0, 3.0, f64::INFINITY, 5.0, 6.0, 7.0];
386        let mut ema = Ema::new(3).unwrap();
387        assert!(bits_eq(&ema.batch_nan(&series), &ema_replay(3, &series)));
388    }
389
390    #[test]
391    fn batch_nan_falls_back_when_warming() {
392        let mut ema = Ema::new(3).unwrap();
393        ema.update(10.0); // mid-warmup: warmup_buf non-empty, not seeded
394        let series = [1.0, 2.0, 3.0, 4.0];
395        let mut ref_ema = Ema::new(3).unwrap();
396        ref_ema.update(10.0);
397        let want: Vec<f64> = series
398            .iter()
399            .map(|&x| ref_ema.update(x).unwrap_or(f64::NAN))
400            .collect();
401        assert!(bits_eq(&ema.batch_nan(&series), &want));
402    }
403
404    #[test]
405    fn batch_nan_sub_period_slice_stays_unseeded() {
406        let series = [1.0, 2.0];
407        let mut ema = Ema::new(5).unwrap();
408        let got = ema.batch_nan(&series);
409        assert!(got.iter().all(|x| x.is_nan()) && got.len() == 2);
410        assert!(!ema.is_ready());
411        // Warmup state was stashed: feeding the rest seeds exactly as a full stream.
412        assert!(bits_eq(
413            &[ema.update(3.0).unwrap_or(f64::NAN)],
414            &[ema_replay(5, &[1.0, 2.0, 3.0])[2]]
415        ));
416    }
417
418    proptest::proptest! {
419        #![proptest_config(proptest::test_runner::Config::with_cases(48))]
420        #[test]
421        fn ema_matches_naive(
422            period in 1usize..20,
423            prices in proptest::collection::vec(-1000.0_f64..1000.0, 0..150),
424        ) {
425            let mut ema = Ema::new(period).unwrap();
426            let got = ema.batch(&prices);
427            let want = ema_naive(&prices, period);
428            proptest::prop_assert_eq!(got.len(), want.len());
429            for (g, w) in got.iter().zip(want.iter()) {
430                match (g, w) {
431                    (None, None) => {}
432                    (Some(a), Some(b)) => proptest::prop_assert!(
433                        (a - b).abs() <= 1e-9 * a.abs().max(1.0),
434                        "got={a} want={b}"
435                    ),
436                    _ => proptest::prop_assert!(false, "warmup mismatch"),
437                }
438            }
439        }
440    }
441}