Skip to main content

wickra_core/indicators/
regime_label.rs

1//! Regime Label — volatility-quantile classification of the current bar.
2
3use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::indicators::rolling_quantile::quantile_sorted;
7use crate::traits::Indicator;
8
9/// Regime Label — a discrete `{−1, 0, +1}` classification of the current
10/// volatility regime by where the latest rolling volatility falls within its
11/// own recent distribution.
12///
13/// ```text
14/// σₜ    = sample stddev of the last `vol_period` log returns
15/// q1,q3 = 25th / 75th percentile of the last `lookback` σ readings
16/// label = −1 if σₜ < q1   (calm regime)
17///         +1 if σₜ > q3   (stressed regime)
18///          0 otherwise    (normal regime)
19/// ```
20///
21/// This is the canonical rolling-volatility-quantile regime split: rather than
22/// thresholding absolute volatility (which is not comparable across instruments
23/// or epochs), it asks whether *today's* volatility is unusually low or high
24/// **relative to its own recent history**. `−1` is a calm regime, `+1` a
25/// stressed / high-volatility regime, `0` the normal middle. Because the latest
26/// reading is included in its own reference window, a freshly elevated
27/// volatility prints `+1` until the window catches up to the new level — it
28/// flags the *transition*, not just the absolute level. When the recent
29/// volatilities are all equal (`q1 == q3`, e.g. a constant drift) there is no
30/// spread to classify against and the label is `0`.
31///
32/// Each `update` is `O(vol_period + lookback log lookback)`. Non-finite and
33/// non-positive prices are ignored.
34///
35/// # Example
36///
37/// ```
38/// use wickra_core::{Indicator, RegimeLabel};
39///
40/// let mut indicator = RegimeLabel::new(5, 20).unwrap();
41/// let mut last = None;
42/// for i in 0..60 {
43///     last = indicator.update(100.0 + (f64::from(i) * 0.5).sin());
44/// }
45/// assert!(last.is_some());
46/// ```
47#[derive(Debug, Clone)]
48pub struct RegimeLabel {
49    vol_period: usize,
50    lookback: usize,
51    prev_price: Option<f64>,
52    /// Trailing window of the last `vol_period` log returns.
53    ret_window: VecDeque<f64>,
54    ret_sum: f64,
55    ret_sum_sq: f64,
56    /// Trailing window of the last `lookback` volatility readings.
57    vol_window: VecDeque<f64>,
58    /// Reusable scratch buffer for the quantile sort.
59    scratch: Vec<f64>,
60    last: Option<f64>,
61}
62
63impl RegimeLabel {
64    /// Construct a new Regime Label classifier.
65    ///
66    /// `vol_period` is the window for the rolling volatility; `lookback` is the
67    /// window of volatility readings whose quartiles set the regime bands.
68    ///
69    /// # Errors
70    /// Returns [`Error::InvalidPeriod`] if `vol_period < 2` (the sample standard
71    /// deviation needs at least two returns) or if `lookback < 2` (the quartile
72    /// split needs at least two readings).
73    pub fn new(vol_period: usize, lookback: usize) -> Result<Self> {
74        if vol_period < 2 {
75            return Err(Error::InvalidPeriod {
76                message: "regime label needs vol_period >= 2",
77            });
78        }
79        if lookback < 2 {
80            return Err(Error::InvalidPeriod {
81                message: "regime label needs lookback >= 2",
82            });
83        }
84        Ok(Self {
85            vol_period,
86            lookback,
87            prev_price: None,
88            ret_window: VecDeque::with_capacity(vol_period),
89            ret_sum: 0.0,
90            ret_sum_sq: 0.0,
91            vol_window: VecDeque::with_capacity(lookback),
92            scratch: Vec::with_capacity(lookback),
93            last: None,
94        })
95    }
96
97    /// Configured `(vol_period, lookback)`.
98    pub const fn params(&self) -> (usize, usize) {
99        (self.vol_period, self.lookback)
100    }
101}
102
103impl Indicator for RegimeLabel {
104    type Input = f64;
105    type Output = f64;
106
107    fn update(&mut self, input: f64) -> Option<f64> {
108        if !input.is_finite() || input <= 0.0 {
109            return self.last;
110        }
111        let Some(prev) = self.prev_price else {
112            self.prev_price = Some(input);
113            return None;
114        };
115        self.prev_price = Some(input);
116        let r = (input / prev).ln();
117        // Roll the return window and its running moments.
118        if self.ret_window.len() == self.vol_period {
119            let old = self.ret_window.pop_front().expect("non-empty");
120            self.ret_sum -= old;
121            self.ret_sum_sq -= old * old;
122        }
123        self.ret_window.push_back(r);
124        self.ret_sum += r;
125        self.ret_sum_sq += r * r;
126        if self.ret_window.len() < self.vol_period {
127            return None;
128        }
129        let n = self.vol_period as f64;
130        let mean = self.ret_sum / n;
131        let var = ((self.ret_sum_sq - n * mean * mean) / (n - 1.0)).max(0.0);
132        let vol = var.sqrt();
133        // Roll the volatility window.
134        if self.vol_window.len() == self.lookback {
135            self.vol_window.pop_front();
136        }
137        self.vol_window.push_back(vol);
138        if self.vol_window.len() < self.lookback {
139            return None;
140        }
141        // Classify the latest volatility against the quartiles of the window.
142        self.scratch.clear();
143        self.scratch.extend(self.vol_window.iter().copied());
144        self.scratch.sort_by(f64::total_cmp);
145        let q1 = quantile_sorted(&self.scratch, 0.25);
146        let q3 = quantile_sorted(&self.scratch, 0.75);
147        let label = if vol < q1 {
148            -1.0
149        } else if vol > q3 {
150            1.0
151        } else {
152            0.0
153        };
154        self.last = Some(label);
155        Some(label)
156    }
157
158    fn reset(&mut self) {
159        self.prev_price = None;
160        self.ret_window.clear();
161        self.ret_sum = 0.0;
162        self.ret_sum_sq = 0.0;
163        self.vol_window.clear();
164        self.scratch.clear();
165        self.last = None;
166    }
167
168    fn warmup_period(&self) -> usize {
169        // One price seeds `prev`, `vol_period` returns yield the first vol, then
170        // `lookback` vols fill the regime window.
171        self.vol_period + self.lookback
172    }
173
174    fn is_ready(&self) -> bool {
175        self.last.is_some()
176    }
177
178    fn name(&self) -> &'static str {
179        "RegimeLabel"
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186    use crate::traits::BatchExt;
187
188    #[test]
189    fn rejects_bad_periods() {
190        assert!(matches!(
191            RegimeLabel::new(1, 20),
192            Err(Error::InvalidPeriod { .. })
193        ));
194        assert!(matches!(
195            RegimeLabel::new(5, 1),
196            Err(Error::InvalidPeriod { .. })
197        ));
198    }
199
200    #[test]
201    fn accessors_and_metadata() {
202        let rl = RegimeLabel::new(5, 20).unwrap();
203        assert_eq!(rl.params(), (5, 20));
204        assert_eq!(rl.warmup_period(), 25);
205        assert_eq!(rl.name(), "RegimeLabel");
206        assert!(!rl.is_ready());
207    }
208
209    #[test]
210    fn detects_stressed_regime_on_volatility_spike() {
211        // Calm warmup, then a burst of large moves: the elevated volatility
212        // prints +1 while the lookback window still holds the calm readings.
213        let mut rl = RegimeLabel::new(4, 8).unwrap();
214        let mut prices: Vec<f64> = (0..24)
215            .map(|i| 100.0 + (f64::from(i) * 0.7).sin() * 0.2)
216            .collect();
217        let mut base = *prices.last().unwrap();
218        for i in 0..8 {
219            base *= if i % 2 == 0 { 1.08 } else { 0.93 };
220            prices.push(base);
221        }
222        let out = rl.batch(&prices);
223        assert!(
224            out.iter().flatten().any(|&v| v == 1.0),
225            "expected a stressed (+1) regime label"
226        );
227    }
228
229    #[test]
230    fn detects_calm_regime_after_volatility_drop() {
231        // Volatile warmup, then a calm tail: the depressed volatility prints -1.
232        let mut rl = RegimeLabel::new(4, 8).unwrap();
233        let mut prices: Vec<f64> = Vec::new();
234        let mut base = 100.0;
235        for i in 0..24 {
236            base *= if i % 2 == 0 { 1.05 } else { 0.96 };
237            prices.push(base);
238        }
239        for i in 0..12 {
240            prices.push(base + (f64::from(i) * 0.7).sin() * 0.05);
241        }
242        let out = rl.batch(&prices);
243        assert!(
244            out.iter().flatten().any(|&v| v == -1.0),
245            "expected a calm (-1) regime label"
246        );
247    }
248
249    #[test]
250    fn zero_volatility_is_neutral() {
251        // A constant price has exactly-zero returns => zero volatility on every
252        // window => q1 == q3 == 0 => neutral 0 throughout. (A geometric drift is
253        // *conceptually* constant-vol too, but floating-point rounding of the
254        // log returns leaves ~1e-16 dispersion, so the exactly-flat series is
255        // the clean way to pin the q1 == q3 branch.)
256        let mut rl = RegimeLabel::new(4, 8).unwrap();
257        for v in rl.batch(&[100.0; 40]).into_iter().flatten() {
258            assert_eq!(v, 0.0);
259        }
260    }
261
262    #[test]
263    fn output_is_ternary() {
264        let mut rl = RegimeLabel::new(5, 20).unwrap();
265        let prices: Vec<f64> = (0..300)
266            .map(|i| 100.0 + (f64::from(i) * 0.3).sin() * (1.0 + (f64::from(i) * 0.05).sin() * 5.0))
267            .collect();
268        for v in rl.batch(&prices).into_iter().flatten() {
269            assert!(v == -1.0 || v == 0.0 || v == 1.0, "non-ternary label {v}");
270        }
271    }
272
273    #[test]
274    fn ignores_non_finite_and_non_positive() {
275        let mut rl = RegimeLabel::new(4, 6).unwrap();
276        let prices: Vec<f64> = (0..40)
277            .map(|i| 100.0 + (f64::from(i) * 0.5).sin() * 2.0)
278            .collect();
279        let out = rl.batch(&prices);
280        let last = *out.last().unwrap();
281        assert!(last.is_some());
282        assert_eq!(rl.update(f64::NAN), last);
283        assert_eq!(rl.update(-1.0), last);
284        assert_eq!(rl.update(0.0), last);
285    }
286
287    #[test]
288    fn reset_clears_state() {
289        let mut rl = RegimeLabel::new(4, 6).unwrap();
290        rl.batch(&(1..=40).map(f64::from).collect::<Vec<_>>());
291        assert!(rl.is_ready());
292        rl.reset();
293        assert!(!rl.is_ready());
294        assert_eq!(rl.update(1.0), None);
295    }
296
297    #[test]
298    fn batch_equals_streaming() {
299        let prices: Vec<f64> = (1..=160)
300            .map(|i| 100.0 + (f64::from(i) * 0.25).sin() * 4.0)
301            .collect();
302        let batch = RegimeLabel::new(5, 20).unwrap().batch(&prices);
303        let mut b = RegimeLabel::new(5, 20).unwrap();
304        let streamed: Vec<_> = prices.iter().map(|p| b.update(*p)).collect();
305        assert_eq!(batch, streamed);
306    }
307}