ux_indicators/indicators/
sma.rs

1#![allow(dead_code)]
2#![allow(unused_imports)]
3
4use std::fmt;
5
6use crate::errors::*;
7
8use std::collections::HashMap;
9use std::rc::Rc;
10use std::cell::RefCell;
11
12use crate::errors::{ErrorKind, Error};
13use crate::{Close, Next, Reset};
14use crate::{Slot, SlotPtr, SlotType, Frame};
15
16use crate::{Factory};
17
18pub struct SmaFactory {
19}
20
21impl SmaFactory {
22    pub fn new() -> Self {
23        Self{}
24    }
25}
26
27impl Factory for SmaFactory {
28    fn create() -> Box<dyn Next<f64, Output = Box<[f64]>>> {
29        Box::new(SimpleMovingAverage::default())
30    }
31}
32
33/// Simple moving average (SMA).
34///
35/// # Formula
36///
37/// ![SMA](https://wikimedia.org/api/rest_v1/media/math/render/svg/e2bf09dc6deaf86b3607040585fac6078f9c7c89)
38///
39/// Where:
40///
41/// * _SMA<sub>t</sub>_ - value of simple moving average at a point of time _t_
42/// * _n_ - number of periods (length)
43/// * _p<sub>t</sub>_ - input value at a point of time _t_
44///
45/// # Parameters
46///
47/// * _n_ - number of periods (integer greater than 0)
48///
49/// # Example
50///
51/// ```
52/// // use core::indicators::SimpleMovingAverage;
53/// // use core::Next;
54///
55/// // let mut sma = SimpleMovingAverage::new(3).unwrap();
56/// // assert_eq!(sma.next(10.0), 10.0);
57/// // assert_eq!(sma.next(11.0), 10.5);
58/// // assert_eq!(sma.next(12.0), 11.0);
59/// // assert_eq!(sma.next(13.0), 12.0);
60/// ```
61///
62/// # Links
63///
64/// * [Simple Moving Average, Wikipedia](https://en.wikipedia.org/wiki/Moving_average#Simple_moving_average)
65///
66
67#[derive(Debug, Clone)]
68pub struct SimpleMovingAverage<'a> {
69    period: u32,
70    index: usize,
71    count: u32,
72    sum: f64,
73    vec: Vec<f64>,
74    pub inputs: HashMap<&'a str, Rc<RefCell<Slot>>>,
75    pub outputs: HashMap<&'a str, Rc<RefCell<Slot>>>,
76}
77 
78impl<'a> SimpleMovingAverage<'a> {
79    pub fn new(period: u32) -> Result<Self> {
80        match period {
81            // 0 => Err(Error::from_kind(ErrorKind::InvalidParameter)),
82            _ => {
83                let indicator = Self {
84                    period: period,
85                    index: 0,
86                    count: 0,
87                    sum: 0.0,
88                    vec: vec![0.0; period as usize],
89                    inputs: [("input", Rc::new(RefCell::new(Slot::new(SlotType::Input))))].iter().cloned().collect(),
90                    outputs: [("output", Rc::new(RefCell::new(Slot::new(SlotType::Output))))].iter().cloned().collect(),
91                };
92
93                Ok(indicator)
94            }
95        }
96    }
97    
98    // fn slot(&mut self, name: &str) -> Option<&mut Rc<RefCell<Slot>>> {
99    //     self.inputs.get_mut(name)
100    // }
101
102    fn slot(&mut self, name: &str) -> Option<&Rc<RefCell<Slot>>> {
103        self.inputs.get(name)
104    }
105
106    fn process(&mut self) {
107        match self.inputs.get("input") {
108            None => println!("INPUTS OOPS {:?}", self),
109            Some(slot) => {
110                // println!("HERE {:?}", slot.borrow_mut());
111
112                let input = slot.borrow_mut().get();
113                
114                self.index = (self.index + 1) % (self.period as usize);
115                let old_val = self.vec[self.index];
116                
117                self.vec[self.index] = input;
118
119                // fill counter upto period
120                if self.count < self.period {
121                    self.count += 1;
122                }
123
124                // sliding window
125                self.sum = self.sum - old_val + input;
126                let output = self.sum / (self.count as f64);
127                match self.outputs.get("output") {
128                    None => println!("OUTPUT OOPS {:?}", self),
129                    Some(slot) => {
130                        slot.borrow_mut().put(output);
131                    }
132                }
133            }
134        }
135    }
136}
137
138impl<'a> Next<f64> for SimpleMovingAverage<'a> {
139    type Output = Box<[f64]>;
140    
141    fn next(&mut self, input: f64) -> Self::Output {
142        self.index = (self.index + 1) % (self.period as usize);
143
144        let old_val = self.vec[self.index];
145        self.vec[self.index] = input;
146
147        // fill counter upto period
148        if self.count < self.period {
149            self.count += 1;
150        }
151
152        // sliding window
153        self.sum = self.sum - old_val + input;
154        Box::new([self.sum / (self.count as f64)])
155    }
156}
157
158// impl<'a, T: Close> Next<&'a T> for SimpleMovingAverage {
159//     type Output = f64;
160
161//     fn next(&mut self, input: &'a T) -> Self::Output {
162//         self.next(input.close())
163//     }
164// }
165
166impl<'a> Reset for SimpleMovingAverage<'a> {
167    
168    fn reset(&mut self) {
169        self.index = 0;
170        self.count = 0;
171        self.sum = 0.0;
172        for idx in 0..(self.period as usize) {
173            self.vec[idx] = 0.0;
174        }
175    }
176}
177
178impl<'a> Default for SimpleMovingAverage<'a> {
179    fn default() -> Self {
180        Self::new(9).unwrap()
181    }
182}
183
184impl<'a> fmt::Display for SimpleMovingAverage<'a> {
185    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
186        write!(f, "SMA({})", self.period)
187    }
188}
189
190// fn run_test<T>(test: T) -> ()
191//     where T: FnOnce() -> () + panic::UnwindSafe
192// {
193//     setup();
194
195//     let result = panic::catch_unwind(|| {
196//         test()
197//     });
198
199//     teardown();
200
201//     assert!(result.is_ok())
202// }
203
204// #[test]
205// fn test() {
206//     run_test(|| {
207//         let ret_value = function_under_test();
208//         assert!(ret_value);
209//     })
210// }
211
212// #[test]
213// fn test_something_interesting() {
214//     run_test(|| {
215//         let true_or_false = do_the_test();
216
217//         assert!(true_or_false);
218//     })
219// }
220// fn run_test<T>(test: T) -> ()
221//     where T: FnOnce() -> () + panic::UnwindSafe
222// {
223//     setup();
224
225//     let result = panic::catch_unwind(|| {
226//         test()
227//     });
228
229//     teardown();
230
231//     assert!(result.is_ok())
232// }
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use crate::test_helper::*;
238    use crate::{Slot, SlotPtr, SlotType, Frame};
239
240    // test_indicator!(SimpleMovingAverage);
241
242    
243    #[test]
244    fn test_new() {
245        initialize();
246        // assert!(SimpleMovingAverage::new(0).is_err());
247        assert!(SimpleMovingAverage::new(1).is_ok());
248    }
249
250    fn step(sma: &mut SimpleMovingAverage, val: f64) -> f64 {
251        let input = sma.slot("input").unwrap();
252        input.borrow_mut().put(val);
253        sma.process();
254        let output = sma.outputs.get("output").unwrap();
255        output.borrow_mut().get()
256    }
257
258    #[test]
259    fn test_process() {
260        initialize();
261
262        // assert!(SimpleMovingAverage::new(0).is_err());
263        // assert!(SimpleMovingAverage::new(1).is_ok());
264
265        let mut sma = SimpleMovingAverage::new(4).unwrap();
266        // match sma.slot("input") {
267        //     None => println!("WIRE1 NOT CONNECTED FOR input"),
268        //     Some(input) => {
269        //         // let mut input = input.borrow_mut();
270        //         input.borrow_mut().put(1.0);
271        //         sma.process();
272        //         input.borrow_mut().put(1.5);
273        //         sma.process();
274        //         input.borrow_mut().put(1.0);
275        //         sma.process();
276        //     }
277        // }
278        assert_eq!(step(&mut sma, 4.0), 4.0);
279        assert_eq!(step(&mut sma, 5.0), 4.5);
280        assert_eq!(step(&mut sma, 6.0), 5.0);        
281        assert_eq!(step(&mut sma, 6.0), 5.25);
282        assert_eq!(step(&mut sma, 6.0), 5.75);
283        assert_eq!(step(&mut sma, 6.0), 6.0);
284        assert_eq!(step(&mut sma, 2.0), 5.0);
285    }
286
287    // let mut inputs: HashMap<&str, SlotPtr> = HashMap::new();
288    // // let mut new_tail = Box::new(Slot::new(SlotType::Input));
289    // // let ptr = &mut *new_tail;
290    // // inputs.insert("input", ptr);
291    // // inputs.insert("input", &mut *new_tail);
292    // inputs.insert("input", &mut *Box::new(Slot::new(SlotType::Input)));
293
294    // println!("INPUTS {:?}", inputs);
295
296    // let slot = *inputs.get("input").unwrap();
297    // unsafe {
298    //     let slot = &mut *slot;
299    //     slot.put(2.0);
300    // }
301
302    // unsafe {
303    //     let slot = &mut *slot;
304    //     println!("SLOT {:?}", slot);
305    // }
306
307    // let a: HashMap<&str, SlotPtr> = [("input", ptr)].iter().cloned().collect();
308
309
310
311    // #[test]
312    // fn test_next() {
313    //     initialize();
314
315    //     let mut sma = SimpleMovingAverage::new(4).unwrap();
316    //     assert_eq!(sma.next(4.0), 4.0);
317    //     assert_eq!(sma.next(5.0), 4.5);
318    //     assert_eq!(sma.next(6.0), 5.0);
319    //     assert_eq!(sma.next(6.0), 5.25);
320    //     assert_eq!(sma.next(6.0), 5.75);
321    //     assert_eq!(sma.next(6.0), 6.0);
322    //     assert_eq!(sma.next(2.0), 5.0);
323    // }
324
325    // #[test]
326    // fn test_next_with_bars() {
327    //     initialize();
328
329    //     fn bar(close: f64) -> Bar {
330    //         Bar::new().close(close)
331    //     }
332
333    //     let mut sma = SimpleMovingAverage::new(3).unwrap();
334    //     assert_eq!(sma.next(&bar(4.0)), 4.0);
335    //     assert_eq!(sma.next(&bar(4.0)), 4.0);
336    //     assert_eq!(sma.next(&bar(7.0)), 5.0);
337    //     assert_eq!(sma.next(&bar(1.0)), 4.0);
338    // }
339
340    // #[test]
341    // fn test_reset() {
342    //     initialize();
343
344    //     let mut sma = SimpleMovingAverage::new(4).unwrap();
345    //     assert_eq!(sma.next(4.0), 4.0);
346    //     assert_eq!(sma.next(5.0), 4.5);
347    //     assert_eq!(sma.next(6.0), 5.0);
348
349    //     sma.reset();
350    //     assert_eq!(sma.next(99.0), 99.0);
351    // }
352
353    #[test]
354    fn test_default() {
355        initialize();
356
357        SimpleMovingAverage::default();
358    }
359
360    #[test]
361    fn test_display() {
362        initialize();
363
364        let sma = SimpleMovingAverage::new(5).unwrap();
365        assert_eq!(format!("{}", sma), "SMA(5)");
366    }
367}