Skip to main content

wickra_core/indicators/
trend_label.rs

1//! Trend Label — the sign of the rolling least-squares slope.
2
3use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8/// Trend Label — a discrete `{−1, 0, +1}` classification of the local trend from
9/// the sign of the ordinary-least-squares slope over the last `period` values.
10///
11/// ```text
12/// slope = Σ (tᵢ − t̄)(xᵢ − x̄) / Σ (tᵢ − t̄)²      (regress price on bar index)
13/// label = +1 if slope > 0,  −1 if slope < 0,  0 if slope == 0
14/// ```
15///
16/// The sign of the regression slope is *scale-invariant* — it does not depend on
17/// the nominal price level — which makes it a clean, comparable trend state
18/// across instruments. `+1` marks a rising regression line, `−1` a falling one,
19/// and `0` a perfectly flat window. It is the discrete companion to
20/// [`LinRegSlope`](crate::LinRegSlope) (which returns the continuous slope): use
21/// the label when a feature pipeline wants a categorical trend direction and
22/// keys any magnitude / dead-band tuning on the raw slope itself.
23///
24/// Each `update` is `O(period)`: the slope numerator is recomputed from the
25/// window. The denominator `Σ(tᵢ − t̄)²` is strictly positive for `period ≥ 2`,
26/// so the sign is always well-defined.
27///
28/// # Example
29///
30/// ```
31/// use wickra_core::{Indicator, TrendLabel};
32///
33/// let mut indicator = TrendLabel::new(10).unwrap();
34/// let mut last = None;
35/// for i in 0..20 {
36///     last = indicator.update(100.0 + f64::from(i)); // strictly rising
37/// }
38/// assert_eq!(last, Some(1.0));
39/// ```
40#[derive(Debug, Clone)]
41pub struct TrendLabel {
42    period: usize,
43    window: VecDeque<f64>,
44}
45
46impl TrendLabel {
47    /// Construct a new Trend Label classifier.
48    ///
49    /// # Errors
50    /// Returns [`Error::InvalidPeriod`] if `period < 2` — a slope needs at least
51    /// two points.
52    pub fn new(period: usize) -> Result<Self> {
53        if period < 2 {
54            return Err(Error::InvalidPeriod {
55                message: "trend label needs period >= 2",
56            });
57        }
58        Ok(Self {
59            period,
60            window: VecDeque::with_capacity(period),
61        })
62    }
63
64    /// Configured period.
65    pub const fn period(&self) -> usize {
66        self.period
67    }
68}
69
70impl Indicator for TrendLabel {
71    type Input = f64;
72    type Output = f64;
73
74    fn update(&mut self, value: f64) -> Option<f64> {
75        if !value.is_finite() {
76            return None;
77        }
78        if self.window.len() == self.period {
79            self.window.pop_front();
80        }
81        self.window.push_back(value);
82        if self.window.len() < self.period {
83            return None;
84        }
85        let count = self.period as f64;
86        let mean_t = (count - 1.0) / 2.0;
87        let mean_x = self.window.iter().sum::<f64>() / count;
88        // Slope numerator: Σ (t − t̄)(x − x̄). The denominator Σ(t − t̄)² > 0 for
89        // period >= 2, so the slope sign equals the numerator sign.
90        let mut numerator = 0.0;
91        for (t, &x) in self.window.iter().enumerate() {
92            numerator += (t as f64 - mean_t) * (x - mean_x);
93        }
94        let label = if numerator > 0.0 {
95            1.0
96        } else if numerator < 0.0 {
97            -1.0
98        } else {
99            0.0
100        };
101        Some(label)
102    }
103
104    fn reset(&mut self) {
105        self.window.clear();
106    }
107
108    fn warmup_period(&self) -> usize {
109        self.period
110    }
111
112    fn is_ready(&self) -> bool {
113        self.window.len() == self.period
114    }
115
116    fn name(&self) -> &'static str {
117        "TrendLabel"
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124    use crate::traits::BatchExt;
125
126    #[test]
127    fn rejects_period_below_two() {
128        assert!(matches!(
129            TrendLabel::new(1),
130            Err(Error::InvalidPeriod { .. })
131        ));
132        assert!(TrendLabel::new(2).is_ok());
133    }
134
135    #[test]
136    fn accessors_and_metadata() {
137        let tl = TrendLabel::new(10).unwrap();
138        assert_eq!(tl.period(), 10);
139        assert_eq!(tl.warmup_period(), 10);
140        assert_eq!(tl.name(), "TrendLabel");
141        assert!(!tl.is_ready());
142    }
143
144    #[test]
145    fn rising_series_is_plus_one() {
146        let mut tl = TrendLabel::new(10).unwrap();
147        let prices: Vec<f64> = (0..20).map(f64::from).collect();
148        assert_eq!(tl.batch(&prices).into_iter().flatten().last(), Some(1.0));
149    }
150
151    #[test]
152    fn falling_series_is_minus_one() {
153        let mut tl = TrendLabel::new(10).unwrap();
154        let prices: Vec<f64> = (0..20).map(|i| 100.0 - f64::from(i)).collect();
155        assert_eq!(tl.batch(&prices).into_iter().flatten().last(), Some(-1.0));
156    }
157
158    #[test]
159    fn flat_series_is_zero() {
160        let mut tl = TrendLabel::new(8).unwrap();
161        for v in tl.batch(&[42.0; 16]).into_iter().flatten() {
162            assert_eq!(v, 0.0);
163        }
164    }
165
166    #[test]
167    fn scale_invariant_sign() {
168        // Multiplying the whole series by a constant cannot change the trend sign.
169        let prices: Vec<f64> = (0..30)
170            .map(|i| 100.0 + (f64::from(i) * 0.4).sin() * 5.0)
171            .collect();
172        let small = TrendLabel::new(12).unwrap().batch(&prices);
173        let scaled: Vec<f64> = prices.iter().map(|p| p * 1000.0).collect();
174        let large = TrendLabel::new(12).unwrap().batch(&scaled);
175        assert_eq!(small, large);
176    }
177
178    #[test]
179    fn output_is_ternary() {
180        let mut tl = TrendLabel::new(14).unwrap();
181        let prices: Vec<f64> = (0..200)
182            .map(|i| 100.0 + (f64::from(i) * 0.3).sin() * 10.0)
183            .collect();
184        for v in tl.batch(&prices).into_iter().flatten() {
185            assert!(v == -1.0 || v == 0.0 || v == 1.0, "non-ternary label {v}");
186        }
187    }
188
189    #[test]
190    fn reset_clears_state() {
191        let mut tl = TrendLabel::new(5).unwrap();
192        tl.batch(&[1.0, 2.0, 3.0, 4.0, 5.0]);
193        assert!(tl.is_ready());
194        tl.reset();
195        assert!(!tl.is_ready());
196        assert_eq!(tl.update(1.0), None);
197    }
198
199    #[test]
200    fn batch_equals_streaming() {
201        let prices: Vec<f64> = (0..60)
202            .map(|i| 100.0 + (f64::from(i) * 0.3).sin() * 5.0)
203            .collect();
204        let batch = TrendLabel::new(14).unwrap().batch(&prices);
205        let mut b = TrendLabel::new(14).unwrap();
206        let streamed: Vec<_> = prices.iter().map(|p| b.update(*p)).collect();
207        assert_eq!(batch, streamed);
208    }
209}