wickra_core/indicators/
ema.rs1use crate::error::{Error, Result};
4use crate::traits::Indicator;
5
6#[derive(Debug, Clone)]
25pub struct Ema {
26 period: usize,
27 alpha: f64,
28 one_minus_alpha: f64,
31 current: f64,
35 seeded: bool,
37 warmup_buf: Vec<f64>,
38}
39
40impl Ema {
41 pub fn new(period: usize) -> Result<Self> {
47 if period == 0 {
48 return Err(Error::PeriodZero);
49 }
50 let alpha = 2.0 / (period as f64 + 1.0);
51 Ok(Self {
52 period,
53 alpha,
54 one_minus_alpha: 1.0 - alpha,
55 current: 0.0,
56 seeded: false,
57 warmup_buf: Vec::with_capacity(period),
58 })
59 }
60
61 pub fn with_alpha(alpha: f64) -> Result<Self> {
71 if !alpha.is_finite() || alpha <= 0.0 || alpha > 1.0 {
72 return Err(Error::InvalidPeriod {
73 message: "alpha must be in (0.0, 1.0]",
74 });
75 }
76 Ok(Self {
77 period: 1,
78 alpha,
79 one_minus_alpha: 1.0 - alpha,
80 current: 0.0,
81 seeded: false,
82 warmup_buf: Vec::with_capacity(1),
83 })
84 }
85
86 pub const fn period(&self) -> usize {
88 self.period
89 }
90
91 pub const fn alpha(&self) -> f64 {
93 self.alpha
94 }
95
96 pub const fn value(&self) -> Option<f64> {
98 if self.seeded {
99 Some(self.current)
100 } else {
101 None
102 }
103 }
104
105 pub(crate) fn step_unchecked(&mut self, input: f64) -> Option<f64> {
108 if self.seeded {
109 let new = self
110 .alpha
111 .mul_add(input, self.one_minus_alpha * self.current);
112 self.current = new;
113 return Some(new);
114 }
115 self.warmup_buf.push(input);
116 if self.warmup_buf.len() == self.period {
117 let seed = self.warmup_buf.iter().copied().sum::<f64>() / self.period as f64;
118 self.current = seed;
119 self.seeded = true;
120 return Some(seed);
121 }
122 None
123 }
124}
125
126impl Indicator for Ema {
127 type Input = f64;
128 type Output = f64;
129
130 fn update(&mut self, input: f64) -> Option<f64> {
131 if !input.is_finite() {
132 return self.value();
133 }
134 self.step_unchecked(input)
135 }
136
137 fn reset(&mut self) {
138 self.current = 0.0;
139 self.seeded = false;
140 self.warmup_buf.clear();
141 }
142
143 fn warmup_period(&self) -> usize {
144 self.period
145 }
146
147 fn is_ready(&self) -> bool {
148 self.seeded
149 }
150
151 fn name(&self) -> &'static str {
152 "EMA"
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159 use crate::traits::BatchExt;
160 use approx::assert_relative_eq;
161
162 fn ema_naive(prices: &[f64], period: usize) -> Vec<Option<f64>> {
164 let alpha = 2.0 / (period as f64 + 1.0);
165 let mut out = Vec::with_capacity(prices.len());
166 let mut state: Option<f64> = None;
167 for (i, &p) in prices.iter().enumerate() {
168 if let Some(prev) = state {
169 let v = alpha * p + (1.0 - alpha) * prev;
170 state = Some(v);
171 out.push(Some(v));
172 } else if i + 1 == period {
173 let seed = prices[..period].iter().sum::<f64>() / period as f64;
174 state = Some(seed);
175 out.push(Some(seed));
176 } else {
177 out.push(None);
178 }
179 }
180 out
181 }
182
183 #[test]
184 fn new_rejects_zero_period() {
185 assert!(matches!(Ema::new(0), Err(Error::PeriodZero)));
186 }
187
188 #[test]
193 fn accessors_and_metadata() {
194 let ema = Ema::new(14).unwrap();
195 assert_eq!(ema.period(), 14);
196 assert_eq!(ema.warmup_period(), 14);
197 assert_eq!(ema.name(), "EMA");
198 }
199
200 #[test]
201 fn warmup_returns_none_until_seed() {
202 let mut ema = Ema::new(3).unwrap();
203 assert_eq!(ema.update(1.0), None);
204 assert_eq!(ema.update(2.0), None);
205 assert_eq!(ema.update(3.0), Some(2.0)); }
207
208 #[test]
209 fn first_value_equals_sma_seed() {
210 let mut ema = Ema::new(5).unwrap();
211 let inputs = [10.0, 20.0, 30.0, 40.0, 50.0];
212 let mut last = None;
213 for v in inputs {
214 last = ema.update(v);
215 }
216 assert_relative_eq!(last.unwrap(), 30.0, epsilon = 1e-12);
217 }
218
219 #[test]
220 fn alpha_matches_period_formula() {
221 let ema = Ema::new(10).unwrap();
222 assert_relative_eq!(ema.alpha(), 2.0 / 11.0, epsilon = 1e-15);
223 }
224
225 #[test]
226 fn step_after_seed_uses_alpha_formula() {
227 let mut ema = Ema::new(3).unwrap();
230 ema.batch(&[1.0, 2.0, 3.0]);
231 assert_relative_eq!(ema.update(10.0).unwrap(), 6.0, epsilon = 1e-12);
232 }
233
234 #[test]
235 fn constant_series_converges_to_constant() {
236 let mut ema = Ema::new(10).unwrap();
237 let out = ema.batch(&[42.0_f64; 100]);
238 for x in out.iter().skip(9) {
239 assert_relative_eq!(x.unwrap(), 42.0, epsilon = 1e-9);
240 }
241 }
242
243 #[test]
244 fn with_alpha_validates_range() {
245 assert!(Ema::with_alpha(0.5).is_ok());
246 assert!(Ema::with_alpha(1.0).is_ok());
247 assert!(matches!(
248 Ema::with_alpha(0.0),
249 Err(Error::InvalidPeriod { .. })
250 ));
251 assert!(matches!(
252 Ema::with_alpha(1.5),
253 Err(Error::InvalidPeriod { .. })
254 ));
255 assert!(matches!(
256 Ema::with_alpha(f64::NAN),
257 Err(Error::InvalidPeriod { .. })
258 ));
259 }
260
261 #[test]
262 fn reset_clears_state() {
263 let mut ema = Ema::new(3).unwrap();
264 ema.batch(&[1.0, 2.0, 3.0]);
265 assert!(ema.is_ready());
266 ema.reset();
267 assert!(!ema.is_ready());
268 assert_eq!(ema.update(1.0), None);
269 }
270
271 #[test]
272 fn batch_equals_streaming() {
273 let prices: Vec<f64> = (1..=30).map(f64::from).collect();
274 let mut a = Ema::new(5).unwrap();
275 let mut b = Ema::new(5).unwrap();
276 assert_eq!(
277 a.batch(&prices),
278 prices.iter().map(|p| b.update(*p)).collect::<Vec<_>>()
279 );
280 }
281
282 #[test]
283 fn ignores_non_finite_input() {
284 let mut ema = Ema::new(3).unwrap();
285 ema.batch(&[1.0, 2.0, 3.0]);
286 let before = ema.value();
287 assert_eq!(ema.update(f64::NAN), before);
288 assert_eq!(ema.update(f64::INFINITY), before);
289 }
290
291 proptest::proptest! {
292 #![proptest_config(proptest::test_runner::Config::with_cases(48))]
293 #[test]
294 fn ema_matches_naive(
295 period in 1usize..20,
296 prices in proptest::collection::vec(-1000.0_f64..1000.0, 0..150),
297 ) {
298 let mut ema = Ema::new(period).unwrap();
299 let got = ema.batch(&prices);
300 let want = ema_naive(&prices, period);
301 proptest::prop_assert_eq!(got.len(), want.len());
302 for (g, w) in got.iter().zip(want.iter()) {
303 match (g, w) {
304 (None, None) => {}
305 (Some(a), Some(b)) => proptest::prop_assert!(
306 (a - b).abs() <= 1e-9 * a.abs().max(1.0),
307 "got={a} want={b}"
308 ),
309 _ => proptest::prop_assert!(false, "warmup mismatch"),
310 }
311 }
312 }
313 }
314}