Skip to main content

wickra_core/indicators/
trend_strength_index.rs

1//! Trend Strength Index — the signed coefficient of determination of a linear
2//! regression of price against time.
3
4use std::collections::VecDeque;
5
6use crate::error::{Error, Result};
7use crate::traits::Indicator;
8
9/// Trend Strength Index: fits an ordinary-least-squares line to the last
10/// `period` prices against their bar index and reports the coefficient of
11/// determination `r^2`, signed by the slope of the fit.
12///
13/// ```text
14/// regress y = close on x = 0..period-1
15/// r^2  = (n·Σxy − Σx·Σy)^2 / [ (n·Σx² − (Σx)²)(n·Σy² − (Σy)²) ]
16/// TSI  = sign(slope) · r^2          (slope sign = sign of n·Σxy − Σx·Σy)
17/// ```
18///
19/// `r^2` in `[0, 1]` measures how well a straight line explains the price over
20/// the window — how *trendy* the segment is, regardless of direction. Carrying
21/// the slope sign turns it into a directional reading in `[-1, 1]`: values near
22/// `+1` are a strong, clean uptrend; near `-1` a strong downtrend; near `0` a
23/// flat or noisy market with no linear structure. A window of constant prices
24/// (zero variance in `y`) has no defined trend and returns `0`.
25///
26/// # Example
27///
28/// ```
29/// use wickra_core::{Indicator, TrendStrengthIndex};
30///
31/// let mut indicator = TrendStrengthIndex::new(20).unwrap();
32/// let mut last = None;
33/// for i in 0..40 {
34///     last = indicator.update(100.0 + f64::from(i));
35/// }
36/// // A clean ramp is a perfect uptrend -> r^2 = 1.
37/// assert!((last.unwrap() - 1.0).abs() < 1e-9);
38/// ```
39#[derive(Debug, Clone)]
40pub struct TrendStrengthIndex {
41    period: usize,
42    buf: VecDeque<f64>,
43}
44
45impl TrendStrengthIndex {
46    /// Construct a Trend Strength Index over the given window.
47    ///
48    /// # Errors
49    ///
50    /// Returns [`Error::PeriodZero`] if `period == 0`, or [`Error::InvalidPeriod`]
51    /// if `period == 1` (a regression needs at least two points).
52    pub fn new(period: usize) -> Result<Self> {
53        if period == 0 {
54            return Err(Error::PeriodZero);
55        }
56        if period == 1 {
57            return Err(Error::InvalidPeriod {
58                message: "period must be >= 2 for a regression",
59            });
60        }
61        Ok(Self {
62            period,
63            buf: VecDeque::with_capacity(period),
64        })
65    }
66
67    /// Configured window length.
68    pub const fn period(&self) -> usize {
69        self.period
70    }
71}
72
73impl Indicator for TrendStrengthIndex {
74    type Input = f64;
75    type Output = f64;
76
77    fn update(&mut self, price: f64) -> Option<f64> {
78        self.buf.push_back(price);
79        if self.buf.len() > self.period {
80            self.buf.pop_front();
81        }
82        if self.buf.len() < self.period {
83            return None;
84        }
85
86        let count = self.period as f64;
87        let mut sum_x = 0.0;
88        let mut sum_xx = 0.0;
89        let mut sum_y = 0.0;
90        let mut sum_yy = 0.0;
91        let mut sum_xy = 0.0;
92        for (idx, &price) in self.buf.iter().enumerate() {
93            let x = idx as f64;
94            sum_x += x;
95            sum_xx += x * x;
96            sum_y += price;
97            sum_yy += price * price;
98            sum_xy += x * price;
99        }
100
101        let cov = count.mul_add(sum_xy, -(sum_x * sum_y));
102        let var_x = count.mul_add(sum_xx, -(sum_x * sum_x));
103        let var_y = count.mul_add(sum_yy, -(sum_y * sum_y));
104        if var_y <= 0.0 {
105            return Some(0.0);
106        }
107        let r2 = (cov * cov) / (var_x * var_y);
108        Some(if cov >= 0.0 { r2 } else { -r2 })
109    }
110
111    fn reset(&mut self) {
112        self.buf.clear();
113    }
114
115    fn warmup_period(&self) -> usize {
116        self.period
117    }
118
119    fn is_ready(&self) -> bool {
120        self.buf.len() >= self.period
121    }
122
123    fn name(&self) -> &'static str {
124        "TrendStrengthIndex"
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131    use crate::traits::BatchExt;
132    use approx::assert_relative_eq;
133
134    #[test]
135    fn rejects_invalid_period() {
136        assert!(matches!(TrendStrengthIndex::new(0), Err(Error::PeriodZero)));
137        assert!(matches!(
138            TrendStrengthIndex::new(1),
139            Err(Error::InvalidPeriod { .. })
140        ));
141    }
142
143    #[test]
144    fn accessors_and_metadata() {
145        let tsi = TrendStrengthIndex::new(20).unwrap();
146        assert_eq!(tsi.period(), 20);
147        assert_eq!(tsi.warmup_period(), 20);
148        assert_eq!(tsi.name(), "TrendStrengthIndex");
149        assert!(!tsi.is_ready());
150    }
151
152    #[test]
153    fn warmup_emits_at_period() {
154        let mut tsi = TrendStrengthIndex::new(4).unwrap();
155        let inputs: Vec<f64> = (0..6).map(f64::from).collect();
156        let out = tsi.batch(&inputs);
157        assert!(out[2].is_none());
158        assert!(out[3].is_some());
159    }
160
161    #[test]
162    fn perfect_uptrend_is_plus_one() {
163        let mut tsi = TrendStrengthIndex::new(10).unwrap();
164        let inputs: Vec<f64> = (0..10).map(f64::from).collect();
165        let last = tsi.batch(&inputs).last().unwrap().unwrap();
166        assert_relative_eq!(last, 1.0, epsilon = 1e-9);
167    }
168
169    #[test]
170    fn perfect_downtrend_is_minus_one() {
171        let mut tsi = TrendStrengthIndex::new(10).unwrap();
172        let inputs: Vec<f64> = (0..10).map(|i| 100.0 - f64::from(i)).collect();
173        let last = tsi.batch(&inputs).last().unwrap().unwrap();
174        assert_relative_eq!(last, -1.0, epsilon = 1e-9);
175    }
176
177    #[test]
178    fn flat_market_returns_zero() {
179        let mut tsi = TrendStrengthIndex::new(8).unwrap();
180        let inputs = [42.0; 12];
181        let last = tsi.batch(&inputs).last().unwrap().unwrap();
182        assert_relative_eq!(last, 0.0, epsilon = 1e-12);
183    }
184
185    #[test]
186    fn noisy_trend_is_between() {
187        // An upward drift with noise: positive but not a perfect fit.
188        let mut tsi = TrendStrengthIndex::new(12).unwrap();
189        let inputs: Vec<f64> = (0..12)
190            .map(|i| f64::from(i) + if i % 2 == 0 { 0.0 } else { 3.0 })
191            .collect();
192        let last = tsi.batch(&inputs).last().unwrap().unwrap();
193        assert!(last > 0.0 && last < 1.0, "tsi {last} should be in (0, 1)");
194    }
195
196    #[test]
197    fn reset_clears_state() {
198        let mut tsi = TrendStrengthIndex::new(10).unwrap();
199        let inputs: Vec<f64> = (0..10).map(f64::from).collect();
200        tsi.batch(&inputs);
201        assert!(tsi.is_ready());
202        tsi.reset();
203        assert!(!tsi.is_ready());
204    }
205
206    #[test]
207    fn batch_equals_streaming() {
208        let inputs: Vec<f64> = (0..80)
209            .map(|i| 100.0 + (f64::from(i) * 0.2).sin() * 5.0)
210            .collect();
211        let mut a = TrendStrengthIndex::new(15).unwrap();
212        let mut b = TrendStrengthIndex::new(15).unwrap();
213        assert_eq!(
214            a.batch(&inputs),
215            inputs.iter().map(|x| b.update(*x)).collect::<Vec<_>>()
216        );
217    }
218}