Skip to main content

wickra_core/indicators/
trend_label.rs

1//! Trend Label — the sign of the rolling least-squares slope.
2
3use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8/// Trend Label — a discrete `{−1, 0, +1}` classification of the local trend from
9/// the sign of the ordinary-least-squares slope over the last `period` values.
10///
11/// ```text
12/// slope = Σ (tᵢ − t̄)(xᵢ − x̄) / Σ (tᵢ − t̄)²      (regress price on bar index)
13/// label = +1 if slope > 0,  −1 if slope < 0,  0 if slope == 0
14/// ```
15///
16/// The sign of the regression slope is *scale-invariant* — it does not depend on
17/// the nominal price level — which makes it a clean, comparable trend state
18/// across instruments. `+1` marks a rising regression line, `−1` a falling one,
19/// and `0` a perfectly flat window. It is the discrete companion to
20/// [`LinRegSlope`](crate::LinRegSlope) (which returns the continuous slope): use
21/// the label when a feature pipeline wants a categorical trend direction and
22/// keys any magnitude / dead-band tuning on the raw slope itself.
23///
24/// Each `update` is `O(period)`: the slope numerator is recomputed from the
25/// window. The denominator `Σ(tᵢ − t̄)²` is strictly positive for `period ≥ 2`,
26/// so the sign is always well-defined.
27///
28/// # Example
29///
30/// ```
31/// use wickra_core::{Indicator, TrendLabel};
32///
33/// let mut indicator = TrendLabel::new(10).unwrap();
34/// let mut last = None;
35/// for i in 0..20 {
36///     last = indicator.update(100.0 + f64::from(i)); // strictly rising
37/// }
38/// assert_eq!(last, Some(1.0));
39/// ```
40#[derive(Debug, Clone)]
41pub struct TrendLabel {
42    period: usize,
43    window: VecDeque<f64>,
44}
45
46impl TrendLabel {
47    /// Construct a new Trend Label classifier.
48    ///
49    /// # Errors
50    /// Returns [`Error::InvalidPeriod`] if `period < 2` — a slope needs at least
51    /// two points.
52    pub fn new(period: usize) -> Result<Self> {
53        if period < 2 {
54            return Err(Error::InvalidPeriod {
55                message: "trend label needs period >= 2",
56            });
57        }
58        Ok(Self {
59            period,
60            window: VecDeque::with_capacity(period),
61        })
62    }
63
64    /// Configured period.
65    pub const fn period(&self) -> usize {
66        self.period
67    }
68}
69
70impl Indicator for TrendLabel {
71    type Input = f64;
72    type Output = f64;
73
74    fn update(&mut self, value: f64) -> Option<f64> {
75        if self.window.len() == self.period {
76            self.window.pop_front();
77        }
78        self.window.push_back(value);
79        if self.window.len() < self.period {
80            return None;
81        }
82        let count = self.period as f64;
83        let mean_t = (count - 1.0) / 2.0;
84        let mean_x = self.window.iter().sum::<f64>() / count;
85        // Slope numerator: Σ (t − t̄)(x − x̄). The denominator Σ(t − t̄)² > 0 for
86        // period >= 2, so the slope sign equals the numerator sign.
87        let mut numerator = 0.0;
88        for (t, &x) in self.window.iter().enumerate() {
89            numerator += (t as f64 - mean_t) * (x - mean_x);
90        }
91        let label = if numerator > 0.0 {
92            1.0
93        } else if numerator < 0.0 {
94            -1.0
95        } else {
96            0.0
97        };
98        Some(label)
99    }
100
101    fn reset(&mut self) {
102        self.window.clear();
103    }
104
105    fn warmup_period(&self) -> usize {
106        self.period
107    }
108
109    fn is_ready(&self) -> bool {
110        self.window.len() == self.period
111    }
112
113    fn name(&self) -> &'static str {
114        "TrendLabel"
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121    use crate::traits::BatchExt;
122
123    #[test]
124    fn rejects_period_below_two() {
125        assert!(matches!(
126            TrendLabel::new(1),
127            Err(Error::InvalidPeriod { .. })
128        ));
129        assert!(TrendLabel::new(2).is_ok());
130    }
131
132    #[test]
133    fn accessors_and_metadata() {
134        let tl = TrendLabel::new(10).unwrap();
135        assert_eq!(tl.period(), 10);
136        assert_eq!(tl.warmup_period(), 10);
137        assert_eq!(tl.name(), "TrendLabel");
138        assert!(!tl.is_ready());
139    }
140
141    #[test]
142    fn rising_series_is_plus_one() {
143        let mut tl = TrendLabel::new(10).unwrap();
144        let prices: Vec<f64> = (0..20).map(f64::from).collect();
145        assert_eq!(tl.batch(&prices).into_iter().flatten().last(), Some(1.0));
146    }
147
148    #[test]
149    fn falling_series_is_minus_one() {
150        let mut tl = TrendLabel::new(10).unwrap();
151        let prices: Vec<f64> = (0..20).map(|i| 100.0 - f64::from(i)).collect();
152        assert_eq!(tl.batch(&prices).into_iter().flatten().last(), Some(-1.0));
153    }
154
155    #[test]
156    fn flat_series_is_zero() {
157        let mut tl = TrendLabel::new(8).unwrap();
158        for v in tl.batch(&[42.0; 16]).into_iter().flatten() {
159            assert_eq!(v, 0.0);
160        }
161    }
162
163    #[test]
164    fn scale_invariant_sign() {
165        // Multiplying the whole series by a constant cannot change the trend sign.
166        let prices: Vec<f64> = (0..30)
167            .map(|i| 100.0 + (f64::from(i) * 0.4).sin() * 5.0)
168            .collect();
169        let small = TrendLabel::new(12).unwrap().batch(&prices);
170        let scaled: Vec<f64> = prices.iter().map(|p| p * 1000.0).collect();
171        let large = TrendLabel::new(12).unwrap().batch(&scaled);
172        assert_eq!(small, large);
173    }
174
175    #[test]
176    fn output_is_ternary() {
177        let mut tl = TrendLabel::new(14).unwrap();
178        let prices: Vec<f64> = (0..200)
179            .map(|i| 100.0 + (f64::from(i) * 0.3).sin() * 10.0)
180            .collect();
181        for v in tl.batch(&prices).into_iter().flatten() {
182            assert!(v == -1.0 || v == 0.0 || v == 1.0, "non-ternary label {v}");
183        }
184    }
185
186    #[test]
187    fn reset_clears_state() {
188        let mut tl = TrendLabel::new(5).unwrap();
189        tl.batch(&[1.0, 2.0, 3.0, 4.0, 5.0]);
190        assert!(tl.is_ready());
191        tl.reset();
192        assert!(!tl.is_ready());
193        assert_eq!(tl.update(1.0), None);
194    }
195
196    #[test]
197    fn batch_equals_streaming() {
198        let prices: Vec<f64> = (0..60)
199            .map(|i| 100.0 + (f64::from(i) * 0.3).sin() * 5.0)
200            .collect();
201        let batch = TrendLabel::new(14).unwrap().batch(&prices);
202        let mut b = TrendLabel::new(14).unwrap();
203        let streamed: Vec<_> = prices.iter().map(|p| b.update(*p)).collect();
204        assert_eq!(batch, streamed);
205    }
206}