1#[derive(Debug, Clone, Copy, PartialEq)]
38pub enum Seasonality {
39 Additive,
41 Multiplicative,
43}
44
45#[derive(Debug, Clone)]
47pub struct HoltWintersResult {
48 pub level: Vec<f64>,
50 pub trend: Vec<f64>,
52 pub seasonal: Vec<f64>,
54 pub fitted: Vec<f64>,
56}
57
58impl HoltWintersResult {
59 pub fn forecast(&self, h: usize, period: usize, seasonality: Seasonality) -> f64 {
66 let last_l = *self.level.last().expect("level must be non-empty");
67 let last_t = *self.trend.last().expect("trend must be non-empty");
68
69 let s_len = self.seasonal.len();
71 let idx = s_len - period + ((h - 1) % period);
72 let s = self.seasonal[idx];
73
74 match seasonality {
75 Seasonality::Additive => last_l + h as f64 * last_t + s,
76 Seasonality::Multiplicative => (last_l + h as f64 * last_t) * s,
77 }
78 }
79}
80
81pub struct HoltWinters {
83 alpha: f64,
84 beta: f64,
85 gamma: f64,
86 period: usize,
87 seasonality: Seasonality,
88}
89
90impl HoltWinters {
91 pub fn new(
102 alpha: f64,
103 beta: f64,
104 gamma: f64,
105 period: usize,
106 seasonality: Seasonality,
107 ) -> Option<Self> {
108 if !alpha.is_finite() || alpha <= 0.0 || alpha >= 1.0 {
109 return None;
110 }
111 if !beta.is_finite() || beta <= 0.0 || beta >= 1.0 {
112 return None;
113 }
114 if !gamma.is_finite() || gamma <= 0.0 || gamma >= 1.0 {
115 return None;
116 }
117 if period < 2 {
118 return None;
119 }
120 Some(Self {
121 alpha,
122 beta,
123 gamma,
124 period,
125 seasonality,
126 })
127 }
128
129 pub fn period(&self) -> usize {
131 self.period
132 }
133
134 pub fn seasonality(&self) -> Seasonality {
136 self.seasonality
137 }
138
139 pub fn smooth(&self, data: &[f64]) -> Option<HoltWintersResult> {
145 let m = self.period;
146 let n = data.len();
147
148 if n < 2 * m {
149 return None;
150 }
151
152 if self.seasonality == Seasonality::Multiplicative && data.iter().any(|&x| x <= 0.0) {
154 return None;
155 }
156
157 let l0: f64 = data[..m].iter().sum::<f64>() / m as f64;
159 let t0: f64 = (0..m)
160 .map(|i| (data[m + i] - data[i]) / m as f64)
161 .sum::<f64>()
162 / m as f64;
163
164 let mut seasonal = vec![0.0; n];
166 match self.seasonality {
167 Seasonality::Additive => {
168 for i in 0..m {
169 seasonal[i] = data[i] - l0;
170 }
171 }
172 Seasonality::Multiplicative => {
173 for i in 0..m {
174 seasonal[i] = data[i] / l0;
175 }
176 }
177 }
178
179 let mut level = vec![0.0; n];
180 let mut trend = vec![0.0; n];
181 let mut fitted = vec![0.0; n];
182
183 for i in 0..m {
185 level[i] = l0;
186 trend[i] = t0;
187 fitted[i] = match self.seasonality {
188 Seasonality::Additive => l0 + seasonal[i],
189 Seasonality::Multiplicative => l0 * seasonal[i],
190 };
191 }
192
193 for t in m..n {
195 let s_prev = seasonal[t - m];
196
197 let l = match self.seasonality {
198 Seasonality::Additive => {
199 self.alpha * (data[t] - s_prev)
200 + (1.0 - self.alpha) * (level[t - 1] + trend[t - 1])
201 }
202 Seasonality::Multiplicative => {
203 self.alpha * (data[t] / s_prev)
204 + (1.0 - self.alpha) * (level[t - 1] + trend[t - 1])
205 }
206 };
207
208 let b = self.beta * (l - level[t - 1]) + (1.0 - self.beta) * trend[t - 1];
209
210 let s = match self.seasonality {
211 Seasonality::Additive => self.gamma * (data[t] - l) + (1.0 - self.gamma) * s_prev,
212 Seasonality::Multiplicative => {
213 self.gamma * (data[t] / l) + (1.0 - self.gamma) * s_prev
214 }
215 };
216
217 level[t] = l;
218 trend[t] = b;
219 seasonal[t] = s;
220
221 fitted[t] = match self.seasonality {
223 Seasonality::Additive => level[t - 1] + trend[t - 1] + s_prev,
224 Seasonality::Multiplicative => (level[t - 1] + trend[t - 1]) * s_prev,
225 };
226 }
227
228 Some(HoltWintersResult {
229 level,
230 trend,
231 seasonal,
232 fitted,
233 })
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240
241 fn seasonal_additive_data() -> Vec<f64> {
242 let pattern = [10.0, -5.0, -5.0, 0.0];
244 (0..24)
245 .map(|t| 100.0 + 2.0 * t as f64 + pattern[t % 4])
246 .collect()
247 }
248
249 fn seasonal_multiplicative_data() -> Vec<f64> {
250 let pattern = [1.2, 0.8, 0.9, 1.1];
252 (0..24)
253 .map(|t| (100.0 + 2.0 * t as f64) * pattern[t % 4])
254 .collect()
255 }
256
257 #[test]
258 fn test_hw_additive_basic() {
259 let data = seasonal_additive_data();
260 let hw = HoltWinters::new(0.3, 0.1, 0.3, 4, Seasonality::Additive).unwrap();
261 let result = hw.smooth(&data).unwrap();
262
263 assert_eq!(result.level.len(), 24);
264 assert_eq!(result.trend.len(), 24);
265 assert_eq!(result.seasonal.len(), 24);
266 assert_eq!(result.fitted.len(), 24);
267 }
268
269 #[test]
270 fn test_hw_additive_forecast() {
271 let data = seasonal_additive_data();
272 let hw = HoltWinters::new(0.3, 0.1, 0.3, 4, Seasonality::Additive).unwrap();
273 let result = hw.smooth(&data).unwrap();
274
275 let f1 = result.forecast(1, 4, Seasonality::Additive);
277 let f4 = result.forecast(4, 4, Seasonality::Additive);
278
279 assert!(f1 > 100.0, "forecast(1) = {f1}");
281 assert!(f4 > f1 - 20.0, "forecast(4) = {f4}");
282 }
283
284 #[test]
285 fn test_hw_multiplicative_basic() {
286 let data = seasonal_multiplicative_data();
287 let hw = HoltWinters::new(0.3, 0.1, 0.3, 4, Seasonality::Multiplicative).unwrap();
288 let result = hw.smooth(&data).unwrap();
289
290 assert_eq!(result.level.len(), 24);
291 assert_eq!(result.fitted.len(), 24);
292 }
293
294 #[test]
295 fn test_hw_fitted_approximates_data() {
296 let data = seasonal_additive_data();
297 let hw = HoltWinters::new(0.5, 0.3, 0.5, 4, Seasonality::Additive).unwrap();
298 let result = hw.smooth(&data).unwrap();
299
300 let mape: f64 = (8..24)
302 .map(|i| ((result.fitted[i] - data[i]) / data[i]).abs())
303 .sum::<f64>()
304 / 16.0;
305
306 assert!(
307 mape < 0.10,
308 "mean absolute percentage error = {mape}, expected < 10%"
309 );
310 }
311
312 #[test]
313 fn test_hw_seasonal_pattern_detected() {
314 let data = seasonal_additive_data();
315 let hw = HoltWinters::new(0.3, 0.1, 0.5, 4, Seasonality::Additive).unwrap();
316 let result = hw.smooth(&data).unwrap();
317
318 let last_cycle: Vec<f64> = (20..24).map(|i| result.seasonal[i]).collect();
320
321 let max_idx = last_cycle
323 .iter()
324 .enumerate()
325 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
326 .unwrap()
327 .0;
328 assert_eq!(max_idx, 0, "highest seasonal at wrong position");
329 }
330
331 #[test]
332 fn test_hw_insufficient_data() {
333 let hw = HoltWinters::new(0.3, 0.1, 0.3, 4, Seasonality::Additive).unwrap();
334 assert!(hw.smooth(&[1.0; 7]).is_none());
336 assert!(hw.smooth(&[1.0; 8]).is_some());
337 }
338
339 #[test]
340 fn test_hw_multiplicative_rejects_negative() {
341 let hw = HoltWinters::new(0.3, 0.1, 0.3, 4, Seasonality::Multiplicative).unwrap();
342 let data = vec![1.0, 2.0, -1.0, 4.0, 5.0, 6.0, 7.0, 8.0];
343 assert!(hw.smooth(&data).is_none());
344 }
345
346 #[test]
347 fn test_hw_invalid_params() {
348 assert!(HoltWinters::new(0.0, 0.5, 0.5, 4, Seasonality::Additive).is_none());
349 assert!(HoltWinters::new(0.5, 1.0, 0.5, 4, Seasonality::Additive).is_none());
350 assert!(HoltWinters::new(0.5, 0.5, 0.0, 4, Seasonality::Additive).is_none());
351 assert!(HoltWinters::new(0.5, 0.5, 0.5, 1, Seasonality::Additive).is_none());
352 }
353
354 #[test]
355 fn test_hw_trend_detected() {
356 let data = seasonal_additive_data();
357 let hw = HoltWinters::new(0.3, 0.3, 0.3, 4, Seasonality::Additive).unwrap();
358 let result = hw.smooth(&data).unwrap();
359
360 let last_trend = result.trend[23];
362 assert!(
363 last_trend > 1.0 && last_trend < 4.0,
364 "trend = {last_trend}, expected ~2.0"
365 );
366 }
367
368 #[test]
369 fn test_hw_level_tracks_mean() {
370 let data = seasonal_additive_data();
371 let hw = HoltWinters::new(0.3, 0.1, 0.3, 4, Seasonality::Additive).unwrap();
372 let result = hw.smooth(&data).unwrap();
373
374 let last_level = result.level[23];
376 assert!(
377 (last_level - 146.0).abs() < 10.0,
378 "level = {last_level}, expected ~146"
379 );
380 }
381}