wickra_core/indicators/
trend_label.rs1use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8#[derive(Debug, Clone)]
41pub struct TrendLabel {
42 period: usize,
43 window: VecDeque<f64>,
44}
45
46impl TrendLabel {
47 pub fn new(period: usize) -> Result<Self> {
53 if period < 2 {
54 return Err(Error::InvalidPeriod {
55 message: "trend label needs period >= 2",
56 });
57 }
58 Ok(Self {
59 period,
60 window: VecDeque::with_capacity(period),
61 })
62 }
63
64 pub const fn period(&self) -> usize {
66 self.period
67 }
68}
69
70impl Indicator for TrendLabel {
71 type Input = f64;
72 type Output = f64;
73
74 fn update(&mut self, value: f64) -> Option<f64> {
75 if !value.is_finite() {
76 return None;
77 }
78 if self.window.len() == self.period {
79 self.window.pop_front();
80 }
81 self.window.push_back(value);
82 if self.window.len() < self.period {
83 return None;
84 }
85 let count = self.period as f64;
86 let mean_t = (count - 1.0) / 2.0;
87 let mean_x = self.window.iter().sum::<f64>() / count;
88 let mut numerator = 0.0;
91 for (t, &x) in self.window.iter().enumerate() {
92 numerator += (t as f64 - mean_t) * (x - mean_x);
93 }
94 let label = if numerator > 0.0 {
95 1.0
96 } else if numerator < 0.0 {
97 -1.0
98 } else {
99 0.0
100 };
101 Some(label)
102 }
103
104 fn reset(&mut self) {
105 self.window.clear();
106 }
107
108 fn warmup_period(&self) -> usize {
109 self.period
110 }
111
112 fn is_ready(&self) -> bool {
113 self.window.len() == self.period
114 }
115
116 fn name(&self) -> &'static str {
117 "TrendLabel"
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124 use crate::traits::BatchExt;
125
126 #[test]
127 fn rejects_period_below_two() {
128 assert!(matches!(
129 TrendLabel::new(1),
130 Err(Error::InvalidPeriod { .. })
131 ));
132 assert!(TrendLabel::new(2).is_ok());
133 }
134
135 #[test]
136 fn accessors_and_metadata() {
137 let tl = TrendLabel::new(10).unwrap();
138 assert_eq!(tl.period(), 10);
139 assert_eq!(tl.warmup_period(), 10);
140 assert_eq!(tl.name(), "TrendLabel");
141 assert!(!tl.is_ready());
142 }
143
144 #[test]
145 fn rising_series_is_plus_one() {
146 let mut tl = TrendLabel::new(10).unwrap();
147 let prices: Vec<f64> = (0..20).map(f64::from).collect();
148 assert_eq!(tl.batch(&prices).into_iter().flatten().last(), Some(1.0));
149 }
150
151 #[test]
152 fn falling_series_is_minus_one() {
153 let mut tl = TrendLabel::new(10).unwrap();
154 let prices: Vec<f64> = (0..20).map(|i| 100.0 - f64::from(i)).collect();
155 assert_eq!(tl.batch(&prices).into_iter().flatten().last(), Some(-1.0));
156 }
157
158 #[test]
159 fn flat_series_is_zero() {
160 let mut tl = TrendLabel::new(8).unwrap();
161 for v in tl.batch(&[42.0; 16]).into_iter().flatten() {
162 assert_eq!(v, 0.0);
163 }
164 }
165
166 #[test]
167 fn scale_invariant_sign() {
168 let prices: Vec<f64> = (0..30)
170 .map(|i| 100.0 + (f64::from(i) * 0.4).sin() * 5.0)
171 .collect();
172 let small = TrendLabel::new(12).unwrap().batch(&prices);
173 let scaled: Vec<f64> = prices.iter().map(|p| p * 1000.0).collect();
174 let large = TrendLabel::new(12).unwrap().batch(&scaled);
175 assert_eq!(small, large);
176 }
177
178 #[test]
179 fn output_is_ternary() {
180 let mut tl = TrendLabel::new(14).unwrap();
181 let prices: Vec<f64> = (0..200)
182 .map(|i| 100.0 + (f64::from(i) * 0.3).sin() * 10.0)
183 .collect();
184 for v in tl.batch(&prices).into_iter().flatten() {
185 assert!(v == -1.0 || v == 0.0 || v == 1.0, "non-ternary label {v}");
186 }
187 }
188
189 #[test]
190 fn reset_clears_state() {
191 let mut tl = TrendLabel::new(5).unwrap();
192 tl.batch(&[1.0, 2.0, 3.0, 4.0, 5.0]);
193 assert!(tl.is_ready());
194 tl.reset();
195 assert!(!tl.is_ready());
196 assert_eq!(tl.update(1.0), None);
197 }
198
199 #[test]
200 fn batch_equals_streaming() {
201 let prices: Vec<f64> = (0..60)
202 .map(|i| 100.0 + (f64::from(i) * 0.3).sin() * 5.0)
203 .collect();
204 let batch = TrendLabel::new(14).unwrap().batch(&prices);
205 let mut b = TrendLabel::new(14).unwrap();
206 let streamed: Vec<_> = prices.iter().map(|p| b.update(*p)).collect();
207 assert_eq!(batch, streamed);
208 }
209}