wickra_core/indicators/
regime_label.rs1use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::indicators::rolling_quantile::quantile_sorted;
7use crate::traits::Indicator;
8
9#[derive(Debug, Clone)]
48pub struct RegimeLabel {
49 vol_period: usize,
50 lookback: usize,
51 prev_price: Option<f64>,
52 ret_window: VecDeque<f64>,
54 ret_sum: f64,
55 ret_sum_sq: f64,
56 vol_window: VecDeque<f64>,
58 scratch: Vec<f64>,
60 last: Option<f64>,
61}
62
63impl RegimeLabel {
64 pub fn new(vol_period: usize, lookback: usize) -> Result<Self> {
74 if vol_period < 2 {
75 return Err(Error::InvalidPeriod {
76 message: "regime label needs vol_period >= 2",
77 });
78 }
79 if lookback < 2 {
80 return Err(Error::InvalidPeriod {
81 message: "regime label needs lookback >= 2",
82 });
83 }
84 Ok(Self {
85 vol_period,
86 lookback,
87 prev_price: None,
88 ret_window: VecDeque::with_capacity(vol_period),
89 ret_sum: 0.0,
90 ret_sum_sq: 0.0,
91 vol_window: VecDeque::with_capacity(lookback),
92 scratch: Vec::with_capacity(lookback),
93 last: None,
94 })
95 }
96
97 pub const fn params(&self) -> (usize, usize) {
99 (self.vol_period, self.lookback)
100 }
101}
102
103impl Indicator for RegimeLabel {
104 type Input = f64;
105 type Output = f64;
106
107 fn update(&mut self, input: f64) -> Option<f64> {
108 if !input.is_finite() || input <= 0.0 {
109 return self.last;
110 }
111 let Some(prev) = self.prev_price else {
112 self.prev_price = Some(input);
113 return None;
114 };
115 self.prev_price = Some(input);
116 let r = (input / prev).ln();
117 if self.ret_window.len() == self.vol_period {
119 let old = self.ret_window.pop_front().expect("non-empty");
120 self.ret_sum -= old;
121 self.ret_sum_sq -= old * old;
122 }
123 self.ret_window.push_back(r);
124 self.ret_sum += r;
125 self.ret_sum_sq += r * r;
126 if self.ret_window.len() < self.vol_period {
127 return None;
128 }
129 let n = self.vol_period as f64;
130 let mean = self.ret_sum / n;
131 let var = ((self.ret_sum_sq - n * mean * mean) / (n - 1.0)).max(0.0);
132 let vol = var.sqrt();
133 if self.vol_window.len() == self.lookback {
135 self.vol_window.pop_front();
136 }
137 self.vol_window.push_back(vol);
138 if self.vol_window.len() < self.lookback {
139 return None;
140 }
141 self.scratch.clear();
143 self.scratch.extend(self.vol_window.iter().copied());
144 self.scratch.sort_by(f64::total_cmp);
145 let q1 = quantile_sorted(&self.scratch, 0.25);
146 let q3 = quantile_sorted(&self.scratch, 0.75);
147 let label = if vol < q1 {
148 -1.0
149 } else if vol > q3 {
150 1.0
151 } else {
152 0.0
153 };
154 self.last = Some(label);
155 Some(label)
156 }
157
158 fn reset(&mut self) {
159 self.prev_price = None;
160 self.ret_window.clear();
161 self.ret_sum = 0.0;
162 self.ret_sum_sq = 0.0;
163 self.vol_window.clear();
164 self.scratch.clear();
165 self.last = None;
166 }
167
168 fn warmup_period(&self) -> usize {
169 self.vol_period + self.lookback
172 }
173
174 fn is_ready(&self) -> bool {
175 self.last.is_some()
176 }
177
178 fn name(&self) -> &'static str {
179 "RegimeLabel"
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186 use crate::traits::BatchExt;
187
188 #[test]
189 fn rejects_bad_periods() {
190 assert!(matches!(
191 RegimeLabel::new(1, 20),
192 Err(Error::InvalidPeriod { .. })
193 ));
194 assert!(matches!(
195 RegimeLabel::new(5, 1),
196 Err(Error::InvalidPeriod { .. })
197 ));
198 }
199
200 #[test]
201 fn accessors_and_metadata() {
202 let rl = RegimeLabel::new(5, 20).unwrap();
203 assert_eq!(rl.params(), (5, 20));
204 assert_eq!(rl.warmup_period(), 25);
205 assert_eq!(rl.name(), "RegimeLabel");
206 assert!(!rl.is_ready());
207 }
208
209 #[test]
210 fn detects_stressed_regime_on_volatility_spike() {
211 let mut rl = RegimeLabel::new(4, 8).unwrap();
214 let mut prices: Vec<f64> = (0..24)
215 .map(|i| 100.0 + (f64::from(i) * 0.7).sin() * 0.2)
216 .collect();
217 let mut base = *prices.last().unwrap();
218 for i in 0..8 {
219 base *= if i % 2 == 0 { 1.08 } else { 0.93 };
220 prices.push(base);
221 }
222 let out = rl.batch(&prices);
223 assert!(
224 out.iter().flatten().any(|&v| v == 1.0),
225 "expected a stressed (+1) regime label"
226 );
227 }
228
229 #[test]
230 fn detects_calm_regime_after_volatility_drop() {
231 let mut rl = RegimeLabel::new(4, 8).unwrap();
233 let mut prices: Vec<f64> = Vec::new();
234 let mut base = 100.0;
235 for i in 0..24 {
236 base *= if i % 2 == 0 { 1.05 } else { 0.96 };
237 prices.push(base);
238 }
239 for i in 0..12 {
240 prices.push(base + (f64::from(i) * 0.7).sin() * 0.05);
241 }
242 let out = rl.batch(&prices);
243 assert!(
244 out.iter().flatten().any(|&v| v == -1.0),
245 "expected a calm (-1) regime label"
246 );
247 }
248
249 #[test]
250 fn zero_volatility_is_neutral() {
251 let mut rl = RegimeLabel::new(4, 8).unwrap();
257 for v in rl.batch(&[100.0; 40]).into_iter().flatten() {
258 assert_eq!(v, 0.0);
259 }
260 }
261
262 #[test]
263 fn output_is_ternary() {
264 let mut rl = RegimeLabel::new(5, 20).unwrap();
265 let prices: Vec<f64> = (0..300)
266 .map(|i| 100.0 + (f64::from(i) * 0.3).sin() * (1.0 + (f64::from(i) * 0.05).sin() * 5.0))
267 .collect();
268 for v in rl.batch(&prices).into_iter().flatten() {
269 assert!(v == -1.0 || v == 0.0 || v == 1.0, "non-ternary label {v}");
270 }
271 }
272
273 #[test]
274 fn ignores_non_finite_and_non_positive() {
275 let mut rl = RegimeLabel::new(4, 6).unwrap();
276 let prices: Vec<f64> = (0..40)
277 .map(|i| 100.0 + (f64::from(i) * 0.5).sin() * 2.0)
278 .collect();
279 let out = rl.batch(&prices);
280 let last = *out.last().unwrap();
281 assert!(last.is_some());
282 assert_eq!(rl.update(f64::NAN), last);
283 assert_eq!(rl.update(-1.0), last);
284 assert_eq!(rl.update(0.0), last);
285 }
286
287 #[test]
288 fn reset_clears_state() {
289 let mut rl = RegimeLabel::new(4, 6).unwrap();
290 rl.batch(&(1..=40).map(f64::from).collect::<Vec<_>>());
291 assert!(rl.is_ready());
292 rl.reset();
293 assert!(!rl.is_ready());
294 assert_eq!(rl.update(1.0), None);
295 }
296
297 #[test]
298 fn batch_equals_streaming() {
299 let prices: Vec<f64> = (1..=160)
300 .map(|i| 100.0 + (f64::from(i) * 0.25).sin() * 4.0)
301 .collect();
302 let batch = RegimeLabel::new(5, 20).unwrap().batch(&prices);
303 let mut b = RegimeLabel::new(5, 20).unwrap();
304 let streamed: Vec<_> = prices.iter().map(|p| b.update(*p)).collect();
305 assert_eq!(batch, streamed);
306 }
307}