wickra_core/indicators/
holt_winters.rs1use crate::error::{Error, Result};
4use crate::traits::Indicator;
5
6#[derive(Debug, Clone)]
43pub struct HoltWinters {
44 alpha: f64,
45 beta: f64,
46 state: Option<(f64, f64)>,
48 prev_price: Option<f64>,
50}
51
52impl HoltWinters {
53 pub fn new(alpha: f64, beta: f64) -> Result<Self> {
61 if !alpha.is_finite() || alpha <= 0.0 || alpha > 1.0 {
62 return Err(Error::InvalidPeriod {
63 message: "HoltWinters alpha must be in (0.0, 1.0]",
64 });
65 }
66 if !beta.is_finite() || beta <= 0.0 || beta > 1.0 {
67 return Err(Error::InvalidPeriod {
68 message: "HoltWinters beta must be in (0.0, 1.0]",
69 });
70 }
71 Ok(Self {
72 alpha,
73 beta,
74 state: None,
75 prev_price: None,
76 })
77 }
78
79 pub const fn alpha(&self) -> f64 {
81 self.alpha
82 }
83
84 pub const fn beta(&self) -> f64 {
86 self.beta
87 }
88
89 pub fn level(&self) -> Option<f64> {
91 self.state.map(|(level, _)| level)
92 }
93
94 pub fn trend(&self) -> Option<f64> {
96 self.state.map(|(_, trend)| trend)
97 }
98
99 pub fn value(&self) -> Option<f64> {
101 self.state.map(|(level, trend)| level + trend)
102 }
103}
104
105impl Indicator for HoltWinters {
106 type Input = f64;
107 type Output = f64;
108
109 fn update(&mut self, price: f64) -> Option<f64> {
110 if !price.is_finite() {
111 return self.value();
112 }
113 match self.state {
114 None => {
115 if let Some(prev) = self.prev_price {
116 let level = price;
118 let trend = price - prev;
119 self.state = Some((level, trend));
120 Some(level + trend)
121 } else {
122 self.prev_price = Some(price);
124 None
125 }
126 }
127 Some((level, trend)) => {
128 let level_new = self.alpha * price + (1.0 - self.alpha) * (level + trend);
129 let trend_new = self.beta * (level_new - level) + (1.0 - self.beta) * trend;
130 self.state = Some((level_new, trend_new));
131 Some(level_new + trend_new)
132 }
133 }
134 }
135
136 fn reset(&mut self) {
137 self.state = None;
138 self.prev_price = None;
139 }
140
141 fn warmup_period(&self) -> usize {
142 2
144 }
145
146 fn is_ready(&self) -> bool {
147 self.state.is_some()
148 }
149
150 fn name(&self) -> &'static str {
151 "HoltWinters"
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158 use crate::traits::BatchExt;
159 use approx::assert_relative_eq;
160
161 fn naive(prices: &[f64], alpha: f64, beta: f64) -> Vec<Option<f64>> {
163 let mut state: Option<(f64, f64)> = None;
164 let mut prev: Option<f64> = None;
165 let mut out = Vec::with_capacity(prices.len());
166 for &price in prices {
167 let v = match state {
168 None => {
169 if let Some(p0) = prev {
170 let level = price;
171 let trend = price - p0;
172 state = Some((level, trend));
173 Some(level + trend)
174 } else {
175 prev = Some(price);
176 None
177 }
178 }
179 Some((level, trend)) => {
180 let ln = alpha * price + (1.0 - alpha) * (level + trend);
181 let tn = beta * (ln - level) + (1.0 - beta) * trend;
182 state = Some((ln, tn));
183 Some(ln + tn)
184 }
185 };
186 out.push(v);
187 }
188 out
189 }
190
191 #[test]
192 fn rejects_invalid_alpha() {
193 assert!(matches!(
194 HoltWinters::new(0.0, 0.1),
195 Err(Error::InvalidPeriod { .. })
196 ));
197 assert!(matches!(
198 HoltWinters::new(1.5, 0.1),
199 Err(Error::InvalidPeriod { .. })
200 ));
201 assert!(matches!(
202 HoltWinters::new(f64::NAN, 0.1),
203 Err(Error::InvalidPeriod { .. })
204 ));
205 }
206
207 #[test]
208 fn rejects_invalid_beta() {
209 assert!(matches!(
210 HoltWinters::new(0.2, 0.0),
211 Err(Error::InvalidPeriod { .. })
212 ));
213 assert!(matches!(
214 HoltWinters::new(0.2, 1.5),
215 Err(Error::InvalidPeriod { .. })
216 ));
217 assert!(matches!(
218 HoltWinters::new(0.2, f64::INFINITY),
219 Err(Error::InvalidPeriod { .. })
220 ));
221 }
222
223 #[test]
226 fn accessors_and_metadata() {
227 let hw = HoltWinters::new(0.2, 0.1).unwrap();
228 assert_relative_eq!(hw.alpha(), 0.2, epsilon = 1e-12);
229 assert_relative_eq!(hw.beta(), 0.1, epsilon = 1e-12);
230 assert_eq!(hw.warmup_period(), 2);
231 assert_eq!(hw.name(), "HoltWinters");
232 }
233
234 #[test]
235 fn warmup_then_seed_on_second_input() {
236 let mut hw = HoltWinters::new(0.2, 0.1).unwrap();
237 assert_eq!(hw.update(10.0), None);
238 assert_relative_eq!(hw.update(12.0).unwrap(), 14.0, epsilon = 1e-12);
240 assert_relative_eq!(hw.level().unwrap(), 12.0, epsilon = 1e-12);
241 assert_relative_eq!(hw.trend().unwrap(), 2.0, epsilon = 1e-12);
242 }
243
244 #[test]
245 fn linear_series_forecasts_exactly() {
246 let prices: Vec<f64> = (1..=20).map(f64::from).collect();
249 let mut hw = HoltWinters::new(0.3, 0.4).unwrap();
250 let out = hw.batch(&prices);
251 assert!(out[0].is_none());
252 for (i, v) in out.iter().enumerate().skip(1) {
253 assert_relative_eq!(v.unwrap(), (i + 2) as f64, epsilon = 1e-9);
255 }
256 }
257
258 #[test]
259 fn constant_series_yields_constant() {
260 let mut hw = HoltWinters::new(0.2, 0.1).unwrap();
261 let out = hw.batch(&[42.0_f64; 30]);
262 for v in out.into_iter().skip(1).flatten() {
263 assert_relative_eq!(v, 42.0, epsilon = 1e-9);
264 }
265 }
266
267 #[test]
268 fn matches_naive_recurrence() {
269 let prices: Vec<f64> = (0..60)
270 .map(|i| 100.0 + (f64::from(i) * 0.3).sin() * 10.0 + f64::from(i) * 0.2)
271 .collect();
272 let mut hw = HoltWinters::new(0.25, 0.15).unwrap();
273 let got = hw.batch(&prices);
274 let want = naive(&prices, 0.25, 0.15);
275 for (g, w) in got.iter().zip(want.iter()) {
276 assert_eq!(g.is_some(), w.is_some());
277 if let (Some(a), Some(b)) = (g, w) {
278 assert_relative_eq!(a, b, epsilon = 1e-9);
279 }
280 }
281 }
282
283 #[test]
284 fn reset_clears_state() {
285 let mut hw = HoltWinters::new(0.2, 0.1).unwrap();
286 hw.batch(&(1..=20).map(f64::from).collect::<Vec<_>>());
287 assert!(hw.is_ready());
288 hw.reset();
289 assert!(!hw.is_ready());
290 assert_eq!(hw.update(1.0), None);
291 }
292
293 #[test]
294 fn batch_equals_streaming() {
295 let prices: Vec<f64> = (1..=30).map(|i| f64::from(i) * 0.5).collect();
296 let mut a = HoltWinters::new(0.3, 0.2).unwrap();
297 let mut b = HoltWinters::new(0.3, 0.2).unwrap();
298 assert_eq!(
299 a.batch(&prices),
300 prices.iter().map(|p| b.update(*p)).collect::<Vec<_>>()
301 );
302 }
303
304 #[test]
305 fn ignores_non_finite_input() {
306 let mut hw = HoltWinters::new(0.2, 0.1).unwrap();
307 assert_eq!(hw.update(f64::NAN), None);
309 hw.update(10.0);
310 let ready = hw.update(12.0).expect("seeded on second finite input");
311 assert_eq!(hw.update(f64::NAN), Some(ready));
313 assert_eq!(hw.update(f64::INFINITY), Some(ready));
314 }
315}