Skip to main content

wickra_core/indicators/
holt_winters.rs

1//! Holt's linear (double exponential) smoothing.
2
3use crate::error::{Error, Result};
4use crate::traits::Indicator;
5
6/// Holt's linear method — double exponential smoothing with a level and a
7/// trend component.
8///
9/// A single [`Ema`](crate::Ema) tracks only a *level* and therefore lags any
10/// sustained trend. Holt's method adds a second smoothed state, the trend, and
11/// reports the one-step-ahead forecast `level + trend`, which removes that lag
12/// on trending data while still smoothing noise.
13///
14/// ```text
15/// level_t = α · price_t        + (1 − α) · (level_{t-1} + trend_{t-1})
16/// trend_t = β · (level_t − level_{t-1}) + (1 − β) · trend_{t-1}
17/// output  = level_t + trend_t          (one-step-ahead forecast)
18/// ```
19///
20/// `α ∈ (0, 1]` is the level smoothing constant and `β ∈ (0, 1]` the trend
21/// smoothing constant. The state is seeded from the first two inputs
22/// (`level = price_1`, `trend = price_1 − price_0`), so the first output lands
23/// on the **second** input.
24///
25/// On a perfectly linear series the forecast is exact from the second bar
26/// onward (for any `α`, `β`): if the level equals the current value and the
27/// trend equals the slope, both invariants are preserved and `level + trend`
28/// equals the next value.
29///
30/// # Example
31///
32/// ```
33/// use wickra_core::{HoltWinters, Indicator};
34///
35/// let mut indicator = HoltWinters::new(0.2, 0.1).unwrap();
36/// let mut last = None;
37/// for i in 0..80 {
38///     last = indicator.update(100.0 + f64::from(i));
39/// }
40/// assert!(last.is_some());
41/// ```
42#[derive(Debug, Clone)]
43pub struct HoltWinters {
44    alpha: f64,
45    beta: f64,
46    /// `(level, trend)` once seeded.
47    state: Option<(f64, f64)>,
48    /// First input, held until the second arrives to seed the trend.
49    prev_price: Option<f64>,
50}
51
52impl HoltWinters {
53    /// Construct Holt's linear smoother with level constant `alpha` and trend
54    /// constant `beta`.
55    ///
56    /// # Errors
57    ///
58    /// Returns [`Error::InvalidPeriod`] if either constant is non-finite or
59    /// outside `(0.0, 1.0]`.
60    pub fn new(alpha: f64, beta: f64) -> Result<Self> {
61        if !alpha.is_finite() || alpha <= 0.0 || alpha > 1.0 {
62            return Err(Error::InvalidPeriod {
63                message: "HoltWinters alpha must be in (0.0, 1.0]",
64            });
65        }
66        if !beta.is_finite() || beta <= 0.0 || beta > 1.0 {
67            return Err(Error::InvalidPeriod {
68                message: "HoltWinters beta must be in (0.0, 1.0]",
69            });
70        }
71        Ok(Self {
72            alpha,
73            beta,
74            state: None,
75            prev_price: None,
76        })
77    }
78
79    /// Level smoothing constant `alpha`.
80    pub const fn alpha(&self) -> f64 {
81        self.alpha
82    }
83
84    /// Trend smoothing constant `beta`.
85    pub const fn beta(&self) -> f64 {
86        self.beta
87    }
88
89    /// Current smoothed level, if seeded.
90    pub fn level(&self) -> Option<f64> {
91        self.state.map(|(level, _)| level)
92    }
93
94    /// Current smoothed trend, if seeded.
95    pub fn trend(&self) -> Option<f64> {
96        self.state.map(|(_, trend)| trend)
97    }
98
99    /// Current one-step-ahead forecast `level + trend`, if seeded.
100    pub fn value(&self) -> Option<f64> {
101        self.state.map(|(level, trend)| level + trend)
102    }
103}
104
105impl Indicator for HoltWinters {
106    type Input = f64;
107    type Output = f64;
108
109    fn update(&mut self, price: f64) -> Option<f64> {
110        if !price.is_finite() {
111            return self.value();
112        }
113        match self.state {
114            None => {
115                if let Some(prev) = self.prev_price {
116                    // Second input: seed level and trend.
117                    let level = price;
118                    let trend = price - prev;
119                    self.state = Some((level, trend));
120                    Some(level + trend)
121                } else {
122                    // First input: hold it to seed the trend next time.
123                    self.prev_price = Some(price);
124                    None
125                }
126            }
127            Some((level, trend)) => {
128                let level_new = self.alpha * price + (1.0 - self.alpha) * (level + trend);
129                let trend_new = self.beta * (level_new - level) + (1.0 - self.beta) * trend;
130                self.state = Some((level_new, trend_new));
131                Some(level_new + trend_new)
132            }
133        }
134    }
135
136    fn reset(&mut self) {
137        self.state = None;
138        self.prev_price = None;
139    }
140
141    fn warmup_period(&self) -> usize {
142        // Two inputs are needed to seed the level and the trend.
143        2
144    }
145
146    fn is_ready(&self) -> bool {
147        self.state.is_some()
148    }
149
150    fn name(&self) -> &'static str {
151        "HoltWinters"
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158    use crate::traits::BatchExt;
159    use approx::assert_relative_eq;
160
161    /// Independent reference for the steady-state recurrence.
162    fn naive(prices: &[f64], alpha: f64, beta: f64) -> Vec<Option<f64>> {
163        let mut state: Option<(f64, f64)> = None;
164        let mut prev: Option<f64> = None;
165        let mut out = Vec::with_capacity(prices.len());
166        for &price in prices {
167            let v = match state {
168                None => {
169                    if let Some(p0) = prev {
170                        let level = price;
171                        let trend = price - p0;
172                        state = Some((level, trend));
173                        Some(level + trend)
174                    } else {
175                        prev = Some(price);
176                        None
177                    }
178                }
179                Some((level, trend)) => {
180                    let ln = alpha * price + (1.0 - alpha) * (level + trend);
181                    let tn = beta * (ln - level) + (1.0 - beta) * trend;
182                    state = Some((ln, tn));
183                    Some(ln + tn)
184                }
185            };
186            out.push(v);
187        }
188        out
189    }
190
191    #[test]
192    fn rejects_invalid_alpha() {
193        assert!(matches!(
194            HoltWinters::new(0.0, 0.1),
195            Err(Error::InvalidPeriod { .. })
196        ));
197        assert!(matches!(
198            HoltWinters::new(1.5, 0.1),
199            Err(Error::InvalidPeriod { .. })
200        ));
201        assert!(matches!(
202            HoltWinters::new(f64::NAN, 0.1),
203            Err(Error::InvalidPeriod { .. })
204        ));
205    }
206
207    #[test]
208    fn rejects_invalid_beta() {
209        assert!(matches!(
210            HoltWinters::new(0.2, 0.0),
211            Err(Error::InvalidPeriod { .. })
212        ));
213        assert!(matches!(
214            HoltWinters::new(0.2, 1.5),
215            Err(Error::InvalidPeriod { .. })
216        ));
217        assert!(matches!(
218            HoltWinters::new(0.2, f64::INFINITY),
219            Err(Error::InvalidPeriod { .. })
220        ));
221    }
222
223    /// Cover the const accessors `alpha` + `beta` and the Indicator-impl
224    /// `warmup_period` + `name`.
225    #[test]
226    fn accessors_and_metadata() {
227        let hw = HoltWinters::new(0.2, 0.1).unwrap();
228        assert_relative_eq!(hw.alpha(), 0.2, epsilon = 1e-12);
229        assert_relative_eq!(hw.beta(), 0.1, epsilon = 1e-12);
230        assert_eq!(hw.warmup_period(), 2);
231        assert_eq!(hw.name(), "HoltWinters");
232    }
233
234    #[test]
235    fn warmup_then_seed_on_second_input() {
236        let mut hw = HoltWinters::new(0.2, 0.1).unwrap();
237        assert_eq!(hw.update(10.0), None);
238        // Second input seeds level = 12, trend = 12 - 10 = 2 -> forecast 14.
239        assert_relative_eq!(hw.update(12.0).unwrap(), 14.0, epsilon = 1e-12);
240        assert_relative_eq!(hw.level().unwrap(), 12.0, epsilon = 1e-12);
241        assert_relative_eq!(hw.trend().unwrap(), 2.0, epsilon = 1e-12);
242    }
243
244    #[test]
245    fn linear_series_forecasts_exactly() {
246        // On a perfect ramp the one-step forecast equals the next value, for
247        // any alpha/beta, from the second bar onward.
248        let prices: Vec<f64> = (1..=20).map(f64::from).collect();
249        let mut hw = HoltWinters::new(0.3, 0.4).unwrap();
250        let out = hw.batch(&prices);
251        assert!(out[0].is_none());
252        for (i, v) in out.iter().enumerate().skip(1) {
253            // forecast at index i is the price at index i + 1 = (i + 2).
254            assert_relative_eq!(v.unwrap(), (i + 2) as f64, epsilon = 1e-9);
255        }
256    }
257
258    #[test]
259    fn constant_series_yields_constant() {
260        let mut hw = HoltWinters::new(0.2, 0.1).unwrap();
261        let out = hw.batch(&[42.0_f64; 30]);
262        for v in out.into_iter().skip(1).flatten() {
263            assert_relative_eq!(v, 42.0, epsilon = 1e-9);
264        }
265    }
266
267    #[test]
268    fn matches_naive_recurrence() {
269        let prices: Vec<f64> = (0..60)
270            .map(|i| 100.0 + (f64::from(i) * 0.3).sin() * 10.0 + f64::from(i) * 0.2)
271            .collect();
272        let mut hw = HoltWinters::new(0.25, 0.15).unwrap();
273        let got = hw.batch(&prices);
274        let want = naive(&prices, 0.25, 0.15);
275        for (g, w) in got.iter().zip(want.iter()) {
276            assert_eq!(g.is_some(), w.is_some());
277            if let (Some(a), Some(b)) = (g, w) {
278                assert_relative_eq!(a, b, epsilon = 1e-9);
279            }
280        }
281    }
282
283    #[test]
284    fn reset_clears_state() {
285        let mut hw = HoltWinters::new(0.2, 0.1).unwrap();
286        hw.batch(&(1..=20).map(f64::from).collect::<Vec<_>>());
287        assert!(hw.is_ready());
288        hw.reset();
289        assert!(!hw.is_ready());
290        assert_eq!(hw.update(1.0), None);
291    }
292
293    #[test]
294    fn batch_equals_streaming() {
295        let prices: Vec<f64> = (1..=30).map(|i| f64::from(i) * 0.5).collect();
296        let mut a = HoltWinters::new(0.3, 0.2).unwrap();
297        let mut b = HoltWinters::new(0.3, 0.2).unwrap();
298        assert_eq!(
299            a.batch(&prices),
300            prices.iter().map(|p| b.update(*p)).collect::<Vec<_>>()
301        );
302    }
303
304    #[test]
305    fn ignores_non_finite_input() {
306        let mut hw = HoltWinters::new(0.2, 0.1).unwrap();
307        // Non-finite before any state returns None.
308        assert_eq!(hw.update(f64::NAN), None);
309        hw.update(10.0);
310        let ready = hw.update(12.0).expect("seeded on second finite input");
311        // Non-finite after seeding returns the current forecast unchanged.
312        assert_eq!(hw.update(f64::NAN), Some(ready));
313        assert_eq!(hw.update(f64::INFINITY), Some(ready));
314    }
315}