Skip to main content

wickra_core/indicators/
ema.rs

1//! Exponential Moving Average.
2
3use crate::error::{Error, Result};
4use crate::traits::Indicator;
5
6/// Exponential Moving Average with smoothing factor `alpha = 2 / (period + 1)`.
7///
8/// The first value is seeded with the simple mean of the first `period` inputs
9/// (the classical TA-Lib convention). From then on each new input contributes
10/// `alpha * input + (1 - alpha) * previous`.
11///
12/// # Example
13///
14/// ```
15/// use wickra_core::{Indicator, Ema};
16///
17/// let mut indicator = Ema::new(3).unwrap();
18/// let mut last = None;
19/// for i in 0..80 {
20///     last = indicator.update(100.0 + f64::from(i));
21/// }
22/// assert!(last.is_some());
23/// ```
24#[derive(Debug, Clone)]
25pub struct Ema {
26    period: usize,
27    alpha: f64,
28    /// `1 - alpha`, precomputed so the recurrence avoids a subtraction per tick.
29    /// Cached value, so the steady-state output is bit-for-bit unchanged.
30    one_minus_alpha: f64,
31    /// Latest EMA value, valid only once `seeded` is true. Stored as a bare `f64`
32    /// (plus the `seeded` flag) rather than `Option<f64>` so the steady-state
33    /// recurrence reads and writes 8 bytes with no enum-tag handling per tick.
34    current: f64,
35    /// Whether `current` holds a real value yet (warmup complete).
36    seeded: bool,
37    warmup_buf: Vec<f64>,
38}
39
40impl Ema {
41    /// Construct an EMA with the given period.
42    ///
43    /// # Errors
44    ///
45    /// Returns [`Error::PeriodZero`] if `period == 0`.
46    pub fn new(period: usize) -> Result<Self> {
47        if period == 0 {
48            return Err(Error::PeriodZero);
49        }
50        let alpha = 2.0 / (period as f64 + 1.0);
51        Ok(Self {
52            period,
53            alpha,
54            one_minus_alpha: 1.0 - alpha,
55            current: 0.0,
56            seeded: false,
57            warmup_buf: Vec::with_capacity(period),
58        })
59    }
60
61    /// Construct an EMA with a custom smoothing factor `alpha in (0, 1]`.
62    ///
63    /// The reported `period` is derived from `alpha` via `2/alpha - 1` and rounded;
64    /// `warmup_period()` falls back to `1` because the implementation seeds from the
65    /// very first input.
66    ///
67    /// # Errors
68    ///
69    /// Returns [`Error::InvalidPeriod`] if `alpha` is not in `(0.0, 1.0]` or non-finite.
70    pub fn with_alpha(alpha: f64) -> Result<Self> {
71        if !alpha.is_finite() || alpha <= 0.0 || alpha > 1.0 {
72            return Err(Error::InvalidPeriod {
73                message: "alpha must be in (0.0, 1.0]",
74            });
75        }
76        Ok(Self {
77            period: 1,
78            alpha,
79            one_minus_alpha: 1.0 - alpha,
80            current: 0.0,
81            seeded: false,
82            warmup_buf: Vec::with_capacity(1),
83        })
84    }
85
86    /// Configured period.
87    pub const fn period(&self) -> usize {
88        self.period
89    }
90
91    /// Smoothing factor.
92    pub const fn alpha(&self) -> f64 {
93        self.alpha
94    }
95
96    /// Current value if available.
97    pub const fn value(&self) -> Option<f64> {
98        if self.seeded {
99            Some(self.current)
100        } else {
101            None
102        }
103    }
104
105    /// Internal helper that feeds a value without finiteness validation. The caller
106    /// guarantees `input.is_finite()`. Used by MACD which has already validated.
107    pub(crate) fn step_unchecked(&mut self, input: f64) -> Option<f64> {
108        if self.seeded {
109            let new = self
110                .alpha
111                .mul_add(input, self.one_minus_alpha * self.current);
112            self.current = new;
113            return Some(new);
114        }
115        self.warmup_buf.push(input);
116        if self.warmup_buf.len() == self.period {
117            let seed = self.warmup_buf.iter().copied().sum::<f64>() / self.period as f64;
118            self.current = seed;
119            self.seeded = true;
120            return Some(seed);
121        }
122        None
123    }
124}
125
126impl Indicator for Ema {
127    type Input = f64;
128    type Output = f64;
129
130    fn update(&mut self, input: f64) -> Option<f64> {
131        if !input.is_finite() {
132            return self.value();
133        }
134        self.step_unchecked(input)
135    }
136
137    fn reset(&mut self) {
138        self.current = 0.0;
139        self.seeded = false;
140        self.warmup_buf.clear();
141    }
142
143    fn warmup_period(&self) -> usize {
144        self.period
145    }
146
147    fn is_ready(&self) -> bool {
148        self.seeded
149    }
150
151    fn name(&self) -> &'static str {
152        "EMA"
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use crate::traits::BatchExt;
160    use approx::assert_relative_eq;
161
162    /// Independent reference: SMA-seeded EMA computed straight from the definition.
163    fn ema_naive(prices: &[f64], period: usize) -> Vec<Option<f64>> {
164        let alpha = 2.0 / (period as f64 + 1.0);
165        let mut out = Vec::with_capacity(prices.len());
166        let mut state: Option<f64> = None;
167        for (i, &p) in prices.iter().enumerate() {
168            if let Some(prev) = state {
169                let v = alpha * p + (1.0 - alpha) * prev;
170                state = Some(v);
171                out.push(Some(v));
172            } else if i + 1 == period {
173                let seed = prices[..period].iter().sum::<f64>() / period as f64;
174                state = Some(seed);
175                out.push(Some(seed));
176            } else {
177                out.push(None);
178            }
179        }
180        out
181    }
182
183    #[test]
184    fn new_rejects_zero_period() {
185        assert!(matches!(Ema::new(0), Err(Error::PeriodZero)));
186    }
187
188    /// Cover the const accessor `period` (74-77) and the Indicator-impl
189    /// `warmup_period` (123-125) + `name` (131-133). `alpha` and `value`
190    /// are exercised by other tests and downstream consumers; only the
191    /// three metadata methods were dead.
192    #[test]
193    fn accessors_and_metadata() {
194        let ema = Ema::new(14).unwrap();
195        assert_eq!(ema.period(), 14);
196        assert_eq!(ema.warmup_period(), 14);
197        assert_eq!(ema.name(), "EMA");
198    }
199
200    #[test]
201    fn warmup_returns_none_until_seed() {
202        let mut ema = Ema::new(3).unwrap();
203        assert_eq!(ema.update(1.0), None);
204        assert_eq!(ema.update(2.0), None);
205        assert_eq!(ema.update(3.0), Some(2.0)); // seed = SMA([1,2,3]) = 2
206    }
207
208    #[test]
209    fn first_value_equals_sma_seed() {
210        let mut ema = Ema::new(5).unwrap();
211        let inputs = [10.0, 20.0, 30.0, 40.0, 50.0];
212        let mut last = None;
213        for v in inputs {
214            last = ema.update(v);
215        }
216        assert_relative_eq!(last.unwrap(), 30.0, epsilon = 1e-12);
217    }
218
219    #[test]
220    fn alpha_matches_period_formula() {
221        let ema = Ema::new(10).unwrap();
222        assert_relative_eq!(ema.alpha(), 2.0 / 11.0, epsilon = 1e-15);
223    }
224
225    #[test]
226    fn step_after_seed_uses_alpha_formula() {
227        // period=3 => alpha = 0.5; seed = mean([1,2,3]) = 2; next input 10
228        // expected = 0.5*10 + 0.5*2 = 6
229        let mut ema = Ema::new(3).unwrap();
230        ema.batch(&[1.0, 2.0, 3.0]);
231        assert_relative_eq!(ema.update(10.0).unwrap(), 6.0, epsilon = 1e-12);
232    }
233
234    #[test]
235    fn constant_series_converges_to_constant() {
236        let mut ema = Ema::new(10).unwrap();
237        let out = ema.batch(&[42.0_f64; 100]);
238        for x in out.iter().skip(9) {
239            assert_relative_eq!(x.unwrap(), 42.0, epsilon = 1e-9);
240        }
241    }
242
243    #[test]
244    fn with_alpha_validates_range() {
245        assert!(Ema::with_alpha(0.5).is_ok());
246        assert!(Ema::with_alpha(1.0).is_ok());
247        assert!(matches!(
248            Ema::with_alpha(0.0),
249            Err(Error::InvalidPeriod { .. })
250        ));
251        assert!(matches!(
252            Ema::with_alpha(1.5),
253            Err(Error::InvalidPeriod { .. })
254        ));
255        assert!(matches!(
256            Ema::with_alpha(f64::NAN),
257            Err(Error::InvalidPeriod { .. })
258        ));
259    }
260
261    #[test]
262    fn reset_clears_state() {
263        let mut ema = Ema::new(3).unwrap();
264        ema.batch(&[1.0, 2.0, 3.0]);
265        assert!(ema.is_ready());
266        ema.reset();
267        assert!(!ema.is_ready());
268        assert_eq!(ema.update(1.0), None);
269    }
270
271    #[test]
272    fn batch_equals_streaming() {
273        let prices: Vec<f64> = (1..=30).map(f64::from).collect();
274        let mut a = Ema::new(5).unwrap();
275        let mut b = Ema::new(5).unwrap();
276        assert_eq!(
277            a.batch(&prices),
278            prices.iter().map(|p| b.update(*p)).collect::<Vec<_>>()
279        );
280    }
281
282    #[test]
283    fn ignores_non_finite_input() {
284        let mut ema = Ema::new(3).unwrap();
285        ema.batch(&[1.0, 2.0, 3.0]);
286        let before = ema.value();
287        assert_eq!(ema.update(f64::NAN), before);
288        assert_eq!(ema.update(f64::INFINITY), before);
289    }
290
291    proptest::proptest! {
292        #![proptest_config(proptest::test_runner::Config::with_cases(48))]
293        #[test]
294        fn ema_matches_naive(
295            period in 1usize..20,
296            prices in proptest::collection::vec(-1000.0_f64..1000.0, 0..150),
297        ) {
298            let mut ema = Ema::new(period).unwrap();
299            let got = ema.batch(&prices);
300            let want = ema_naive(&prices, period);
301            proptest::prop_assert_eq!(got.len(), want.len());
302            for (g, w) in got.iter().zip(want.iter()) {
303                match (g, w) {
304                    (None, None) => {}
305                    (Some(a), Some(b)) => proptest::prop_assert!(
306                        (a - b).abs() <= 1e-9 * a.abs().max(1.0),
307                        "got={a} want={b}"
308                    ),
309                    _ => proptest::prop_assert!(false, "warmup mismatch"),
310                }
311            }
312        }
313    }
314}