Skip to main content

wickra_core/indicators/
wma.rs

1//! Weighted Moving Average (linear weights).
2
3use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8/// Weighted Moving Average with linear weights `1, 2, ..., period`.
9///
10/// Output is `sum(weight_i * price_i) / sum(weights)`. Maintained incrementally in
11/// O(1) by keeping the rolling sum of values and the rolling weighted sum.
12///
13/// # Example
14///
15/// ```
16/// use wickra_core::{Indicator, Wma};
17///
18/// let mut indicator = Wma::new(3).unwrap();
19/// let mut last = None;
20/// for i in 0..80 {
21///     last = indicator.update(100.0 + f64::from(i));
22/// }
23/// assert!(last.is_some());
24/// ```
25#[derive(Debug, Clone)]
26pub struct Wma {
27    period: usize,
28    window: VecDeque<f64>,
29    weight_sum: f64, // sum_i (weight_i * value_i)
30    value_sum: f64,  // sum_i (value_i)
31    weights_total: f64,
32}
33
34impl Wma {
35    /// Construct a new WMA with the given window length.
36    ///
37    /// # Errors
38    ///
39    /// Returns [`Error::PeriodZero`] if `period == 0`.
40    pub fn new(period: usize) -> Result<Self> {
41        if period == 0 {
42            return Err(Error::PeriodZero);
43        }
44        let n = period as f64;
45        let weights_total = n * (n + 1.0) / 2.0;
46        Ok(Self {
47            period,
48            window: VecDeque::with_capacity(period),
49            weight_sum: 0.0,
50            value_sum: 0.0,
51            weights_total,
52        })
53    }
54
55    /// Configured period.
56    pub const fn period(&self) -> usize {
57        self.period
58    }
59
60    /// Current value if available.
61    pub fn value(&self) -> Option<f64> {
62        if self.window.len() == self.period {
63            Some(self.weight_sum / self.weights_total)
64        } else {
65            None
66        }
67    }
68}
69
70impl Indicator for Wma {
71    type Input = f64;
72    type Output = f64;
73
74    fn update(&mut self, input: f64) -> Option<f64> {
75        if !input.is_finite() {
76            return self.value();
77        }
78        if self.window.len() < self.period {
79            // Warmup. Just accumulate; compute weight_sum once when the window first
80            // becomes full to avoid having to track changing weights during warmup.
81            self.window.push_back(input);
82            self.value_sum += input;
83            if self.window.len() == self.period {
84                self.weight_sum = self
85                    .window
86                    .iter()
87                    .enumerate()
88                    .map(|(i, v)| (i as f64 + 1.0) * v)
89                    .sum();
90            }
91            return self.value();
92        }
93        // Steady state: slide the window. With weights [1, 2, ..., period],
94        //   new_weight_sum = old_weight_sum - old_value_sum + period * new_input
95        // because every retained element's weight drops by one and the newcomer
96        // enters at weight = period. Order matters: subtract `value_sum` BEFORE
97        // updating it.
98        let oldest = self.window.pop_front().expect("window non-empty");
99        self.weight_sum = self.weight_sum - self.value_sum + self.period as f64 * input;
100        self.value_sum = self.value_sum - oldest + input;
101        self.window.push_back(input);
102        self.value()
103    }
104
105    fn reset(&mut self) {
106        self.window.clear();
107        self.weight_sum = 0.0;
108        self.value_sum = 0.0;
109    }
110
111    fn warmup_period(&self) -> usize {
112        self.period
113    }
114
115    fn is_ready(&self) -> bool {
116        self.window.len() == self.period
117    }
118
119    fn name(&self) -> &'static str {
120        "WMA"
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use crate::traits::BatchExt;
128    use approx::assert_relative_eq;
129
130    /// Reference implementation: explicit weighted average over a window.
131    fn wma_naive(prices: &[f64], period: usize) -> Vec<Option<f64>> {
132        let weights_total = (period as f64) * (period as f64 + 1.0) / 2.0;
133        prices
134            .iter()
135            .enumerate()
136            .map(|(i, _)| {
137                if i + 1 < period {
138                    None
139                } else {
140                    let window = &prices[i + 1 - period..=i];
141                    let s: f64 = window
142                        .iter()
143                        .enumerate()
144                        .map(|(j, p)| (j as f64 + 1.0) * p)
145                        .sum();
146                    Some(s / weights_total)
147                }
148            })
149            .collect()
150    }
151
152    #[test]
153    fn new_rejects_zero_period() {
154        assert!(matches!(Wma::new(0), Err(Error::PeriodZero)));
155    }
156
157    /// Cover the const accessor `period` (56-58) and the Indicator-impl
158    /// `warmup_period` (111-113) + `name` (119-121). Existing tests never
159    /// inspect these metadata methods.
160    #[test]
161    fn accessors_and_metadata() {
162        let wma = Wma::new(7).unwrap();
163        assert_eq!(wma.period(), 7);
164        assert_eq!(wma.warmup_period(), 7);
165        assert_eq!(wma.name(), "WMA");
166    }
167
168    #[test]
169    fn warmup_returns_none() {
170        let mut wma = Wma::new(3).unwrap();
171        assert_eq!(wma.update(1.0), None);
172        assert_eq!(wma.update(2.0), None);
173        // WMA(3) of [1,2,3]: oldest = 1 (weight 1), middle = 2 (weight 2), newest = 3 (weight 3)
174        // -> (1*1 + 2*2 + 3*3) / (1+2+3) = 14/6
175        assert_relative_eq!(wma.update(3.0).unwrap(), 14.0 / 6.0, epsilon = 1e-12);
176    }
177
178    #[test]
179    fn known_values_period_4() {
180        // WMA(4) weights 1,2,3,4 (total 10); inputs [1,2,3,4]:
181        // (1*1 + 2*2 + 3*3 + 4*4) / 10 = (1+4+9+16)/10 = 30/10 = 3.0
182        let mut wma = Wma::new(4).unwrap();
183        let v = wma.batch(&[1.0, 2.0, 3.0, 4.0]);
184        assert_relative_eq!(v[3].unwrap(), 3.0, epsilon = 1e-12);
185    }
186
187    #[test]
188    fn matches_naive_over_random_inputs() {
189        let prices: Vec<f64> = (1..=30).map(|i| f64::from(i) * 1.7 - 5.0).collect();
190        let mut wma = Wma::new(7).unwrap();
191        let got = wma.batch(&prices);
192        let want = wma_naive(&prices, 7);
193        for (i, (g, w)) in got.iter().zip(want.iter()).enumerate() {
194            // Same warmup — emission shape must agree at every index.
195            assert_eq!(g.is_some(), w.is_some(), "warmup mismatch at index {i}");
196            if let (Some(a), Some(b)) = (g, w) {
197                assert_relative_eq!(*a, *b, epsilon = 1e-9);
198            }
199        }
200    }
201
202    #[test]
203    fn period_one_is_pass_through() {
204        let mut wma = Wma::new(1).unwrap();
205        assert_relative_eq!(wma.update(5.5).unwrap(), 5.5, epsilon = 1e-12);
206        assert_relative_eq!(wma.update(7.5).unwrap(), 7.5, epsilon = 1e-12);
207    }
208
209    #[test]
210    fn reset_clears_state() {
211        let mut wma = Wma::new(4).unwrap();
212        wma.batch(&[1.0, 2.0, 3.0, 4.0, 5.0]);
213        assert!(wma.is_ready());
214        wma.reset();
215        assert!(!wma.is_ready());
216        assert_eq!(wma.update(10.0), None);
217    }
218
219    #[test]
220    fn batch_equals_streaming() {
221        let prices: Vec<f64> = (1..=20).map(|i| f64::from(i) * 0.5).collect();
222        let mut a = Wma::new(5).unwrap();
223        let mut b = Wma::new(5).unwrap();
224        assert_eq!(
225            a.batch(&prices),
226            prices.iter().map(|p| b.update(*p)).collect::<Vec<_>>()
227        );
228    }
229
230    #[test]
231    fn ignores_non_finite_input_but_keeps_state() {
232        let mut wma = Wma::new(3).unwrap();
233        wma.update(1.0);
234        wma.update(2.0);
235        let ready = wma.update(3.0).expect("WMA(3) ready after three inputs");
236        // Non-finite inputs return the last value without mutating the window.
237        assert_eq!(wma.update(f64::NAN), Some(ready));
238        assert_eq!(wma.update(f64::INFINITY), Some(ready));
239        // The window still holds 1, 2, 3 -> next real input slides it to 2, 3, 4.
240        assert_relative_eq!(
241            wma.update(4.0).unwrap(),
242            (2.0 * 1.0 + 3.0 * 2.0 + 4.0 * 3.0) / 6.0,
243            epsilon = 1e-12
244        );
245    }
246
247    proptest::proptest! {
248        #![proptest_config(proptest::test_runner::Config::with_cases(48))]
249        #[test]
250        fn proptest_matches_naive(
251            period in 1usize..15,
252            prices in proptest::collection::vec(-500.0_f64..500.0, 0..120),
253        ) {
254            let mut wma = Wma::new(period).unwrap();
255            let got = wma.batch(&prices);
256            let want = wma_naive(&prices, period);
257            proptest::prop_assert_eq!(got.len(), want.len());
258            for (g, w) in got.iter().zip(want.iter()) {
259                match (g, w) {
260                    (None, None) => {}
261                    (Some(a), Some(b)) => proptest::prop_assert!(
262                        (a - b).abs() < 1e-7,
263                        "got={a} want={b}"
264                    ),
265                    _ => proptest::prop_assert!(false, "warmup mismatch"),
266                }
267            }
268        }
269    }
270}