1use crate::error::{Error, Result};
4use crate::traits::Indicator;
5
6#[derive(Debug, Clone)]
25pub struct Ema {
26 period: usize,
27 alpha: f64,
28 one_minus_alpha: f64,
31 current: f64,
35 seeded: bool,
37 warmup_buf: Vec<f64>,
38}
39
40impl Ema {
41 pub fn new(period: usize) -> Result<Self> {
47 if period == 0 {
48 return Err(Error::PeriodZero);
49 }
50 let alpha = 2.0 / (period as f64 + 1.0);
51 Ok(Self {
52 period,
53 alpha,
54 one_minus_alpha: 1.0 - alpha,
55 current: 0.0,
56 seeded: false,
57 warmup_buf: Vec::with_capacity(period),
58 })
59 }
60
61 pub fn with_alpha(alpha: f64) -> Result<Self> {
71 if !alpha.is_finite() || alpha <= 0.0 || alpha > 1.0 {
72 return Err(Error::InvalidPeriod {
73 message: "alpha must be in (0.0, 1.0]",
74 });
75 }
76 Ok(Self {
77 period: 1,
78 alpha,
79 one_minus_alpha: 1.0 - alpha,
80 current: 0.0,
81 seeded: false,
82 warmup_buf: Vec::with_capacity(1),
83 })
84 }
85
86 pub const fn period(&self) -> usize {
88 self.period
89 }
90
91 pub const fn alpha(&self) -> f64 {
93 self.alpha
94 }
95
96 pub const fn value(&self) -> Option<f64> {
98 if self.seeded {
99 Some(self.current)
100 } else {
101 None
102 }
103 }
104
105 pub(crate) fn is_fresh(&self) -> bool {
108 !self.seeded && self.warmup_buf.is_empty()
109 }
110
111 pub(crate) fn seed_to(&mut self, current: f64) {
117 self.current = current;
118 self.seeded = true;
119 }
120
121 pub fn batch_nan(&mut self, inputs: &[f64]) -> Vec<f64> {
132 let p = self.period;
133 if self.seeded || !self.warmup_buf.is_empty() || !inputs.iter().all(|x| x.is_finite()) {
134 return inputs
135 .iter()
136 .map(|&x| self.update(x).unwrap_or(f64::NAN))
137 .collect();
138 }
139
140 let n = inputs.len();
141 if n < p {
142 self.warmup_buf.extend_from_slice(inputs);
144 return vec![f64::NAN; n];
145 }
146
147 let mut out = vec![f64::NAN; p - 1];
149 out.reserve(n - (p - 1));
150 let seed = inputs[..p].iter().copied().sum::<f64>() / p as f64;
151 let mut cur = seed;
152 out.push(seed);
153 let (alpha, oma) = (self.alpha, self.one_minus_alpha);
154 for &x in &inputs[p..] {
155 cur = alpha.mul_add(x, oma * cur);
156 out.push(cur);
157 }
158
159 self.current = cur;
162 self.seeded = true;
163 self.warmup_buf.extend_from_slice(&inputs[..p]);
164 out
165 }
166
167 pub(crate) fn step_unchecked(&mut self, input: f64) -> Option<f64> {
170 if self.seeded {
171 let new = self
172 .alpha
173 .mul_add(input, self.one_minus_alpha * self.current);
174 self.current = new;
175 return Some(new);
176 }
177 self.warmup_buf.push(input);
178 if self.warmup_buf.len() == self.period {
179 let seed = self.warmup_buf.iter().copied().sum::<f64>() / self.period as f64;
180 self.current = seed;
181 self.seeded = true;
182 return Some(seed);
183 }
184 None
185 }
186}
187
188impl Indicator for Ema {
189 type Input = f64;
190 type Output = f64;
191
192 fn update(&mut self, input: f64) -> Option<f64> {
193 if !input.is_finite() {
194 return self.value();
195 }
196 self.step_unchecked(input)
197 }
198
199 fn reset(&mut self) {
200 self.current = 0.0;
201 self.seeded = false;
202 self.warmup_buf.clear();
203 }
204
205 fn warmup_period(&self) -> usize {
206 self.period
207 }
208
209 fn is_ready(&self) -> bool {
210 self.seeded
211 }
212
213 fn name(&self) -> &'static str {
214 "EMA"
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221 use crate::traits::BatchExt;
222 use approx::assert_relative_eq;
223
224 fn ema_naive(prices: &[f64], period: usize) -> Vec<Option<f64>> {
226 let alpha = 2.0 / (period as f64 + 1.0);
227 let mut out = Vec::with_capacity(prices.len());
228 let mut state: Option<f64> = None;
229 for (i, &p) in prices.iter().enumerate() {
230 if let Some(prev) = state {
231 let v = alpha * p + (1.0 - alpha) * prev;
232 state = Some(v);
233 out.push(Some(v));
234 } else if i + 1 == period {
235 let seed = prices[..period].iter().sum::<f64>() / period as f64;
236 state = Some(seed);
237 out.push(Some(seed));
238 } else {
239 out.push(None);
240 }
241 }
242 out
243 }
244
245 #[test]
246 fn new_rejects_zero_period() {
247 assert!(matches!(Ema::new(0), Err(Error::PeriodZero)));
248 }
249
250 #[test]
255 fn accessors_and_metadata() {
256 let ema = Ema::new(14).unwrap();
257 assert_eq!(ema.period(), 14);
258 assert_eq!(ema.warmup_period(), 14);
259 assert_eq!(ema.name(), "EMA");
260 }
261
262 #[test]
263 fn warmup_returns_none_until_seed() {
264 let mut ema = Ema::new(3).unwrap();
265 assert_eq!(ema.update(1.0), None);
266 assert_eq!(ema.update(2.0), None);
267 assert_eq!(ema.update(3.0), Some(2.0)); }
269
270 #[test]
271 fn first_value_equals_sma_seed() {
272 let mut ema = Ema::new(5).unwrap();
273 let inputs = [10.0, 20.0, 30.0, 40.0, 50.0];
274 let mut last = None;
275 for v in inputs {
276 last = ema.update(v);
277 }
278 assert_relative_eq!(last.unwrap(), 30.0, epsilon = 1e-12);
279 }
280
281 #[test]
282 fn alpha_matches_period_formula() {
283 let ema = Ema::new(10).unwrap();
284 assert_relative_eq!(ema.alpha(), 2.0 / 11.0, epsilon = 1e-15);
285 }
286
287 #[test]
288 fn step_after_seed_uses_alpha_formula() {
289 let mut ema = Ema::new(3).unwrap();
292 ema.batch(&[1.0, 2.0, 3.0]);
293 assert_relative_eq!(ema.update(10.0).unwrap(), 6.0, epsilon = 1e-12);
294 }
295
296 #[test]
297 fn constant_series_converges_to_constant() {
298 let mut ema = Ema::new(10).unwrap();
299 let out = ema.batch(&[42.0_f64; 100]);
300 for x in out.iter().skip(9) {
301 assert_relative_eq!(x.unwrap(), 42.0, epsilon = 1e-9);
302 }
303 }
304
305 #[test]
306 fn with_alpha_validates_range() {
307 assert!(Ema::with_alpha(0.5).is_ok());
308 assert!(Ema::with_alpha(1.0).is_ok());
309 assert!(matches!(
310 Ema::with_alpha(0.0),
311 Err(Error::InvalidPeriod { .. })
312 ));
313 assert!(matches!(
314 Ema::with_alpha(1.5),
315 Err(Error::InvalidPeriod { .. })
316 ));
317 assert!(matches!(
318 Ema::with_alpha(f64::NAN),
319 Err(Error::InvalidPeriod { .. })
320 ));
321 }
322
323 #[test]
324 fn reset_clears_state() {
325 let mut ema = Ema::new(3).unwrap();
326 ema.batch(&[1.0, 2.0, 3.0]);
327 assert!(ema.is_ready());
328 ema.reset();
329 assert!(!ema.is_ready());
330 assert_eq!(ema.update(1.0), None);
331 }
332
333 #[test]
334 fn batch_equals_streaming() {
335 let prices: Vec<f64> = (1..=30).map(f64::from).collect();
336 let mut a = Ema::new(5).unwrap();
337 let mut b = Ema::new(5).unwrap();
338 assert_eq!(
339 a.batch(&prices),
340 prices.iter().map(|p| b.update(*p)).collect::<Vec<_>>()
341 );
342 }
343
344 #[test]
345 fn ignores_non_finite_input() {
346 let mut ema = Ema::new(3).unwrap();
347 ema.batch(&[1.0, 2.0, 3.0]);
348 let before = ema.value();
349 assert_eq!(ema.update(f64::NAN), before);
350 assert_eq!(ema.update(f64::INFINITY), before);
351 }
352
353 fn bits_eq(a: &[f64], b: &[f64]) -> bool {
354 a.len() == b.len()
355 && a.iter()
356 .zip(b)
357 .all(|(x, y)| x == y || (x.is_nan() && y.is_nan()))
358 }
359
360 fn ema_replay(period: usize, series: &[f64]) -> Vec<f64> {
361 let mut e = Ema::new(period).unwrap();
362 series
363 .iter()
364 .map(|&x| e.update(x).unwrap_or(f64::NAN))
365 .collect()
366 }
367
368 #[test]
369 fn batch_nan_fast_path_is_bit_identical() {
370 let series: Vec<f64> = (0..300)
371 .map(|i| (f64::from(i) * 0.25).cos() * 8.0 + 40.0)
372 .collect();
373 let mut ema = Ema::new(14).unwrap();
374 let got = ema.batch_nan(&series);
375 assert!(bits_eq(&got, &ema_replay(14, &series)));
376 let mut ref_ema = Ema::new(14).unwrap();
377 for &x in &series {
378 ref_ema.update(x);
379 }
380 assert_eq!(ema.update(7.5), ref_ema.update(7.5));
381 }
382
383 #[test]
384 fn batch_nan_falls_back_on_non_finite() {
385 let series = [1.0, 2.0, 3.0, f64::INFINITY, 5.0, 6.0, 7.0];
386 let mut ema = Ema::new(3).unwrap();
387 assert!(bits_eq(&ema.batch_nan(&series), &ema_replay(3, &series)));
388 }
389
390 #[test]
391 fn batch_nan_falls_back_when_warming() {
392 let mut ema = Ema::new(3).unwrap();
393 ema.update(10.0); let series = [1.0, 2.0, 3.0, 4.0];
395 let mut ref_ema = Ema::new(3).unwrap();
396 ref_ema.update(10.0);
397 let want: Vec<f64> = series
398 .iter()
399 .map(|&x| ref_ema.update(x).unwrap_or(f64::NAN))
400 .collect();
401 assert!(bits_eq(&ema.batch_nan(&series), &want));
402 }
403
404 #[test]
405 fn batch_nan_sub_period_slice_stays_unseeded() {
406 let series = [1.0, 2.0];
407 let mut ema = Ema::new(5).unwrap();
408 let got = ema.batch_nan(&series);
409 assert!(got.iter().all(|x| x.is_nan()) && got.len() == 2);
410 assert!(!ema.is_ready());
411 assert!(bits_eq(
413 &[ema.update(3.0).unwrap_or(f64::NAN)],
414 &[ema_replay(5, &[1.0, 2.0, 3.0])[2]]
415 ));
416 }
417
418 proptest::proptest! {
419 #![proptest_config(proptest::test_runner::Config::with_cases(48))]
420 #[test]
421 fn ema_matches_naive(
422 period in 1usize..20,
423 prices in proptest::collection::vec(-1000.0_f64..1000.0, 0..150),
424 ) {
425 let mut ema = Ema::new(period).unwrap();
426 let got = ema.batch(&prices);
427 let want = ema_naive(&prices, period);
428 proptest::prop_assert_eq!(got.len(), want.len());
429 for (g, w) in got.iter().zip(want.iter()) {
430 match (g, w) {
431 (None, None) => {}
432 (Some(a), Some(b)) => proptest::prop_assert!(
433 (a - b).abs() <= 1e-9 * a.abs().max(1.0),
434 "got={a} want={b}"
435 ),
436 _ => proptest::prop_assert!(false, "warmup mismatch"),
437 }
438 }
439 }
440 }
441}