wickra_core/indicators/
ema.rs1use crate::error::{Error, Result};
4use crate::traits::Indicator;
5
6#[derive(Debug, Clone)]
25pub struct Ema {
26 period: usize,
27 alpha: f64,
28 state: Option<f64>,
29 warmup_buf: Vec<f64>,
30}
31
32impl Ema {
33 pub fn new(period: usize) -> Result<Self> {
39 if period == 0 {
40 return Err(Error::PeriodZero);
41 }
42 let alpha = 2.0 / (period as f64 + 1.0);
43 Ok(Self {
44 period,
45 alpha,
46 state: None,
47 warmup_buf: Vec::with_capacity(period),
48 })
49 }
50
51 pub fn with_alpha(alpha: f64) -> Result<Self> {
61 if !alpha.is_finite() || alpha <= 0.0 || alpha > 1.0 {
62 return Err(Error::InvalidPeriod {
63 message: "alpha must be in (0.0, 1.0]",
64 });
65 }
66 Ok(Self {
67 period: 1,
68 alpha,
69 state: None,
70 warmup_buf: Vec::with_capacity(1),
71 })
72 }
73
74 pub const fn period(&self) -> usize {
76 self.period
77 }
78
79 pub const fn alpha(&self) -> f64 {
81 self.alpha
82 }
83
84 pub const fn value(&self) -> Option<f64> {
86 self.state
87 }
88
89 pub(crate) fn step_unchecked(&mut self, input: f64) -> Option<f64> {
92 if let Some(prev) = self.state {
93 let new = self.alpha.mul_add(input, (1.0 - self.alpha) * prev);
94 self.state = Some(new);
95 return Some(new);
96 }
97 self.warmup_buf.push(input);
98 if self.warmup_buf.len() == self.period {
99 let seed = self.warmup_buf.iter().copied().sum::<f64>() / self.period as f64;
100 self.state = Some(seed);
101 return Some(seed);
102 }
103 None
104 }
105}
106
107impl Indicator for Ema {
108 type Input = f64;
109 type Output = f64;
110
111 fn update(&mut self, input: f64) -> Option<f64> {
112 if !input.is_finite() {
113 return self.state;
114 }
115 self.step_unchecked(input)
116 }
117
118 fn reset(&mut self) {
119 self.state = None;
120 self.warmup_buf.clear();
121 }
122
123 fn warmup_period(&self) -> usize {
124 self.period
125 }
126
127 fn is_ready(&self) -> bool {
128 self.state.is_some()
129 }
130
131 fn name(&self) -> &'static str {
132 "EMA"
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139 use crate::traits::BatchExt;
140 use approx::assert_relative_eq;
141
142 fn ema_naive(prices: &[f64], period: usize) -> Vec<Option<f64>> {
144 let alpha = 2.0 / (period as f64 + 1.0);
145 let mut out = Vec::with_capacity(prices.len());
146 let mut state: Option<f64> = None;
147 for (i, &p) in prices.iter().enumerate() {
148 if let Some(prev) = state {
149 let v = alpha * p + (1.0 - alpha) * prev;
150 state = Some(v);
151 out.push(Some(v));
152 } else if i + 1 == period {
153 let seed = prices[..period].iter().sum::<f64>() / period as f64;
154 state = Some(seed);
155 out.push(Some(seed));
156 } else {
157 out.push(None);
158 }
159 }
160 out
161 }
162
163 #[test]
164 fn new_rejects_zero_period() {
165 assert!(matches!(Ema::new(0), Err(Error::PeriodZero)));
166 }
167
168 #[test]
173 fn accessors_and_metadata() {
174 let ema = Ema::new(14).unwrap();
175 assert_eq!(ema.period(), 14);
176 assert_eq!(ema.warmup_period(), 14);
177 assert_eq!(ema.name(), "EMA");
178 }
179
180 #[test]
181 fn warmup_returns_none_until_seed() {
182 let mut ema = Ema::new(3).unwrap();
183 assert_eq!(ema.update(1.0), None);
184 assert_eq!(ema.update(2.0), None);
185 assert_eq!(ema.update(3.0), Some(2.0)); }
187
188 #[test]
189 fn first_value_equals_sma_seed() {
190 let mut ema = Ema::new(5).unwrap();
191 let inputs = [10.0, 20.0, 30.0, 40.0, 50.0];
192 let mut last = None;
193 for v in inputs {
194 last = ema.update(v);
195 }
196 assert_relative_eq!(last.unwrap(), 30.0, epsilon = 1e-12);
197 }
198
199 #[test]
200 fn alpha_matches_period_formula() {
201 let ema = Ema::new(10).unwrap();
202 assert_relative_eq!(ema.alpha(), 2.0 / 11.0, epsilon = 1e-15);
203 }
204
205 #[test]
206 fn step_after_seed_uses_alpha_formula() {
207 let mut ema = Ema::new(3).unwrap();
210 ema.batch(&[1.0, 2.0, 3.0]);
211 assert_relative_eq!(ema.update(10.0).unwrap(), 6.0, epsilon = 1e-12);
212 }
213
214 #[test]
215 fn constant_series_converges_to_constant() {
216 let mut ema = Ema::new(10).unwrap();
217 let out = ema.batch(&[42.0_f64; 100]);
218 for x in out.iter().skip(9) {
219 assert_relative_eq!(x.unwrap(), 42.0, epsilon = 1e-9);
220 }
221 }
222
223 #[test]
224 fn with_alpha_validates_range() {
225 assert!(Ema::with_alpha(0.5).is_ok());
226 assert!(Ema::with_alpha(1.0).is_ok());
227 assert!(matches!(
228 Ema::with_alpha(0.0),
229 Err(Error::InvalidPeriod { .. })
230 ));
231 assert!(matches!(
232 Ema::with_alpha(1.5),
233 Err(Error::InvalidPeriod { .. })
234 ));
235 assert!(matches!(
236 Ema::with_alpha(f64::NAN),
237 Err(Error::InvalidPeriod { .. })
238 ));
239 }
240
241 #[test]
242 fn reset_clears_state() {
243 let mut ema = Ema::new(3).unwrap();
244 ema.batch(&[1.0, 2.0, 3.0]);
245 assert!(ema.is_ready());
246 ema.reset();
247 assert!(!ema.is_ready());
248 assert_eq!(ema.update(1.0), None);
249 }
250
251 #[test]
252 fn batch_equals_streaming() {
253 let prices: Vec<f64> = (1..=30).map(f64::from).collect();
254 let mut a = Ema::new(5).unwrap();
255 let mut b = Ema::new(5).unwrap();
256 assert_eq!(
257 a.batch(&prices),
258 prices.iter().map(|p| b.update(*p)).collect::<Vec<_>>()
259 );
260 }
261
262 #[test]
263 fn ignores_non_finite_input() {
264 let mut ema = Ema::new(3).unwrap();
265 ema.batch(&[1.0, 2.0, 3.0]);
266 let before = ema.value();
267 assert_eq!(ema.update(f64::NAN), before);
268 assert_eq!(ema.update(f64::INFINITY), before);
269 }
270
271 proptest::proptest! {
272 #![proptest_config(proptest::test_runner::Config::with_cases(48))]
273 #[test]
274 fn ema_matches_naive(
275 period in 1usize..20,
276 prices in proptest::collection::vec(-1000.0_f64..1000.0, 0..150),
277 ) {
278 let mut ema = Ema::new(period).unwrap();
279 let got = ema.batch(&prices);
280 let want = ema_naive(&prices, period);
281 proptest::prop_assert_eq!(got.len(), want.len());
282 for (g, w) in got.iter().zip(want.iter()) {
283 match (g, w) {
284 (None, None) => {}
285 (Some(a), Some(b)) => proptest::prop_assert!(
286 (a - b).abs() <= 1e-9 * a.abs().max(1.0),
287 "got={a} want={b}"
288 ),
289 _ => proptest::prop_assert!(false, "warmup mismatch"),
290 }
291 }
292 }
293 }
294}