wickra_core/indicators/
trend_label.rs1use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8#[derive(Debug, Clone)]
41pub struct TrendLabel {
42 period: usize,
43 window: VecDeque<f64>,
44}
45
46impl TrendLabel {
47 pub fn new(period: usize) -> Result<Self> {
53 if period < 2 {
54 return Err(Error::InvalidPeriod {
55 message: "trend label needs period >= 2",
56 });
57 }
58 Ok(Self {
59 period,
60 window: VecDeque::with_capacity(period),
61 })
62 }
63
64 pub const fn period(&self) -> usize {
66 self.period
67 }
68}
69
70impl Indicator for TrendLabel {
71 type Input = f64;
72 type Output = f64;
73
74 fn update(&mut self, value: f64) -> Option<f64> {
75 if self.window.len() == self.period {
76 self.window.pop_front();
77 }
78 self.window.push_back(value);
79 if self.window.len() < self.period {
80 return None;
81 }
82 let count = self.period as f64;
83 let mean_t = (count - 1.0) / 2.0;
84 let mean_x = self.window.iter().sum::<f64>() / count;
85 let mut numerator = 0.0;
88 for (t, &x) in self.window.iter().enumerate() {
89 numerator += (t as f64 - mean_t) * (x - mean_x);
90 }
91 let label = if numerator > 0.0 {
92 1.0
93 } else if numerator < 0.0 {
94 -1.0
95 } else {
96 0.0
97 };
98 Some(label)
99 }
100
101 fn reset(&mut self) {
102 self.window.clear();
103 }
104
105 fn warmup_period(&self) -> usize {
106 self.period
107 }
108
109 fn is_ready(&self) -> bool {
110 self.window.len() == self.period
111 }
112
113 fn name(&self) -> &'static str {
114 "TrendLabel"
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121 use crate::traits::BatchExt;
122
123 #[test]
124 fn rejects_period_below_two() {
125 assert!(matches!(
126 TrendLabel::new(1),
127 Err(Error::InvalidPeriod { .. })
128 ));
129 assert!(TrendLabel::new(2).is_ok());
130 }
131
132 #[test]
133 fn accessors_and_metadata() {
134 let tl = TrendLabel::new(10).unwrap();
135 assert_eq!(tl.period(), 10);
136 assert_eq!(tl.warmup_period(), 10);
137 assert_eq!(tl.name(), "TrendLabel");
138 assert!(!tl.is_ready());
139 }
140
141 #[test]
142 fn rising_series_is_plus_one() {
143 let mut tl = TrendLabel::new(10).unwrap();
144 let prices: Vec<f64> = (0..20).map(f64::from).collect();
145 assert_eq!(tl.batch(&prices).into_iter().flatten().last(), Some(1.0));
146 }
147
148 #[test]
149 fn falling_series_is_minus_one() {
150 let mut tl = TrendLabel::new(10).unwrap();
151 let prices: Vec<f64> = (0..20).map(|i| 100.0 - f64::from(i)).collect();
152 assert_eq!(tl.batch(&prices).into_iter().flatten().last(), Some(-1.0));
153 }
154
155 #[test]
156 fn flat_series_is_zero() {
157 let mut tl = TrendLabel::new(8).unwrap();
158 for v in tl.batch(&[42.0; 16]).into_iter().flatten() {
159 assert_eq!(v, 0.0);
160 }
161 }
162
163 #[test]
164 fn scale_invariant_sign() {
165 let prices: Vec<f64> = (0..30)
167 .map(|i| 100.0 + (f64::from(i) * 0.4).sin() * 5.0)
168 .collect();
169 let small = TrendLabel::new(12).unwrap().batch(&prices);
170 let scaled: Vec<f64> = prices.iter().map(|p| p * 1000.0).collect();
171 let large = TrendLabel::new(12).unwrap().batch(&scaled);
172 assert_eq!(small, large);
173 }
174
175 #[test]
176 fn output_is_ternary() {
177 let mut tl = TrendLabel::new(14).unwrap();
178 let prices: Vec<f64> = (0..200)
179 .map(|i| 100.0 + (f64::from(i) * 0.3).sin() * 10.0)
180 .collect();
181 for v in tl.batch(&prices).into_iter().flatten() {
182 assert!(v == -1.0 || v == 0.0 || v == 1.0, "non-ternary label {v}");
183 }
184 }
185
186 #[test]
187 fn reset_clears_state() {
188 let mut tl = TrendLabel::new(5).unwrap();
189 tl.batch(&[1.0, 2.0, 3.0, 4.0, 5.0]);
190 assert!(tl.is_ready());
191 tl.reset();
192 assert!(!tl.is_ready());
193 assert_eq!(tl.update(1.0), None);
194 }
195
196 #[test]
197 fn batch_equals_streaming() {
198 let prices: Vec<f64> = (0..60)
199 .map(|i| 100.0 + (f64::from(i) * 0.3).sin() * 5.0)
200 .collect();
201 let batch = TrendLabel::new(14).unwrap().batch(&prices);
202 let mut b = TrendLabel::new(14).unwrap();
203 let streamed: Vec<_> = prices.iter().map(|p| b.update(*p)).collect();
204 assert_eq!(batch, streamed);
205 }
206}