Skip to main content

wickra_core/indicators/
ema.rs

1//! Exponential Moving Average.
2
3use crate::error::{Error, Result};
4use crate::traits::Indicator;
5
6/// Exponential Moving Average with smoothing factor `alpha = 2 / (period + 1)`.
7///
8/// The first value is seeded with the simple mean of the first `period` inputs
9/// (the classical TA-Lib convention). From then on each new input contributes
10/// `alpha * input + (1 - alpha) * previous`.
11///
12/// # Example
13///
14/// ```
15/// use wickra_core::{Indicator, Ema};
16///
17/// let mut indicator = Ema::new(3).unwrap();
18/// let mut last = None;
19/// for i in 0..80 {
20///     last = indicator.update(100.0 + f64::from(i));
21/// }
22/// assert!(last.is_some());
23/// ```
24#[derive(Debug, Clone)]
25pub struct Ema {
26    period: usize,
27    alpha: f64,
28    state: Option<f64>,
29    warmup_buf: Vec<f64>,
30}
31
32impl Ema {
33    /// Construct an EMA with the given period.
34    ///
35    /// # Errors
36    ///
37    /// Returns [`Error::PeriodZero`] if `period == 0`.
38    pub fn new(period: usize) -> Result<Self> {
39        if period == 0 {
40            return Err(Error::PeriodZero);
41        }
42        let alpha = 2.0 / (period as f64 + 1.0);
43        Ok(Self {
44            period,
45            alpha,
46            state: None,
47            warmup_buf: Vec::with_capacity(period),
48        })
49    }
50
51    /// Construct an EMA with a custom smoothing factor `alpha in (0, 1]`.
52    ///
53    /// The reported `period` is derived from `alpha` via `2/alpha - 1` and rounded;
54    /// `warmup_period()` falls back to `1` because the implementation seeds from the
55    /// very first input.
56    ///
57    /// # Errors
58    ///
59    /// Returns [`Error::InvalidPeriod`] if `alpha` is not in `(0.0, 1.0]` or non-finite.
60    pub fn with_alpha(alpha: f64) -> Result<Self> {
61        if !alpha.is_finite() || alpha <= 0.0 || alpha > 1.0 {
62            return Err(Error::InvalidPeriod {
63                message: "alpha must be in (0.0, 1.0]",
64            });
65        }
66        Ok(Self {
67            period: 1,
68            alpha,
69            state: None,
70            warmup_buf: Vec::with_capacity(1),
71        })
72    }
73
74    /// Configured period.
75    pub const fn period(&self) -> usize {
76        self.period
77    }
78
79    /// Smoothing factor.
80    pub const fn alpha(&self) -> f64 {
81        self.alpha
82    }
83
84    /// Current value if available.
85    pub const fn value(&self) -> Option<f64> {
86        self.state
87    }
88
89    /// Internal helper that feeds a value without finiteness validation. The caller
90    /// guarantees `input.is_finite()`. Used by MACD which has already validated.
91    pub(crate) fn step_unchecked(&mut self, input: f64) -> Option<f64> {
92        if let Some(prev) = self.state {
93            let new = self.alpha.mul_add(input, (1.0 - self.alpha) * prev);
94            self.state = Some(new);
95            return Some(new);
96        }
97        self.warmup_buf.push(input);
98        if self.warmup_buf.len() == self.period {
99            let seed = self.warmup_buf.iter().copied().sum::<f64>() / self.period as f64;
100            self.state = Some(seed);
101            return Some(seed);
102        }
103        None
104    }
105}
106
107impl Indicator for Ema {
108    type Input = f64;
109    type Output = f64;
110
111    fn update(&mut self, input: f64) -> Option<f64> {
112        if !input.is_finite() {
113            return self.state;
114        }
115        self.step_unchecked(input)
116    }
117
118    fn reset(&mut self) {
119        self.state = None;
120        self.warmup_buf.clear();
121    }
122
123    fn warmup_period(&self) -> usize {
124        self.period
125    }
126
127    fn is_ready(&self) -> bool {
128        self.state.is_some()
129    }
130
131    fn name(&self) -> &'static str {
132        "EMA"
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139    use crate::traits::BatchExt;
140    use approx::assert_relative_eq;
141
142    /// Independent reference: SMA-seeded EMA computed straight from the definition.
143    fn ema_naive(prices: &[f64], period: usize) -> Vec<Option<f64>> {
144        let alpha = 2.0 / (period as f64 + 1.0);
145        let mut out = Vec::with_capacity(prices.len());
146        let mut state: Option<f64> = None;
147        for (i, &p) in prices.iter().enumerate() {
148            if let Some(prev) = state {
149                let v = alpha * p + (1.0 - alpha) * prev;
150                state = Some(v);
151                out.push(Some(v));
152            } else if i + 1 == period {
153                let seed = prices[..period].iter().sum::<f64>() / period as f64;
154                state = Some(seed);
155                out.push(Some(seed));
156            } else {
157                out.push(None);
158            }
159        }
160        out
161    }
162
163    #[test]
164    fn new_rejects_zero_period() {
165        assert!(matches!(Ema::new(0), Err(Error::PeriodZero)));
166    }
167
168    /// Cover the const accessor `period` (74-77) and the Indicator-impl
169    /// `warmup_period` (123-125) + `name` (131-133). `alpha` and `value`
170    /// are exercised by other tests and downstream consumers; only the
171    /// three metadata methods were dead.
172    #[test]
173    fn accessors_and_metadata() {
174        let ema = Ema::new(14).unwrap();
175        assert_eq!(ema.period(), 14);
176        assert_eq!(ema.warmup_period(), 14);
177        assert_eq!(ema.name(), "EMA");
178    }
179
180    #[test]
181    fn warmup_returns_none_until_seed() {
182        let mut ema = Ema::new(3).unwrap();
183        assert_eq!(ema.update(1.0), None);
184        assert_eq!(ema.update(2.0), None);
185        assert_eq!(ema.update(3.0), Some(2.0)); // seed = SMA([1,2,3]) = 2
186    }
187
188    #[test]
189    fn first_value_equals_sma_seed() {
190        let mut ema = Ema::new(5).unwrap();
191        let inputs = [10.0, 20.0, 30.0, 40.0, 50.0];
192        let mut last = None;
193        for v in inputs {
194            last = ema.update(v);
195        }
196        assert_relative_eq!(last.unwrap(), 30.0, epsilon = 1e-12);
197    }
198
199    #[test]
200    fn alpha_matches_period_formula() {
201        let ema = Ema::new(10).unwrap();
202        assert_relative_eq!(ema.alpha(), 2.0 / 11.0, epsilon = 1e-15);
203    }
204
205    #[test]
206    fn step_after_seed_uses_alpha_formula() {
207        // period=3 => alpha = 0.5; seed = mean([1,2,3]) = 2; next input 10
208        // expected = 0.5*10 + 0.5*2 = 6
209        let mut ema = Ema::new(3).unwrap();
210        ema.batch(&[1.0, 2.0, 3.0]);
211        assert_relative_eq!(ema.update(10.0).unwrap(), 6.0, epsilon = 1e-12);
212    }
213
214    #[test]
215    fn constant_series_converges_to_constant() {
216        let mut ema = Ema::new(10).unwrap();
217        let out = ema.batch(&[42.0_f64; 100]);
218        for x in out.iter().skip(9) {
219            assert_relative_eq!(x.unwrap(), 42.0, epsilon = 1e-9);
220        }
221    }
222
223    #[test]
224    fn with_alpha_validates_range() {
225        assert!(Ema::with_alpha(0.5).is_ok());
226        assert!(Ema::with_alpha(1.0).is_ok());
227        assert!(matches!(
228            Ema::with_alpha(0.0),
229            Err(Error::InvalidPeriod { .. })
230        ));
231        assert!(matches!(
232            Ema::with_alpha(1.5),
233            Err(Error::InvalidPeriod { .. })
234        ));
235        assert!(matches!(
236            Ema::with_alpha(f64::NAN),
237            Err(Error::InvalidPeriod { .. })
238        ));
239    }
240
241    #[test]
242    fn reset_clears_state() {
243        let mut ema = Ema::new(3).unwrap();
244        ema.batch(&[1.0, 2.0, 3.0]);
245        assert!(ema.is_ready());
246        ema.reset();
247        assert!(!ema.is_ready());
248        assert_eq!(ema.update(1.0), None);
249    }
250
251    #[test]
252    fn batch_equals_streaming() {
253        let prices: Vec<f64> = (1..=30).map(f64::from).collect();
254        let mut a = Ema::new(5).unwrap();
255        let mut b = Ema::new(5).unwrap();
256        assert_eq!(
257            a.batch(&prices),
258            prices.iter().map(|p| b.update(*p)).collect::<Vec<_>>()
259        );
260    }
261
262    #[test]
263    fn ignores_non_finite_input() {
264        let mut ema = Ema::new(3).unwrap();
265        ema.batch(&[1.0, 2.0, 3.0]);
266        let before = ema.value();
267        assert_eq!(ema.update(f64::NAN), before);
268        assert_eq!(ema.update(f64::INFINITY), before);
269    }
270
271    proptest::proptest! {
272        #![proptest_config(proptest::test_runner::Config::with_cases(48))]
273        #[test]
274        fn ema_matches_naive(
275            period in 1usize..20,
276            prices in proptest::collection::vec(-1000.0_f64..1000.0, 0..150),
277        ) {
278            let mut ema = Ema::new(period).unwrap();
279            let got = ema.batch(&prices);
280            let want = ema_naive(&prices, period);
281            proptest::prop_assert_eq!(got.len(), want.len());
282            for (g, w) in got.iter().zip(want.iter()) {
283                match (g, w) {
284                    (None, None) => {}
285                    (Some(a), Some(b)) => proptest::prop_assert!(
286                        (a - b).abs() <= 1e-9 * a.abs().max(1.0),
287                        "got={a} want={b}"
288                    ),
289                    _ => proptest::prop_assert!(false, "warmup mismatch"),
290                }
291            }
292        }
293    }
294}