Skip to main content

u_analytics/smoothing/
holt_winters.rs

1//! Holt-Winters Triple Exponential Smoothing.
2//!
3//! Extends Holt's method with a seasonal component, supporting both
4//! additive and multiplicative seasonality.
5//!
6//! # Algorithm (Additive)
7//!
8//! ```text
9//! Level:    L_t = α (x_t - S_{t-m}) + (1 - α)(L_{t-1} + T_{t-1})
10//! Trend:    T_t = β (L_t - L_{t-1}) + (1 - β) T_{t-1}
11//! Season:   S_t = γ (x_t - L_t) + (1 - γ) S_{t-m}
12//! Forecast: F_{t+h} = L_t + h T_t + S_{t-m+h_m}
13//! ```
14//!
15//! # Algorithm (Multiplicative)
16//!
17//! ```text
18//! Level:    L_t = α (x_t / S_{t-m}) + (1 - α)(L_{t-1} + T_{t-1})
19//! Trend:    T_t = β (L_t - L_{t-1}) + (1 - β) T_{t-1}
20//! Season:   S_t = γ (x_t / L_t) + (1 - γ) S_{t-m}
21//! Forecast: F_{t+h} = (L_t + h T_t) S_{t-m+h_m}
22//! ```
23//!
24//! # Parameters
25//!
26//! - α ∈ (0, 1): level smoothing
27//! - β ∈ (0, 1): trend smoothing
28//! - γ ∈ (0, 1): seasonal smoothing
29//! - m: seasonal period (e.g., 12 for monthly data with yearly cycle)
30//!
31//! # Reference
32//!
33//! Winters, P.R. (1960). "Forecasting Sales by Exponentially Weighted
34//! Moving Averages", *Management Science* 6(3), pp. 324-342.
35
36/// Seasonality type.
37#[derive(Debug, Clone, Copy, PartialEq)]
38pub enum Seasonality {
39    /// Additive: seasonal effect is added to the trend.
40    Additive,
41    /// Multiplicative: seasonal effect multiplies the trend.
42    Multiplicative,
43}
44
45/// Result of Holt-Winters smoothing.
46#[derive(Debug, Clone)]
47pub struct HoltWintersResult {
48    /// Level estimates.
49    pub level: Vec<f64>,
50    /// Trend estimates.
51    pub trend: Vec<f64>,
52    /// Seasonal factors (length = data length + initial period).
53    pub seasonal: Vec<f64>,
54    /// Fitted values (one-step-ahead in-sample forecasts).
55    pub fitted: Vec<f64>,
56}
57
58impl HoltWintersResult {
59    /// Returns a forecast h steps ahead from the last observation.
60    ///
61    /// # Parameters
62    /// - `h`: steps ahead (1-indexed)
63    /// - `period`: seasonal period m
64    /// - `seasonality`: additive or multiplicative
65    pub fn forecast(&self, h: usize, period: usize, seasonality: Seasonality) -> f64 {
66        let last_l = *self.level.last().expect("level must be non-empty");
67        let last_t = *self.trend.last().expect("trend must be non-empty");
68
69        // Seasonal factor: use the most recent completed cycle
70        let s_len = self.seasonal.len();
71        let idx = s_len - period + ((h - 1) % period);
72        let s = self.seasonal[idx];
73
74        match seasonality {
75            Seasonality::Additive => last_l + h as f64 * last_t + s,
76            Seasonality::Multiplicative => (last_l + h as f64 * last_t) * s,
77        }
78    }
79}
80
81/// Holt-Winters Triple Exponential Smoothing.
82pub struct HoltWinters {
83    alpha: f64,
84    beta: f64,
85    gamma: f64,
86    period: usize,
87    seasonality: Seasonality,
88}
89
90impl HoltWinters {
91    /// Creates a new Holt-Winters smoother.
92    ///
93    /// # Parameters
94    /// - `alpha`: level smoothing constant ∈ (0, 1)
95    /// - `beta`: trend smoothing constant ∈ (0, 1)
96    /// - `gamma`: seasonal smoothing constant ∈ (0, 1)
97    /// - `period`: seasonal period (must be ≥ 2)
98    /// - `seasonality`: additive or multiplicative
99    ///
100    /// Returns `None` if parameters are invalid.
101    pub fn new(
102        alpha: f64,
103        beta: f64,
104        gamma: f64,
105        period: usize,
106        seasonality: Seasonality,
107    ) -> Option<Self> {
108        if !alpha.is_finite() || alpha <= 0.0 || alpha >= 1.0 {
109            return None;
110        }
111        if !beta.is_finite() || beta <= 0.0 || beta >= 1.0 {
112            return None;
113        }
114        if !gamma.is_finite() || gamma <= 0.0 || gamma >= 1.0 {
115            return None;
116        }
117        if period < 2 {
118            return None;
119        }
120        Some(Self {
121            alpha,
122            beta,
123            gamma,
124            period,
125            seasonality,
126        })
127    }
128
129    /// Returns the seasonal period.
130    pub fn period(&self) -> usize {
131        self.period
132    }
133
134    /// Returns the seasonality type.
135    pub fn seasonality(&self) -> Seasonality {
136        self.seasonality
137    }
138
139    /// Applies Holt-Winters smoothing to the data.
140    ///
141    /// Requires at least `2 * period` data points for initialization.
142    /// Returns `None` if data is insufficient or if multiplicative
143    /// seasonality is used with non-positive data.
144    pub fn smooth(&self, data: &[f64]) -> Option<HoltWintersResult> {
145        let m = self.period;
146        let n = data.len();
147
148        if n < 2 * m {
149            return None;
150        }
151
152        // Multiplicative: all data must be positive
153        if self.seasonality == Seasonality::Multiplicative && data.iter().any(|&x| x <= 0.0) {
154            return None;
155        }
156
157        // --- Initialization ---
158        let l0: f64 = data[..m].iter().sum::<f64>() / m as f64;
159        let t0: f64 = (0..m)
160            .map(|i| (data[m + i] - data[i]) / m as f64)
161            .sum::<f64>()
162            / m as f64;
163
164        // seasonal[i] = seasonal factor for time i
165        let mut seasonal = vec![0.0; n];
166        match self.seasonality {
167            Seasonality::Additive => {
168                for i in 0..m {
169                    seasonal[i] = data[i] - l0;
170                }
171            }
172            Seasonality::Multiplicative => {
173                for i in 0..m {
174                    seasonal[i] = data[i] / l0;
175                }
176            }
177        }
178
179        let mut level = vec![0.0; n];
180        let mut trend = vec![0.0; n];
181        let mut fitted = vec![0.0; n];
182
183        // Set initial values for times 0..m-1
184        for i in 0..m {
185            level[i] = l0;
186            trend[i] = t0;
187            fitted[i] = match self.seasonality {
188                Seasonality::Additive => l0 + seasonal[i],
189                Seasonality::Multiplicative => l0 * seasonal[i],
190            };
191        }
192
193        // Main smoothing loop
194        for t in m..n {
195            let s_prev = seasonal[t - m];
196
197            let l = match self.seasonality {
198                Seasonality::Additive => {
199                    self.alpha * (data[t] - s_prev)
200                        + (1.0 - self.alpha) * (level[t - 1] + trend[t - 1])
201                }
202                Seasonality::Multiplicative => {
203                    self.alpha * (data[t] / s_prev)
204                        + (1.0 - self.alpha) * (level[t - 1] + trend[t - 1])
205                }
206            };
207
208            let b = self.beta * (l - level[t - 1]) + (1.0 - self.beta) * trend[t - 1];
209
210            let s = match self.seasonality {
211                Seasonality::Additive => self.gamma * (data[t] - l) + (1.0 - self.gamma) * s_prev,
212                Seasonality::Multiplicative => {
213                    self.gamma * (data[t] / l) + (1.0 - self.gamma) * s_prev
214                }
215            };
216
217            level[t] = l;
218            trend[t] = b;
219            seasonal[t] = s;
220
221            // One-step-ahead fitted value
222            fitted[t] = match self.seasonality {
223                Seasonality::Additive => level[t - 1] + trend[t - 1] + s_prev,
224                Seasonality::Multiplicative => (level[t - 1] + trend[t - 1]) * s_prev,
225            };
226        }
227
228        Some(HoltWintersResult {
229            level,
230            trend,
231            seasonal,
232            fitted,
233        })
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    fn seasonal_additive_data() -> Vec<f64> {
242        // Base trend: 100 + 2*t with seasonal pattern [10, -5, -5, 0]
243        let pattern = [10.0, -5.0, -5.0, 0.0];
244        (0..24)
245            .map(|t| 100.0 + 2.0 * t as f64 + pattern[t % 4])
246            .collect()
247    }
248
249    fn seasonal_multiplicative_data() -> Vec<f64> {
250        // Base trend: 100 + 2*t with multiplicative seasonal [1.2, 0.8, 0.9, 1.1]
251        let pattern = [1.2, 0.8, 0.9, 1.1];
252        (0..24)
253            .map(|t| (100.0 + 2.0 * t as f64) * pattern[t % 4])
254            .collect()
255    }
256
257    #[test]
258    fn test_hw_additive_basic() {
259        let data = seasonal_additive_data();
260        let hw = HoltWinters::new(0.3, 0.1, 0.3, 4, Seasonality::Additive).unwrap();
261        let result = hw.smooth(&data).unwrap();
262
263        assert_eq!(result.level.len(), 24);
264        assert_eq!(result.trend.len(), 24);
265        assert_eq!(result.seasonal.len(), 24);
266        assert_eq!(result.fitted.len(), 24);
267    }
268
269    #[test]
270    fn test_hw_additive_forecast() {
271        let data = seasonal_additive_data();
272        let hw = HoltWinters::new(0.3, 0.1, 0.3, 4, Seasonality::Additive).unwrap();
273        let result = hw.smooth(&data).unwrap();
274
275        // Forecast should continue the pattern
276        let f1 = result.forecast(1, 4, Seasonality::Additive);
277        let f4 = result.forecast(4, 4, Seasonality::Additive);
278
279        // Both should be in reasonable range
280        assert!(f1 > 100.0, "forecast(1) = {f1}");
281        assert!(f4 > f1 - 20.0, "forecast(4) = {f4}");
282    }
283
284    #[test]
285    fn test_hw_multiplicative_basic() {
286        let data = seasonal_multiplicative_data();
287        let hw = HoltWinters::new(0.3, 0.1, 0.3, 4, Seasonality::Multiplicative).unwrap();
288        let result = hw.smooth(&data).unwrap();
289
290        assert_eq!(result.level.len(), 24);
291        assert_eq!(result.fitted.len(), 24);
292    }
293
294    #[test]
295    fn test_hw_fitted_approximates_data() {
296        let data = seasonal_additive_data();
297        let hw = HoltWinters::new(0.5, 0.3, 0.5, 4, Seasonality::Additive).unwrap();
298        let result = hw.smooth(&data).unwrap();
299
300        // After warm-up, fitted should be close to data
301        let mape: f64 = (8..24)
302            .map(|i| ((result.fitted[i] - data[i]) / data[i]).abs())
303            .sum::<f64>()
304            / 16.0;
305
306        assert!(
307            mape < 0.10,
308            "mean absolute percentage error = {mape}, expected < 10%"
309        );
310    }
311
312    #[test]
313    fn test_hw_seasonal_pattern_detected() {
314        let data = seasonal_additive_data();
315        let hw = HoltWinters::new(0.3, 0.1, 0.5, 4, Seasonality::Additive).unwrap();
316        let result = hw.smooth(&data).unwrap();
317
318        // Last seasonal factors should reflect the pattern [10, -5, -5, 0]
319        let last_cycle: Vec<f64> = (20..24).map(|i| result.seasonal[i]).collect();
320
321        // Highest seasonal should be at position 0 mod 4 (pattern = +10)
322        let max_idx = last_cycle
323            .iter()
324            .enumerate()
325            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
326            .unwrap()
327            .0;
328        assert_eq!(max_idx, 0, "highest seasonal at wrong position");
329    }
330
331    #[test]
332    fn test_hw_insufficient_data() {
333        let hw = HoltWinters::new(0.3, 0.1, 0.3, 4, Seasonality::Additive).unwrap();
334        // Need at least 2*period = 8 data points
335        assert!(hw.smooth(&[1.0; 7]).is_none());
336        assert!(hw.smooth(&[1.0; 8]).is_some());
337    }
338
339    #[test]
340    fn test_hw_multiplicative_rejects_negative() {
341        let hw = HoltWinters::new(0.3, 0.1, 0.3, 4, Seasonality::Multiplicative).unwrap();
342        let data = vec![1.0, 2.0, -1.0, 4.0, 5.0, 6.0, 7.0, 8.0];
343        assert!(hw.smooth(&data).is_none());
344    }
345
346    #[test]
347    fn test_hw_invalid_params() {
348        assert!(HoltWinters::new(0.0, 0.5, 0.5, 4, Seasonality::Additive).is_none());
349        assert!(HoltWinters::new(0.5, 1.0, 0.5, 4, Seasonality::Additive).is_none());
350        assert!(HoltWinters::new(0.5, 0.5, 0.0, 4, Seasonality::Additive).is_none());
351        assert!(HoltWinters::new(0.5, 0.5, 0.5, 1, Seasonality::Additive).is_none());
352    }
353
354    #[test]
355    fn test_hw_trend_detected() {
356        let data = seasonal_additive_data();
357        let hw = HoltWinters::new(0.3, 0.3, 0.3, 4, Seasonality::Additive).unwrap();
358        let result = hw.smooth(&data).unwrap();
359
360        // True trend is +2.0 per step; estimated should be positive
361        let last_trend = result.trend[23];
362        assert!(
363            last_trend > 1.0 && last_trend < 4.0,
364            "trend = {last_trend}, expected ~2.0"
365        );
366    }
367
368    #[test]
369    fn test_hw_level_tracks_mean() {
370        let data = seasonal_additive_data();
371        let hw = HoltWinters::new(0.3, 0.1, 0.3, 4, Seasonality::Additive).unwrap();
372        let result = hw.smooth(&data).unwrap();
373
374        // At t=23, true level ≈ 100 + 2*23 = 146
375        let last_level = result.level[23];
376        assert!(
377            (last_level - 146.0).abs() < 10.0,
378            "level = {last_level}, expected ~146"
379        );
380    }
381}