Skip to main content

wickra_core/indicators/
trend_strength_index.rs

1//! Trend Strength Index — the signed coefficient of determination of a linear
2//! regression of price against time.
3
4use std::collections::VecDeque;
5
6use crate::error::{Error, Result};
7use crate::traits::Indicator;
8
9/// Trend Strength Index: fits an ordinary-least-squares line to the last
10/// `period` prices against their bar index and reports the coefficient of
11/// determination `r^2`, signed by the slope of the fit.
12///
13/// ```text
14/// regress y = close on x = 0..period-1
15/// r^2  = (n·Σxy − Σx·Σy)^2 / [ (n·Σx² − (Σx)²)(n·Σy² − (Σy)²) ]
16/// TSI  = sign(slope) · r^2          (slope sign = sign of n·Σxy − Σx·Σy)
17/// ```
18///
19/// `r^2` in `[0, 1]` measures how well a straight line explains the price over
20/// the window — how *trendy* the segment is, regardless of direction. Carrying
21/// the slope sign turns it into a directional reading in `[-1, 1]`: values near
22/// `+1` are a strong, clean uptrend; near `-1` a strong downtrend; near `0` a
23/// flat or noisy market with no linear structure. A window of constant prices
24/// (zero variance in `y`) has no defined trend and returns `0`.
25///
26/// # Example
27///
28/// ```
29/// use wickra_core::{Indicator, TrendStrengthIndex};
30///
31/// let mut indicator = TrendStrengthIndex::new(20).unwrap();
32/// let mut last = None;
33/// for i in 0..40 {
34///     last = indicator.update(100.0 + f64::from(i));
35/// }
36/// // A clean ramp is a perfect uptrend -> r^2 = 1.
37/// assert!((last.unwrap() - 1.0).abs() < 1e-9);
38/// ```
39#[derive(Debug, Clone)]
40pub struct TrendStrengthIndex {
41    period: usize,
42    buf: VecDeque<f64>,
43}
44
45impl TrendStrengthIndex {
46    /// Construct a Trend Strength Index over the given window.
47    ///
48    /// # Errors
49    ///
50    /// Returns [`Error::PeriodZero`] if `period == 0`, or [`Error::InvalidPeriod`]
51    /// if `period == 1` (a regression needs at least two points).
52    pub fn new(period: usize) -> Result<Self> {
53        if period == 0 {
54            return Err(Error::PeriodZero);
55        }
56        if period == 1 {
57            return Err(Error::InvalidPeriod {
58                message: "period must be >= 2 for a regression",
59            });
60        }
61        Ok(Self {
62            period,
63            buf: VecDeque::with_capacity(period),
64        })
65    }
66
67    /// Configured window length.
68    pub const fn period(&self) -> usize {
69        self.period
70    }
71}
72
73impl Indicator for TrendStrengthIndex {
74    type Input = f64;
75    type Output = f64;
76
77    fn update(&mut self, price: f64) -> Option<f64> {
78        if !price.is_finite() {
79            return None;
80        }
81        self.buf.push_back(price);
82        if self.buf.len() > self.period {
83            self.buf.pop_front();
84        }
85        if self.buf.len() < self.period {
86            return None;
87        }
88
89        let count = self.period as f64;
90        let mut sum_x = 0.0;
91        let mut sum_xx = 0.0;
92        let mut sum_y = 0.0;
93        let mut sum_yy = 0.0;
94        let mut sum_xy = 0.0;
95        for (idx, &price) in self.buf.iter().enumerate() {
96            let x = idx as f64;
97            sum_x += x;
98            sum_xx += x * x;
99            sum_y += price;
100            sum_yy += price * price;
101            sum_xy += x * price;
102        }
103
104        let cov = count.mul_add(sum_xy, -(sum_x * sum_y));
105        let var_x = count.mul_add(sum_xx, -(sum_x * sum_x));
106        let var_y = count.mul_add(sum_yy, -(sum_y * sum_y));
107        if var_y <= 0.0 {
108            return Some(0.0);
109        }
110        let r2 = (cov * cov) / (var_x * var_y);
111        Some(if cov >= 0.0 { r2 } else { -r2 })
112    }
113
114    fn reset(&mut self) {
115        self.buf.clear();
116    }
117
118    fn warmup_period(&self) -> usize {
119        self.period
120    }
121
122    fn is_ready(&self) -> bool {
123        self.buf.len() >= self.period
124    }
125
126    fn name(&self) -> &'static str {
127        "TrendStrengthIndex"
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134    use crate::traits::BatchExt;
135    use approx::assert_relative_eq;
136
137    #[test]
138    fn rejects_invalid_period() {
139        assert!(matches!(TrendStrengthIndex::new(0), Err(Error::PeriodZero)));
140        assert!(matches!(
141            TrendStrengthIndex::new(1),
142            Err(Error::InvalidPeriod { .. })
143        ));
144    }
145
146    #[test]
147    fn accessors_and_metadata() {
148        let tsi = TrendStrengthIndex::new(20).unwrap();
149        assert_eq!(tsi.period(), 20);
150        assert_eq!(tsi.warmup_period(), 20);
151        assert_eq!(tsi.name(), "TrendStrengthIndex");
152        assert!(!tsi.is_ready());
153    }
154
155    #[test]
156    fn warmup_emits_at_period() {
157        let mut tsi = TrendStrengthIndex::new(4).unwrap();
158        let inputs: Vec<f64> = (0..6).map(f64::from).collect();
159        let out = tsi.batch(&inputs);
160        assert!(out[2].is_none());
161        assert!(out[3].is_some());
162    }
163
164    #[test]
165    fn perfect_uptrend_is_plus_one() {
166        let mut tsi = TrendStrengthIndex::new(10).unwrap();
167        let inputs: Vec<f64> = (0..10).map(f64::from).collect();
168        let last = tsi.batch(&inputs).last().unwrap().unwrap();
169        assert_relative_eq!(last, 1.0, epsilon = 1e-9);
170    }
171
172    #[test]
173    fn perfect_downtrend_is_minus_one() {
174        let mut tsi = TrendStrengthIndex::new(10).unwrap();
175        let inputs: Vec<f64> = (0..10).map(|i| 100.0 - f64::from(i)).collect();
176        let last = tsi.batch(&inputs).last().unwrap().unwrap();
177        assert_relative_eq!(last, -1.0, epsilon = 1e-9);
178    }
179
180    #[test]
181    fn flat_market_returns_zero() {
182        let mut tsi = TrendStrengthIndex::new(8).unwrap();
183        let inputs = [42.0; 12];
184        let last = tsi.batch(&inputs).last().unwrap().unwrap();
185        assert_relative_eq!(last, 0.0, epsilon = 1e-12);
186    }
187
188    #[test]
189    fn noisy_trend_is_between() {
190        // An upward drift with noise: positive but not a perfect fit.
191        let mut tsi = TrendStrengthIndex::new(12).unwrap();
192        let inputs: Vec<f64> = (0..12)
193            .map(|i| f64::from(i) + if i % 2 == 0 { 0.0 } else { 3.0 })
194            .collect();
195        let last = tsi.batch(&inputs).last().unwrap().unwrap();
196        assert!(last > 0.0 && last < 1.0, "tsi {last} should be in (0, 1)");
197    }
198
199    #[test]
200    fn reset_clears_state() {
201        let mut tsi = TrendStrengthIndex::new(10).unwrap();
202        let inputs: Vec<f64> = (0..10).map(f64::from).collect();
203        tsi.batch(&inputs);
204        assert!(tsi.is_ready());
205        tsi.reset();
206        assert!(!tsi.is_ready());
207    }
208
209    #[test]
210    fn batch_equals_streaming() {
211        let inputs: Vec<f64> = (0..80)
212            .map(|i| 100.0 + (f64::from(i) * 0.2).sin() * 5.0)
213            .collect();
214        let mut a = TrendStrengthIndex::new(15).unwrap();
215        let mut b = TrendStrengthIndex::new(15).unwrap();
216        assert_eq!(
217            a.batch(&inputs),
218            inputs.iter().map(|x| b.update(*x)).collect::<Vec<_>>()
219        );
220    }
221}