wickra_core/indicators/
wma.rs1use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8#[derive(Debug, Clone)]
26pub struct Wma {
27 period: usize,
28 window: VecDeque<f64>,
29 weight_sum: f64, value_sum: f64, weights_total: f64,
32}
33
34impl Wma {
35 pub fn new(period: usize) -> Result<Self> {
41 if period == 0 {
42 return Err(Error::PeriodZero);
43 }
44 let n = period as f64;
45 let weights_total = n * (n + 1.0) / 2.0;
46 Ok(Self {
47 period,
48 window: VecDeque::with_capacity(period),
49 weight_sum: 0.0,
50 value_sum: 0.0,
51 weights_total,
52 })
53 }
54
55 pub const fn period(&self) -> usize {
57 self.period
58 }
59
60 pub fn value(&self) -> Option<f64> {
62 if self.window.len() == self.period {
63 Some(self.weight_sum / self.weights_total)
64 } else {
65 None
66 }
67 }
68}
69
70impl Indicator for Wma {
71 type Input = f64;
72 type Output = f64;
73
74 fn update(&mut self, input: f64) -> Option<f64> {
75 if !input.is_finite() {
76 return self.value();
77 }
78 if self.window.len() < self.period {
79 self.window.push_back(input);
82 self.value_sum += input;
83 if self.window.len() == self.period {
84 self.weight_sum = self
85 .window
86 .iter()
87 .enumerate()
88 .map(|(i, v)| (i as f64 + 1.0) * v)
89 .sum();
90 }
91 return self.value();
92 }
93 let oldest = self.window.pop_front().expect("window non-empty");
99 self.weight_sum = self.weight_sum - self.value_sum + self.period as f64 * input;
100 self.value_sum = self.value_sum - oldest + input;
101 self.window.push_back(input);
102 self.value()
103 }
104
105 fn reset(&mut self) {
106 self.window.clear();
107 self.weight_sum = 0.0;
108 self.value_sum = 0.0;
109 }
110
111 fn warmup_period(&self) -> usize {
112 self.period
113 }
114
115 fn is_ready(&self) -> bool {
116 self.window.len() == self.period
117 }
118
119 fn name(&self) -> &'static str {
120 "WMA"
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127 use crate::traits::BatchExt;
128 use approx::assert_relative_eq;
129
130 fn wma_naive(prices: &[f64], period: usize) -> Vec<Option<f64>> {
132 let weights_total = (period as f64) * (period as f64 + 1.0) / 2.0;
133 prices
134 .iter()
135 .enumerate()
136 .map(|(i, _)| {
137 if i + 1 < period {
138 None
139 } else {
140 let window = &prices[i + 1 - period..=i];
141 let s: f64 = window
142 .iter()
143 .enumerate()
144 .map(|(j, p)| (j as f64 + 1.0) * p)
145 .sum();
146 Some(s / weights_total)
147 }
148 })
149 .collect()
150 }
151
152 #[test]
153 fn new_rejects_zero_period() {
154 assert!(matches!(Wma::new(0), Err(Error::PeriodZero)));
155 }
156
157 #[test]
161 fn accessors_and_metadata() {
162 let wma = Wma::new(7).unwrap();
163 assert_eq!(wma.period(), 7);
164 assert_eq!(wma.warmup_period(), 7);
165 assert_eq!(wma.name(), "WMA");
166 }
167
168 #[test]
169 fn warmup_returns_none() {
170 let mut wma = Wma::new(3).unwrap();
171 assert_eq!(wma.update(1.0), None);
172 assert_eq!(wma.update(2.0), None);
173 assert_relative_eq!(wma.update(3.0).unwrap(), 14.0 / 6.0, epsilon = 1e-12);
176 }
177
178 #[test]
179 fn known_values_period_4() {
180 let mut wma = Wma::new(4).unwrap();
183 let v = wma.batch(&[1.0, 2.0, 3.0, 4.0]);
184 assert_relative_eq!(v[3].unwrap(), 3.0, epsilon = 1e-12);
185 }
186
187 #[test]
188 fn matches_naive_over_random_inputs() {
189 let prices: Vec<f64> = (1..=30).map(|i| f64::from(i) * 1.7 - 5.0).collect();
190 let mut wma = Wma::new(7).unwrap();
191 let got = wma.batch(&prices);
192 let want = wma_naive(&prices, 7);
193 for (i, (g, w)) in got.iter().zip(want.iter()).enumerate() {
194 assert_eq!(g.is_some(), w.is_some(), "warmup mismatch at index {i}");
196 if let (Some(a), Some(b)) = (g, w) {
197 assert_relative_eq!(*a, *b, epsilon = 1e-9);
198 }
199 }
200 }
201
202 #[test]
203 fn period_one_is_pass_through() {
204 let mut wma = Wma::new(1).unwrap();
205 assert_relative_eq!(wma.update(5.5).unwrap(), 5.5, epsilon = 1e-12);
206 assert_relative_eq!(wma.update(7.5).unwrap(), 7.5, epsilon = 1e-12);
207 }
208
209 #[test]
210 fn reset_clears_state() {
211 let mut wma = Wma::new(4).unwrap();
212 wma.batch(&[1.0, 2.0, 3.0, 4.0, 5.0]);
213 assert!(wma.is_ready());
214 wma.reset();
215 assert!(!wma.is_ready());
216 assert_eq!(wma.update(10.0), None);
217 }
218
219 #[test]
220 fn batch_equals_streaming() {
221 let prices: Vec<f64> = (1..=20).map(|i| f64::from(i) * 0.5).collect();
222 let mut a = Wma::new(5).unwrap();
223 let mut b = Wma::new(5).unwrap();
224 assert_eq!(
225 a.batch(&prices),
226 prices.iter().map(|p| b.update(*p)).collect::<Vec<_>>()
227 );
228 }
229
230 #[test]
231 fn ignores_non_finite_input_but_keeps_state() {
232 let mut wma = Wma::new(3).unwrap();
233 wma.update(1.0);
234 wma.update(2.0);
235 let ready = wma.update(3.0).expect("WMA(3) ready after three inputs");
236 assert_eq!(wma.update(f64::NAN), Some(ready));
238 assert_eq!(wma.update(f64::INFINITY), Some(ready));
239 assert_relative_eq!(
241 wma.update(4.0).unwrap(),
242 (2.0 * 1.0 + 3.0 * 2.0 + 4.0 * 3.0) / 6.0,
243 epsilon = 1e-12
244 );
245 }
246
247 proptest::proptest! {
248 #![proptest_config(proptest::test_runner::Config::with_cases(48))]
249 #[test]
250 fn proptest_matches_naive(
251 period in 1usize..15,
252 prices in proptest::collection::vec(-500.0_f64..500.0, 0..120),
253 ) {
254 let mut wma = Wma::new(period).unwrap();
255 let got = wma.batch(&prices);
256 let want = wma_naive(&prices, period);
257 proptest::prop_assert_eq!(got.len(), want.len());
258 for (g, w) in got.iter().zip(want.iter()) {
259 match (g, w) {
260 (None, None) => {}
261 (Some(a), Some(b)) => proptest::prop_assert!(
262 (a - b).abs() < 1e-7,
263 "got={a} want={b}"
264 ),
265 _ => proptest::prop_assert!(false, "warmup mismatch"),
266 }
267 }
268 }
269 }
270}