Skip to main content

wickra_core/indicators/
td_moving_average.rs

1#![allow(clippy::doc_markdown)]
2
3//! Tom DeMark TD Moving Averages — the ST1 (fast) and ST2 (slow) trend ribbon.
4
5use crate::error::{Error, Result};
6use crate::indicators::sma::Sma;
7use crate::ohlcv::Candle;
8use crate::traits::Indicator;
9
10/// Output of [`TdMovingAverage`]: the fast (`st1`) and slow (`st2`) moving-average
11/// lines.
12#[derive(Debug, Clone, Copy, PartialEq)]
13pub struct TdMovingAverageOutput {
14    /// ST1 — the fast (short) moving average.
15    pub st1: f64,
16    /// ST2 — the slow (long) moving average.
17    pub st2: f64,
18}
19
20/// Tom DeMark **TD Moving Averages** — a two-line trend ribbon (ST1 fast, ST2
21/// slow) computed on the median price, whose relationship defines the trend.
22///
23/// ```text
24/// price = (high + low) / 2          (median price)
25/// st1   = SMA(price, period_st1)    (fast / "Sequential Trend 1")
26/// st2   = SMA(price, period_st2)    (slow / "Sequential Trend 2")
27/// ```
28///
29/// DeMark's moving-average pair frames the trend objectively: when `st1` is above
30/// `st2` the trend is up, below it down, and the cross marks the change. Using the
31/// **median price** rather than the close de-emphasises closing noise. This is a
32/// streaming dual-SMA implementation of the ST1/ST2 ribbon; read the lines and
33/// their crossover exactly as a fast/slow moving-average system.
34///
35/// `period_st1` must be strictly smaller than `period_st2`. The first value lands
36/// once the slow average is seeded (`period_st2` inputs). Each `update` is O(1).
37///
38/// # Example
39///
40/// ```
41/// use wickra_core::{Candle, Indicator, TdMovingAverage};
42///
43/// let mut indicator = TdMovingAverage::new(5, 13).unwrap();
44/// let mut last = None;
45/// for i in 0..40 {
46///     let base = 100.0 + f64::from(i);
47///     let c = Candle::new(base, base + 1.0, base - 1.0, base, 1_000.0, 0).unwrap();
48///     last = indicator.update(c);
49/// }
50/// assert!(last.is_some());
51/// ```
52#[derive(Debug, Clone)]
53pub struct TdMovingAverage {
54    st1: Sma,
55    st2: Sma,
56    period_st1: usize,
57    period_st2: usize,
58    last: Option<TdMovingAverageOutput>,
59}
60
61impl TdMovingAverage {
62    /// Construct TD Moving Averages with the given fast and slow periods.
63    ///
64    /// # Errors
65    ///
66    /// Returns [`Error::PeriodZero`] if either period is `0`, and
67    /// [`Error::InvalidPeriod`] if `period_st1 >= period_st2`.
68    pub fn new(period_st1: usize, period_st2: usize) -> Result<Self> {
69        if period_st1 == 0 || period_st2 == 0 {
70            return Err(Error::PeriodZero);
71        }
72        if period_st1 >= period_st2 {
73            return Err(Error::InvalidPeriod {
74                message: "TD moving average ST1 period must be strictly less than ST2",
75            });
76        }
77        Ok(Self {
78            st1: Sma::new(period_st1)?,
79            st2: Sma::new(period_st2)?,
80            period_st1,
81            period_st2,
82            last: None,
83        })
84    }
85
86    /// Configured `(period_st1, period_st2)`.
87    pub const fn periods(&self) -> (usize, usize) {
88        (self.period_st1, self.period_st2)
89    }
90
91    /// Current value if available.
92    pub const fn value(&self) -> Option<TdMovingAverageOutput> {
93        self.last
94    }
95}
96
97impl Indicator for TdMovingAverage {
98    type Input = Candle;
99    type Output = TdMovingAverageOutput;
100
101    fn update(&mut self, candle: Candle) -> Option<TdMovingAverageOutput> {
102        let price = candle.median_price();
103        let fast = self.st1.update(price);
104        let slow = self.st2.update(price);
105        if let (Some(st1), Some(st2)) = (fast, slow) {
106            let out = TdMovingAverageOutput { st1, st2 };
107            self.last = Some(out);
108            return Some(out);
109        }
110        None
111    }
112
113    fn reset(&mut self) {
114        self.st1.reset();
115        self.st2.reset();
116        self.last = None;
117    }
118
119    fn warmup_period(&self) -> usize {
120        self.period_st2
121    }
122
123    fn is_ready(&self) -> bool {
124        self.last.is_some()
125    }
126
127    fn name(&self) -> &'static str {
128        "TDMovingAverage"
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    use crate::traits::BatchExt;
136    use approx::assert_relative_eq;
137
138    fn c(median: f64) -> Candle {
139        Candle::new_unchecked(median, median + 1.0, median - 1.0, median, 1_000.0, 0)
140    }
141
142    #[test]
143    fn rejects_invalid_periods() {
144        assert!(matches!(
145            TdMovingAverage::new(0, 13),
146            Err(Error::PeriodZero)
147        ));
148        assert!(matches!(
149            TdMovingAverage::new(13, 5),
150            Err(Error::InvalidPeriod { .. })
151        ));
152        assert!(matches!(
153            TdMovingAverage::new(5, 5),
154            Err(Error::InvalidPeriod { .. })
155        ));
156    }
157
158    #[test]
159    fn accessors_and_metadata() {
160        let td = TdMovingAverage::new(5, 13).unwrap();
161        assert_eq!(td.periods(), (5, 13));
162        assert_eq!(td.warmup_period(), 13);
163        assert_eq!(td.name(), "TDMovingAverage");
164        assert!(!td.is_ready());
165        assert_eq!(td.value(), None);
166    }
167
168    #[test]
169    fn first_emission_at_warmup_period() {
170        let mut td = TdMovingAverage::new(2, 4).unwrap();
171        let candles: Vec<Candle> = (0..8).map(|i| c(100.0 + f64::from(i))).collect();
172        let out = td.batch(&candles);
173        for v in out.iter().take(3) {
174            assert!(v.is_none());
175        }
176        assert!(out[3].is_some());
177    }
178
179    #[test]
180    fn fast_leads_slow_in_uptrend() {
181        let mut td = TdMovingAverage::new(3, 7).unwrap();
182        let candles: Vec<Candle> = (0..40).map(|i| c(100.0 + f64::from(i))).collect();
183        let out = td.batch(&candles).into_iter().flatten().last().unwrap();
184        assert!(out.st1 > out.st2, "fast MA should lead in an uptrend");
185    }
186
187    #[test]
188    fn fast_below_slow_in_downtrend() {
189        let mut td = TdMovingAverage::new(3, 7).unwrap();
190        let candles: Vec<Candle> = (0..40).map(|i| c(200.0 - f64::from(i))).collect();
191        let out = td.batch(&candles).into_iter().flatten().last().unwrap();
192        assert!(out.st1 < out.st2, "fast MA should trail in a downtrend");
193    }
194
195    #[test]
196    fn flat_series_equal_lines() {
197        let mut td = TdMovingAverage::new(2, 4).unwrap();
198        let out = td
199            .batch(&[c(50.0); 10])
200            .into_iter()
201            .flatten()
202            .last()
203            .unwrap();
204        assert_relative_eq!(out.st1, 50.0, epsilon = 1e-9);
205        assert_relative_eq!(out.st2, 50.0, epsilon = 1e-9);
206    }
207
208    #[test]
209    fn reset_clears_state() {
210        let mut td = TdMovingAverage::new(2, 4).unwrap();
211        td.batch(&(0..10).map(|i| c(100.0 + f64::from(i))).collect::<Vec<_>>());
212        assert!(td.is_ready());
213        td.reset();
214        assert!(!td.is_ready());
215        assert_eq!(td.value(), None);
216        assert_eq!(td.update(c(100.0)), None);
217    }
218
219    #[test]
220    fn batch_equals_streaming() {
221        let candles: Vec<Candle> = (0..80)
222            .map(|i| c(100.0 + (f64::from(i) * 0.25).sin() * 9.0))
223            .collect();
224        let batch = TdMovingAverage::new(5, 13).unwrap().batch(&candles);
225        let mut b = TdMovingAverage::new(5, 13).unwrap();
226        let streamed: Vec<_> = candles.iter().map(|x| b.update(*x)).collect();
227        assert_eq!(batch, streamed);
228    }
229}