scirs2_series/state_space/
structural.rs1use super::linear_gaussian::LinearGaussianSSM;
17use crate::error::{Result, TimeSeriesError};
18
19#[derive(Debug, Clone)]
25pub struct TrendComponent {
26 pub level: f64,
28 pub slope: f64,
30 pub level_var: f64,
32 pub slope_var: f64,
34}
35
36impl TrendComponent {
37 pub fn new(level_var: f64, slope_var: f64) -> Self {
39 Self {
40 level: 0.0,
41 slope: 0.0,
42 level_var: level_var.max(1e-10),
43 slope_var: slope_var.max(0.0),
44 }
45 }
46
47 pub fn local_level(level_var: f64) -> Self {
49 Self::new(level_var, 0.0)
50 }
51
52 pub fn state_dim(&self) -> usize {
54 if self.slope_var > 0.0 {
55 2
56 } else {
57 1
58 }
59 }
60}
61
62#[derive(Debug, Clone)]
64pub struct SeasonalComponent {
65 pub period: usize,
67 pub values: Vec<f64>,
69 pub var: f64,
71}
72
73impl SeasonalComponent {
74 pub fn new(period: usize, var: f64) -> Result<Self> {
76 if period < 2 {
77 return Err(TimeSeriesError::InvalidInput(
78 "Seasonal period must be >= 2".to_string(),
79 ));
80 }
81 Ok(Self {
82 period,
83 values: vec![0.0; period - 1],
84 var: var.max(1e-10),
85 })
86 }
87
88 pub fn state_dim(&self) -> usize {
90 self.period - 1
91 }
92}
93
94#[derive(Debug, Clone)]
100pub struct StructuralModel {
101 pub trend: TrendComponent,
103 pub seasonal: Option<SeasonalComponent>,
105 pub irregular_var: f64,
107}
108
109impl StructuralModel {
110 pub fn new(period: Option<usize>) -> Result<Self> {
114 let seasonal = match period {
115 Some(s) => Some(SeasonalComponent::new(s, 0.1)?),
116 None => None,
117 };
118 Ok(Self {
119 trend: TrendComponent::new(0.1, 0.01),
120 seasonal,
121 irregular_var: 0.5,
122 })
123 }
124
125 pub fn local_level(level_var: f64, obs_var: f64) -> Self {
127 Self {
128 trend: TrendComponent::local_level(level_var),
129 seasonal: None,
130 irregular_var: obs_var.max(1e-10),
131 }
132 }
133
134 pub fn local_linear_trend(level_var: f64, slope_var: f64, obs_var: f64) -> Self {
136 Self {
137 trend: TrendComponent::new(level_var, slope_var),
138 seasonal: None,
139 irregular_var: obs_var.max(1e-10),
140 }
141 }
142
143 pub fn state_dim(&self) -> usize {
145 let trend_d = self.trend.state_dim();
146 let seas_d = self.seasonal.as_ref().map_or(0, |s| s.state_dim());
147 trend_d + seas_d
148 }
149
150 pub fn to_ssm(&self) -> LinearGaussianSSM {
155 let n = self.state_dim();
156 let trend_d = self.trend.state_dim();
157 let seas_d = self.seasonal.as_ref().map_or(0, |s| s.state_dim());
158
159 let mut f = vec![vec![0.0f64; n]; n];
161
162 if self.trend.state_dim() == 1 {
164 f[0][0] = 1.0;
166 } else {
167 f[0][0] = 1.0;
169 f[0][1] = 1.0;
170 f[1][1] = 1.0;
171 }
172
173 if let Some(seas) = &self.seasonal {
175 let s = seas.state_dim(); let off = trend_d;
177 for j in 0..s {
179 f[off][off + j] = -1.0;
180 }
181 for i in 1..s {
183 f[off + i][off + i - 1] = 1.0;
184 }
185 }
186
187 let mut h = vec![vec![0.0f64; n]];
189 h[0][0] = 1.0; if seas_d > 0 {
191 h[0][trend_d] = 1.0; }
193
194 let mut q = vec![vec![0.0f64; n]; n];
196 q[0][0] = self.trend.level_var;
197 if self.trend.state_dim() == 2 {
198 q[1][1] = self.trend.slope_var;
199 }
200 if let Some(seas) = &self.seasonal {
201 q[trend_d][trend_d] = seas.var;
202 }
203
204 let r = vec![vec![self.irregular_var]];
206
207 let mu0 = vec![0.0f64; n];
209 let mut p0 = vec![vec![0.0f64; n]; n];
210 for i in 0..n {
211 p0[i][i] = 1e6;
212 }
213
214 LinearGaussianSSM {
215 dim_state: n,
216 dim_obs: 1,
217 f_mat: f,
218 h_mat: h,
219 q_mat: q,
220 r_mat: r,
221 mu0,
222 p0,
223 }
224 }
225
226 fn log_likelihood_from_params(&self, params: &[f64], data: &[f64]) -> f64 {
230 let mut model = self.clone();
231 model.apply_params(params);
232 let ssm = model.to_ssm();
233 let obs: Vec<Vec<f64>> = data.iter().map(|&y| vec![y]).collect();
234 ssm.filter(&obs)
235 .map_or(f64::NEG_INFINITY, |k| k.log_likelihood)
236 }
237
238 fn apply_params(&mut self, params: &[f64]) {
240 let mut idx = 0;
241 self.trend.level_var = params[idx].exp().max(1e-10);
243 idx += 1;
244 if self.trend.state_dim() == 2 {
246 self.trend.slope_var = params[idx].exp().max(1e-10);
247 idx += 1;
248 }
249 if let Some(seas) = &mut self.seasonal {
251 seas.var = params[idx].exp().max(1e-10);
252 idx += 1;
253 }
254 if idx < params.len() {
256 self.irregular_var = params[idx].exp().max(1e-10);
257 }
258 }
259
260 fn initial_params(&self) -> Vec<f64> {
262 let mut p = Vec::new();
263 p.push(self.trend.level_var.max(1e-10).ln());
264 if self.trend.state_dim() == 2 {
265 p.push(self.trend.slope_var.max(1e-10).ln());
266 }
267 if let Some(seas) = &self.seasonal {
268 p.push(seas.var.max(1e-10).ln());
269 }
270 p.push(self.irregular_var.max(1e-10).ln());
271 p
272 }
273
274 pub fn fit(&mut self, data: &[f64]) -> Result<f64> {
280 let n = data.len();
281 if n < 3 {
282 return Err(TimeSeriesError::InsufficientData {
283 message: "StructuralModel::fit requires at least 3 observations".to_string(),
284 required: 3,
285 actual: n,
286 });
287 }
288
289 let mut params = self.initial_params();
290 let np = params.len();
291 let max_outer = 100;
292 let tol = 1e-6;
293
294 let mut best_ll = self.log_likelihood_from_params(¶ms, data);
295
296 for _outer in 0..max_outer {
298 let prev_ll = best_ll;
299 for pi in 0..np {
300 let (best_v, best_local) = golden_section_search_1d(
302 |v| {
303 let mut p2 = params.clone();
304 p2[pi] = v;
305 self.log_likelihood_from_params(&p2, data)
306 },
307 params[pi] - 6.0,
308 params[pi] + 6.0,
309 30,
310 );
311 if best_local > best_ll {
312 params[pi] = best_v;
313 best_ll = best_local;
314 }
315 }
316 if (best_ll - prev_ll).abs() < tol {
317 break;
318 }
319 }
320
321 self.apply_params(¶ms);
323 Ok(best_ll)
324 }
325
326 pub fn decompose(&self, data: &[f64]) -> Result<(Vec<f64>, Vec<f64>, Vec<f64>)> {
330 let n = data.len();
331 if n == 0 {
332 return Ok((vec![], vec![], vec![]));
333 }
334
335 let ssm = self.to_ssm();
336 let obs: Vec<Vec<f64>> = data.iter().map(|&y| vec![y]).collect();
337 let (sm_means, _sm_covs) = ssm.smooth(&obs)?;
338
339 let trend_d = self.trend.state_dim();
340 let seas_d = self.seasonal.as_ref().map_or(0, |s| s.state_dim());
341
342 let mut trend_vec = Vec::with_capacity(n);
343 let mut seas_vec = Vec::with_capacity(n);
344 let mut irreg_vec = Vec::with_capacity(n);
345
346 for t in 0..n {
347 let level = sm_means[t][0];
348 let seas_val = if seas_d > 0 {
349 sm_means[t][trend_d]
350 } else {
351 0.0
352 };
353 let fitted = level + seas_val;
354 let irregular = data[t] - fitted;
355
356 trend_vec.push(level);
357 seas_vec.push(seas_val);
358 irreg_vec.push(irregular);
359 }
360
361 Ok((trend_vec, seas_vec, irreg_vec))
362 }
363}
364
365fn golden_section_search_1d<F>(f: F, a: f64, b: f64, n_iter: usize) -> (f64, f64)
371where
372 F: Fn(f64) -> f64,
373{
374 let phi = (5.0_f64.sqrt() - 1.0) / 2.0; let mut lo = a;
376 let mut hi = b;
377 let mut x1 = hi - phi * (hi - lo);
378 let mut x2 = lo + phi * (hi - lo);
379 let mut f1 = f(x1);
380 let mut f2 = f(x2);
381
382 for _ in 0..n_iter {
383 if f1 < f2 {
384 lo = x1;
385 x1 = x2;
386 f1 = f2;
387 x2 = lo + phi * (hi - lo);
388 f2 = f(x2);
389 } else {
390 hi = x2;
391 x2 = x1;
392 f2 = f1;
393 x1 = hi - phi * (hi - lo);
394 f1 = f(x1);
395 }
396 }
397
398 let best_x = (lo + hi) / 2.0;
399 let best_f = f(best_x);
400 (best_x, best_f)
401}
402
403#[cfg(test)]
408mod tests {
409 use super::*;
410
411 fn trend_data(n: usize) -> Vec<f64> {
412 (0..n)
413 .map(|i| 1.0 + 0.05 * i as f64 + 0.1 * (i as f64 * 0.7).sin())
414 .collect()
415 }
416
417 fn seasonal_data(n: usize, period: usize) -> Vec<f64> {
418 (0..n)
419 .map(|i| {
420 let trend = 1.0 + 0.02 * i as f64;
421 let seas = (2.0 * std::f64::consts::PI * i as f64 / period as f64).sin();
422 trend + seas + 0.05 * (i as f64 * 1.3).cos()
423 })
424 .collect()
425 }
426
427 #[test]
428 fn test_local_level_to_ssm() {
429 let m = StructuralModel::local_level(0.1, 0.5);
430 assert_eq!(m.state_dim(), 1);
431 let ssm = m.to_ssm();
432 assert_eq!(ssm.dim_state, 1);
433 assert_eq!(ssm.f_mat[0][0], 1.0);
434 assert!((ssm.q_mat[0][0] - 0.1).abs() < 1e-10);
435 assert!((ssm.r_mat[0][0] - 0.5).abs() < 1e-10);
436 }
437
438 #[test]
439 fn test_local_linear_trend_to_ssm() {
440 let m = StructuralModel::local_linear_trend(0.1, 0.01, 0.5);
441 assert_eq!(m.state_dim(), 2);
442 let ssm = m.to_ssm();
443 assert_eq!(ssm.f_mat[0][0], 1.0);
445 assert_eq!(ssm.f_mat[0][1], 1.0);
446 assert_eq!(ssm.f_mat[1][0], 0.0);
447 assert_eq!(ssm.f_mat[1][1], 1.0);
448 }
449
450 #[test]
451 fn test_seasonal_to_ssm() {
452 let m = StructuralModel::new(Some(4)).expect("ok");
453 let ssm = m.to_ssm();
454 assert_eq!(ssm.dim_state, m.state_dim());
457 }
458
459 #[test]
460 fn test_decompose_local_level() {
461 let data = trend_data(40);
462 let m = StructuralModel::local_level(0.2, 0.1);
463 let (trend, seas, irreg) = m.decompose(&data).expect("decompose ok");
464 assert_eq!(trend.len(), 40);
465 assert_eq!(seas.len(), 40);
466 assert_eq!(irreg.len(), 40);
467 for &s in &seas {
469 assert_eq!(s, 0.0);
470 }
471 for i in 0..40 {
473 let recon = trend[i] + seas[i] + irreg[i];
474 assert!(
475 (recon - data[i]).abs() < 1e-6,
476 "Reconstruction failed at {i}"
477 );
478 }
479 }
480
481 #[test]
482 fn test_decompose_seasonal() {
483 let data = seasonal_data(48, 4);
484 let m = StructuralModel::new(Some(4)).expect("ok");
485 let (trend, _seas, irreg) = m.decompose(&data).expect("decompose ok");
486 assert_eq!(trend.len(), 48);
487 assert_eq!(irreg.len(), 48);
488 }
489
490 #[test]
491 fn test_fit_level_extraction() {
492 let data: Vec<f64> = (0..30)
494 .map(|i| 5.0 + 0.1 * ((i as f64) * 1.23).sin())
495 .collect();
496 let mut m = StructuralModel::local_level(0.05, 0.2);
497 let ll = m.fit(&data).expect("fit ok");
498 assert!(ll.is_finite());
499
500 let (trend, _seas, _irreg) = m.decompose(&data).expect("decompose ok");
501 let level_mean: f64 = trend[10..30].iter().sum::<f64>() / 20.0;
503 assert!(
504 (level_mean - 5.0).abs() < 1.0,
505 "Level mean {level_mean} far from 5.0"
506 );
507 }
508
509 #[test]
510 fn test_seasonal_component_creation() {
511 let s = SeasonalComponent::new(12, 0.05).expect("ok");
512 assert_eq!(s.period, 12);
513 assert_eq!(s.state_dim(), 11);
514 assert_eq!(s.values.len(), 11);
515 }
516
517 #[test]
518 fn test_new_with_period() {
519 let m = StructuralModel::new(Some(7)).expect("ok");
520 assert!(m.seasonal.is_some());
521 let seas = m.seasonal.as_ref().expect("some");
522 assert_eq!(seas.period, 7);
523 assert_eq!(seas.state_dim(), 6);
524 }
525
526 #[test]
527 fn test_new_without_period() {
528 let m = StructuralModel::new(None).expect("ok");
529 assert!(m.seasonal.is_none());
530 }
531}