wickra_core/indicators/
td_moving_average.rs1#![allow(clippy::doc_markdown)]
2
3use crate::error::{Error, Result};
6use crate::indicators::sma::Sma;
7use crate::ohlcv::Candle;
8use crate::traits::Indicator;
9
10#[derive(Debug, Clone, Copy, PartialEq)]
13pub struct TdMovingAverageOutput {
14 pub st1: f64,
16 pub st2: f64,
18}
19
20#[derive(Debug, Clone)]
53pub struct TdMovingAverage {
54 st1: Sma,
55 st2: Sma,
56 period_st1: usize,
57 period_st2: usize,
58 last: Option<TdMovingAverageOutput>,
59}
60
61impl TdMovingAverage {
62 pub fn new(period_st1: usize, period_st2: usize) -> Result<Self> {
69 if period_st1 == 0 || period_st2 == 0 {
70 return Err(Error::PeriodZero);
71 }
72 if period_st1 >= period_st2 {
73 return Err(Error::InvalidPeriod {
74 message: "TD moving average ST1 period must be strictly less than ST2",
75 });
76 }
77 Ok(Self {
78 st1: Sma::new(period_st1)?,
79 st2: Sma::new(period_st2)?,
80 period_st1,
81 period_st2,
82 last: None,
83 })
84 }
85
86 pub const fn periods(&self) -> (usize, usize) {
88 (self.period_st1, self.period_st2)
89 }
90
91 pub const fn value(&self) -> Option<TdMovingAverageOutput> {
93 self.last
94 }
95}
96
97impl Indicator for TdMovingAverage {
98 type Input = Candle;
99 type Output = TdMovingAverageOutput;
100
101 fn update(&mut self, candle: Candle) -> Option<TdMovingAverageOutput> {
102 let price = candle.median_price();
103 let fast = self.st1.update(price);
104 let slow = self.st2.update(price);
105 if let (Some(st1), Some(st2)) = (fast, slow) {
106 let out = TdMovingAverageOutput { st1, st2 };
107 self.last = Some(out);
108 return Some(out);
109 }
110 None
111 }
112
113 fn reset(&mut self) {
114 self.st1.reset();
115 self.st2.reset();
116 self.last = None;
117 }
118
119 fn warmup_period(&self) -> usize {
120 self.period_st2
121 }
122
123 fn is_ready(&self) -> bool {
124 self.last.is_some()
125 }
126
127 fn name(&self) -> &'static str {
128 "TDMovingAverage"
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135 use crate::traits::BatchExt;
136 use approx::assert_relative_eq;
137
138 fn c(median: f64) -> Candle {
139 Candle::new_unchecked(median, median + 1.0, median - 1.0, median, 1_000.0, 0)
140 }
141
142 #[test]
143 fn rejects_invalid_periods() {
144 assert!(matches!(
145 TdMovingAverage::new(0, 13),
146 Err(Error::PeriodZero)
147 ));
148 assert!(matches!(
149 TdMovingAverage::new(13, 5),
150 Err(Error::InvalidPeriod { .. })
151 ));
152 assert!(matches!(
153 TdMovingAverage::new(5, 5),
154 Err(Error::InvalidPeriod { .. })
155 ));
156 }
157
158 #[test]
159 fn accessors_and_metadata() {
160 let td = TdMovingAverage::new(5, 13).unwrap();
161 assert_eq!(td.periods(), (5, 13));
162 assert_eq!(td.warmup_period(), 13);
163 assert_eq!(td.name(), "TDMovingAverage");
164 assert!(!td.is_ready());
165 assert_eq!(td.value(), None);
166 }
167
168 #[test]
169 fn first_emission_at_warmup_period() {
170 let mut td = TdMovingAverage::new(2, 4).unwrap();
171 let candles: Vec<Candle> = (0..8).map(|i| c(100.0 + f64::from(i))).collect();
172 let out = td.batch(&candles);
173 for v in out.iter().take(3) {
174 assert!(v.is_none());
175 }
176 assert!(out[3].is_some());
177 }
178
179 #[test]
180 fn fast_leads_slow_in_uptrend() {
181 let mut td = TdMovingAverage::new(3, 7).unwrap();
182 let candles: Vec<Candle> = (0..40).map(|i| c(100.0 + f64::from(i))).collect();
183 let out = td.batch(&candles).into_iter().flatten().last().unwrap();
184 assert!(out.st1 > out.st2, "fast MA should lead in an uptrend");
185 }
186
187 #[test]
188 fn fast_below_slow_in_downtrend() {
189 let mut td = TdMovingAverage::new(3, 7).unwrap();
190 let candles: Vec<Candle> = (0..40).map(|i| c(200.0 - f64::from(i))).collect();
191 let out = td.batch(&candles).into_iter().flatten().last().unwrap();
192 assert!(out.st1 < out.st2, "fast MA should trail in a downtrend");
193 }
194
195 #[test]
196 fn flat_series_equal_lines() {
197 let mut td = TdMovingAverage::new(2, 4).unwrap();
198 let out = td
199 .batch(&[c(50.0); 10])
200 .into_iter()
201 .flatten()
202 .last()
203 .unwrap();
204 assert_relative_eq!(out.st1, 50.0, epsilon = 1e-9);
205 assert_relative_eq!(out.st2, 50.0, epsilon = 1e-9);
206 }
207
208 #[test]
209 fn reset_clears_state() {
210 let mut td = TdMovingAverage::new(2, 4).unwrap();
211 td.batch(&(0..10).map(|i| c(100.0 + f64::from(i))).collect::<Vec<_>>());
212 assert!(td.is_ready());
213 td.reset();
214 assert!(!td.is_ready());
215 assert_eq!(td.value(), None);
216 assert_eq!(td.update(c(100.0)), None);
217 }
218
219 #[test]
220 fn batch_equals_streaming() {
221 let candles: Vec<Candle> = (0..80)
222 .map(|i| c(100.0 + (f64::from(i) * 0.25).sin() * 9.0))
223 .collect();
224 let batch = TdMovingAverage::new(5, 13).unwrap().batch(&candles);
225 let mut b = TdMovingAverage::new(5, 13).unwrap();
226 let streamed: Vec<_> = candles.iter().map(|x| b.update(*x)).collect();
227 assert_eq!(batch, streamed);
228 }
229}