scirs2_series/streaming/
forecasting.rs1use scirs2_core::ndarray::Array1;
7use scirs2_core::numeric::{Float, FromPrimitive};
8use std::collections::VecDeque;
9use std::fmt::Debug;
10
11use crate::error::{Result, TimeSeriesError};
12
13#[derive(Debug)]
15pub struct StreamingForecaster<F: Float + Debug> {
16 alpha: F,
18 beta: Option<F>,
20 gamma: Option<F>,
22 seasonal_period: Option<usize>,
24 level: Option<F>,
26 trend: Option<F>,
28 seasonal: VecDeque<F>,
30 buffer: VecDeque<F>,
32 max_buffer_size: usize,
34 observation_count: usize,
36}
37
38impl<F: Float + Debug + Clone> StreamingForecaster<F> {
39 pub fn new(
41 alpha: F,
42 beta: Option<F>,
43 gamma: Option<F>,
44 seasonal_period: Option<usize>,
45 max_buffer_size: usize,
46 ) -> Result<Self> {
47 if alpha <= F::zero() || alpha > F::one() {
48 return Err(TimeSeriesError::InvalidParameter {
49 name: "alpha".to_string(),
50 message: "Alpha must be between 0 and 1".to_string(),
51 });
52 }
53
54 let seasonal = if let Some(_period) = seasonal_period {
55 VecDeque::with_capacity(_period)
56 } else {
57 VecDeque::new()
58 };
59
60 Ok(Self {
61 alpha,
62 beta,
63 gamma,
64 seasonal_period,
65 level: None,
66 trend: None,
67 seasonal,
68 buffer: VecDeque::with_capacity(max_buffer_size),
69 max_buffer_size,
70 observation_count: 0,
71 })
72 }
73
74 pub fn update(&mut self, value: F) -> Result<()> {
76 self.observation_count += 1;
77
78 if self.buffer.len() >= self.max_buffer_size {
80 self.buffer.pop_front();
81 }
82 self.buffer.push_back(value);
83
84 if self.level.is_none() {
86 self.level = Some(value);
87 if self.beta.is_some() {
88 self.trend = Some(F::zero());
89 }
90 if let Some(period) = self.seasonal_period {
91 for _ in 0..period {
92 self.seasonal.push_back(F::zero());
93 }
94 }
95 return Ok(());
96 }
97
98 let current_level = self.level.expect("Operation failed");
99 let mut new_level = value;
100
101 let _seasonal_component = if let Some(period) = self.seasonal_period {
103 if self.seasonal.len() >= period {
104 let seasonal_idx = (self.observation_count - 1) % period;
105 let seasonal_val = self.seasonal[seasonal_idx];
106 new_level = new_level - seasonal_val;
107 seasonal_val
108 } else {
109 F::zero()
110 }
111 } else {
112 F::zero()
113 };
114
115 self.level = Some(self.alpha * new_level + (F::one() - self.alpha) * current_level);
117
118 if let Some(beta) = self.beta {
120 if let Some(current_trend) = self.trend {
121 let new_trend = beta * (self.level.expect("Operation failed") - current_level)
122 + (F::one() - beta) * current_trend;
123 self.trend = Some(new_trend);
124 }
125 }
126
127 if let (Some(gamma), Some(period)) = (self.gamma, self.seasonal_period) {
129 if self.seasonal.len() >= period {
130 let seasonal_idx = (self.observation_count - 1) % period;
131 let current_seasonal = self.seasonal[seasonal_idx];
132 let new_seasonal = gamma * (value - self.level.expect("Operation failed"))
133 + (F::one() - gamma) * current_seasonal;
134 self.seasonal[seasonal_idx] = new_seasonal;
135 }
136 }
137
138 Ok(())
139 }
140
141 pub fn forecast(&self, steps: usize) -> Result<Array1<F>> {
143 if self.level.is_none() {
144 return Err(TimeSeriesError::InvalidModel(
145 "Model not initialized with any data".to_string(),
146 ));
147 }
148
149 let mut forecasts = Array1::zeros(steps);
150 let level = self.level.expect("Operation failed");
151 let trend = self.trend.unwrap_or(F::zero());
152
153 for h in 0..steps {
154 let h_f = F::from(h + 1).expect("Failed to convert to float");
155 let mut forecast = level + trend * h_f;
156
157 if let Some(period) = self.seasonal_period {
159 if !self.seasonal.is_empty() {
160 let seasonal_idx = (self.observation_count + h) % period;
161 if seasonal_idx < self.seasonal.len() {
162 forecast = forecast + self.seasonal[seasonal_idx];
163 }
164 }
165 }
166
167 forecasts[h] = forecast;
168 }
169
170 Ok(forecasts)
171 }
172
173 pub fn get_state(&self) -> ModelState<F> {
175 ModelState {
176 level: self.level,
177 trend: self.trend,
178 seasonal_components: self.seasonal.iter().cloned().collect(),
179 observation_count: self.observation_count,
180 buffer_size: self.buffer.len(),
181 }
182 }
183}
184
185#[derive(Debug, Clone)]
187pub struct ModelState<F: Float> {
188 pub level: Option<F>,
190 pub trend: Option<F>,
192 pub seasonal_components: Vec<F>,
194 pub observation_count: usize,
196 pub buffer_size: usize,
198}