ux_indicators/indicators/
wma.rs

1#![allow(dead_code)]
2#![allow(unused_imports)]
3
4use std::fmt;
5
6use rust_decimal_macros::*;
7use rust_decimal::prelude::*;
8use std::collections::VecDeque;
9
10use crate::errors::{ErrorKind, Error};
11use crate::{Close, Next, Reset};
12
13use crate::{Factory};
14use crate::indicators::SimpleMovingAverage;
15
16pub struct WmaFactory {
17}
18
19impl WmaFactory {
20    pub fn new() -> Self {
21        Self{}
22    }
23}
24
25impl Factory for WmaFactory {
26    fn create() -> Box<dyn Next<f64, Output = Box<[f64]>>> {
27        Box::new(SimpleMovingAverage::default())
28    }
29}
30
31/// Weighted moving average (SMA).
32///
33/// # Formula
34///
35/// ![WMA](https://wikimedia.org/api/rest_v1/media/math/render/svg/e2bf09dc6deaf86b3607040585fac6078f9c7c89)
36///
37/// Where:
38///
39/// * _WMA<sub>t</sub>_ - value of simple moving average at a point of time _t_
40/// * _period_ - number of periods (length)
41/// * _p<sub>t</sub>_ - input value at a point of time _t_
42///
43/// # Parameters
44///
45/// * _period_ - number of periods (integer greater than 0)
46///
47/// # Example
48///
49/// ```
50/// use core::indicators::WeightedMovingAverage;
51/// use core::Next;
52///
53/// let mut wma = WeightedMovingAverage::new(3).unwrap();
54/// assert_eq!(wma.next(10.0), f64::INFINITY);
55/// ```
56///
57/// # Links
58///
59/// * [Weighted Moving Average, Wikipedia](https://en.wikipedia.org/wiki/Moving_average#Simple_moving_average)
60///
61
62#[derive(Debug, Clone)]
63pub struct WeightedMovingAverage {
64    period: u32,
65    index: usize,
66    count: u32,
67    sum: Decimal, /* Flat sum of previous numbers. */
68    weight_sum: Decimal, /* Weighted sum of previous numbers. */
69    vec: VecDeque<f64>,
70}
71
72impl WeightedMovingAverage {
73    // pub fn new(period: u32) -> Result<WeightedMovingAverage, Error> {
74    pub fn new(period: u32) -> Result<WeightedMovingAverage, Error> {
75        match period {
76            // 0 => Err(Error::from_kind(ErrorKind::InvalidParameter)),
77            _ => {
78                let indicator = WeightedMovingAverage {
79                    period: period,
80                    index: 0,
81                    count: 0,
82                    sum: Decimal::zero(),
83                    weight_sum: Decimal::zero(),
84                    vec: VecDeque::with_capacity(period as usize),
85                };
86                Ok(indicator)
87            }
88        }
89    }
90}
91
92impl Next<f64> for WeightedMovingAverage {
93    type Output = f64;
94    
95    fn next(&mut self, input: f64) -> Self::Output {
96        // self.index = (self.index + 1) % (self.period as usize);
97
98        // let old_val = self.vec[self.index];
99        // self.vec[self.index] = input;
100
101        // fill windoe upto period
102        if self.count < self.period - 1 {
103            self.count += 1;
104            self.vec.push_back(input);
105            self.weight_sum += Decimal::from_f64(input).unwrap() * Decimal::from_u32(self.count).unwrap();
106            self.sum += Decimal::from_f64(input).unwrap();
107            return f64::INFINITY;
108        }
109
110        let weights: Decimal = Decimal::from_u32(self.period).unwrap() * (Decimal::from_u32(self.period).unwrap() + Decimal::from_u32(1).unwrap()) / Decimal::from_u32(2).unwrap();
111
112        // sliding window
113        self.weight_sum += Decimal::from_f64(input).unwrap() * Decimal::from_u32(self.period).unwrap();
114        self.sum += Decimal::from_f64(input).unwrap();
115
116        let output: Decimal = self.weight_sum / weights;
117
118        self.vec.push_back(input);
119        self.weight_sum -= self.sum;
120        self.sum -= Decimal::from_f64(self.vec.pop_front().unwrap()).unwrap();
121        
122        output.round_dp_with_strategy(3, RoundingStrategy::RoundHalfUp).to_f64().unwrap()
123    }
124}
125
126impl<'a, T: Close> Next<&'a T> for WeightedMovingAverage {
127    type Output = f64;
128
129    fn next(&mut self, input: &'a T) -> Self::Output {
130        self.next(input.close())
131    }
132}
133
134impl Reset for WeightedMovingAverage {
135    
136    fn reset(&mut self) {
137        self.index = 0;
138        self.count = 0;
139        self.sum = Decimal::from_f64(0.0).unwrap();
140        for idx in 0..(self.period as usize) {
141            self.vec[idx] = 0.0;
142        }
143    }
144}
145
146impl Default for WeightedMovingAverage {
147    fn default() -> Self {
148        Self::new(9).unwrap()
149    }
150}
151
152impl fmt::Display for WeightedMovingAverage {
153    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
154        write!(f, "WMA({})", self.period)
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161    use crate::test_helper::*;
162
163    // test_indicator!(WeightedMovingAverage);
164
165    #[test]
166    fn test_new() {
167        initialize();
168        // assert!(WeightedMovingAverage::new(0).is_err());
169        assert!(WeightedMovingAverage::new(1).is_ok());
170    }
171
172    #[test]
173    fn test_next() {
174        initialize();
175        let _params = &TESTS["wma"];
176        println!("");
177        for _test in _params.tests.iter() {
178            let _period = _test.options[0];
179            let _input = &_test.inputs[0];
180            let _output = &_test.outputs[0];
181            
182            println!("WMA WITH PERIOD {}", _period);
183            let mut indicator = WeightedMovingAverage::new(_period as u32).unwrap();
184            for val in _input.iter() {
185                let res = indicator.next(*val);
186                println!("INPUT {} OUTPUT {}", val, res);
187            }
188        }
189        // let mut indicator = WeightedMovingAverage::new(4).unwrap();
190        // assert_eq!(indicator.next(4.0), 4.0);
191        // assert_eq!(indicator.next(5.0), 4.5);
192        // assert_eq!(indicator.next(6.0), 5.0);
193        // assert_eq!(indicator.next(6.0), 5.25);
194        // assert_eq!(indicator.next(6.0), 5.75);
195        // assert_eq!(indicator.next(6.0), 6.0);
196        // assert_eq!(indicator.next(2.0), 5.0);
197    }
198
199    // #[test]
200    // fn test_next_with_bars() {
201    //     fn bar(close: f64) -> Bar {
202    //         Bar::new().close(close)
203    //     }
204
205    //     let mut indicator = WeightedMovingAverage::new(3).unwrap();
206    //     assert_eq!(indicator.next(&bar(4.0)), 4.0);
207    //     assert_eq!(indicator.next(&bar(4.0)), 4.0);
208    //     assert_eq!(indicator.next(&bar(7.0)), 5.0);
209    //     assert_eq!(indicator.next(&bar(1.0)), 4.0);
210    // }
211
212    // #[test]
213    // fn test_reset() {
214    //     let mut indicator = WeightedMovingAverage::new(4).unwrap();
215    //     assert_eq!(indicator.next(4.0), 4.0);
216    //     assert_eq!(indicator.next(5.0), 4.5);
217    //     assert_eq!(indicator.next(6.0), 5.0);
218
219    //     indicator.reset();
220    //     assert_eq!(indicator.next(99.0), 99.0);
221    // }
222
223    #[test]
224    fn test_default() {
225        WeightedMovingAverage::default();
226    }
227
228    #[test]
229    fn test_display() {
230        let indicator = WeightedMovingAverage::new(5).unwrap();
231        assert_eq!(format!("{}", indicator), "WMA(5)");
232    }
233}