Skip to main content

scirs2_series/state_space/
structural.rs

1//! Structural Time Series Models (Basic Structural Model, Harvey 1989).
2//!
3//! The unobserved components model:
4//!   y_t = μ_t + γ_t + ε_t,  ε_t ~ N(0, σ²_ε)
5//!
6//! Level (local linear trend):
7//!   μ_t = μ_{t-1} + β_{t-1} + η_t,  η_t ~ N(0, σ²_η)
8//!   β_t = β_{t-1} + ζ_t,              ζ_t ~ N(0, σ²_ζ)
9//!
10//! Seasonal (dummy-variable form, period s):
11//!   Σ_{j=0}^{s-1} γ_{t-j} = ω_t,  ω_t ~ N(0, σ²_ω)
12//!
13//! The model is cast in state-space form and estimated via the Kalman filter
14//! log-likelihood, maximised using a bounded Nelder-Mead optimiser.
15
16use super::linear_gaussian::LinearGaussianSSM;
17use crate::error::{Result, TimeSeriesError};
18
19// ---------------------------------------------------------------------------
20// Component structs
21// ---------------------------------------------------------------------------
22
23/// Trend component of a structural time series model.
24#[derive(Debug, Clone)]
25pub struct TrendComponent {
26    /// Current level μ
27    pub level: f64,
28    /// Current slope β (0 for local-level model)
29    pub slope: f64,
30    /// Level disturbance variance σ²_η
31    pub level_var: f64,
32    /// Slope disturbance variance σ²_ζ (0 means deterministic slope)
33    pub slope_var: f64,
34}
35
36impl TrendComponent {
37    /// Create a trend with given level/slope variances.
38    pub fn new(level_var: f64, slope_var: f64) -> Self {
39        Self {
40            level: 0.0,
41            slope: 0.0,
42            level_var: level_var.max(1e-10),
43            slope_var: slope_var.max(0.0),
44        }
45    }
46
47    /// Local level (random walk): slope = 0, σ²_ζ = 0.
48    pub fn local_level(level_var: f64) -> Self {
49        Self::new(level_var, 0.0)
50    }
51
52    /// State dimension: 1 for local-level, 2 for local-linear-trend.
53    pub fn state_dim(&self) -> usize {
54        if self.slope_var > 0.0 {
55            2
56        } else {
57            1
58        }
59    }
60}
61
62/// Seasonal component in dummy-variable form.
63#[derive(Debug, Clone)]
64pub struct SeasonalComponent {
65    /// Seasonal period s (must be >= 2)
66    pub period: usize,
67    /// Seasonal state values γ_1, ..., γ_{s-1}
68    pub values: Vec<f64>,
69    /// Seasonal disturbance variance σ²_ω
70    pub var: f64,
71}
72
73impl SeasonalComponent {
74    /// Create a seasonal component with given period and variance.
75    pub fn new(period: usize, var: f64) -> Result<Self> {
76        if period < 2 {
77            return Err(TimeSeriesError::InvalidInput(
78                "Seasonal period must be >= 2".to_string(),
79            ));
80        }
81        Ok(Self {
82            period,
83            values: vec![0.0; period - 1],
84            var: var.max(1e-10),
85        })
86    }
87
88    /// State dimension = period - 1.
89    pub fn state_dim(&self) -> usize {
90        self.period - 1
91    }
92}
93
94// ---------------------------------------------------------------------------
95// StructuralModel
96// ---------------------------------------------------------------------------
97
98/// Basic Structural Model (BSM) combining trend, seasonality, and irregular.
99#[derive(Debug, Clone)]
100pub struct StructuralModel {
101    /// Trend component (level + optional slope)
102    pub trend: TrendComponent,
103    /// Optional seasonal component
104    pub seasonal: Option<SeasonalComponent>,
105    /// Irregular (observation noise) variance σ²_ε
106    pub irregular_var: f64,
107}
108
109impl StructuralModel {
110    /// Create a new structural model.
111    ///
112    /// If `period` is `Some(s)`, a seasonal component is included.
113    pub fn new(period: Option<usize>) -> Result<Self> {
114        let seasonal = match period {
115            Some(s) => Some(SeasonalComponent::new(s, 0.1)?),
116            None => None,
117        };
118        Ok(Self {
119            trend: TrendComponent::new(0.1, 0.01),
120            seasonal,
121            irregular_var: 0.5,
122        })
123    }
124
125    /// Build a local-level model (no slope, no seasonality).
126    pub fn local_level(level_var: f64, obs_var: f64) -> Self {
127        Self {
128            trend: TrendComponent::local_level(level_var),
129            seasonal: None,
130            irregular_var: obs_var.max(1e-10),
131        }
132    }
133
134    /// Build a local linear trend model (with slope, no seasonality).
135    pub fn local_linear_trend(level_var: f64, slope_var: f64, obs_var: f64) -> Self {
136        Self {
137            trend: TrendComponent::new(level_var, slope_var),
138            seasonal: None,
139            irregular_var: obs_var.max(1e-10),
140        }
141    }
142
143    /// Total state dimension.
144    pub fn state_dim(&self) -> usize {
145        let trend_d = self.trend.state_dim();
146        let seas_d = self.seasonal.as_ref().map_or(0, |s| s.state_dim());
147        trend_d + seas_d
148    }
149
150    /// Convert to a `LinearGaussianSSM` in state-space form.
151    ///
152    /// State vector layout:
153    ///   [μ, (β), γ₁, γ₂, ..., γ_{s-1}]
154    pub fn to_ssm(&self) -> LinearGaussianSSM {
155        let n = self.state_dim();
156        let trend_d = self.trend.state_dim();
157        let seas_d = self.seasonal.as_ref().map_or(0, |s| s.state_dim());
158
159        // Transition matrix F
160        let mut f = vec![vec![0.0f64; n]; n];
161
162        // Trend block
163        if self.trend.state_dim() == 1 {
164            // Local level: μ_t = μ_{t-1}
165            f[0][0] = 1.0;
166        } else {
167            // Local linear trend: [μ_t; β_t] = [[1,1],[0,1]] [μ_{t-1}; β_{t-1}]
168            f[0][0] = 1.0;
169            f[0][1] = 1.0;
170            f[1][1] = 1.0;
171        }
172
173        // Seasonal block: dummy-variable form
174        if let Some(seas) = &self.seasonal {
175            let s = seas.state_dim(); // period - 1
176            let off = trend_d;
177            // First row: [-1, -1, ..., -1]
178            for j in 0..s {
179                f[off][off + j] = -1.0;
180            }
181            // Remaining: shift register [I_{s-1} | 0]
182            for i in 1..s {
183                f[off + i][off + i - 1] = 1.0;
184            }
185        }
186
187        // Observation matrix H: [1, 0, (1, 0, ...)]
188        let mut h = vec![vec![0.0f64; n]];
189        h[0][0] = 1.0; // level
190        if seas_d > 0 {
191            h[0][trend_d] = 1.0; // first seasonal state
192        }
193
194        // Process noise covariance Q
195        let mut q = vec![vec![0.0f64; n]; n];
196        q[0][0] = self.trend.level_var;
197        if self.trend.state_dim() == 2 {
198            q[1][1] = self.trend.slope_var;
199        }
200        if let Some(seas) = &self.seasonal {
201            q[trend_d][trend_d] = seas.var;
202        }
203
204        // Measurement noise covariance R
205        let r = vec![vec![self.irregular_var]];
206
207        // Initial state: diffuse
208        let mu0 = vec![0.0f64; n];
209        let mut p0 = vec![vec![0.0f64; n]; n];
210        for i in 0..n {
211            p0[i][i] = 1e6;
212        }
213
214        LinearGaussianSSM {
215            dim_state: n,
216            dim_obs: 1,
217            f_mat: f,
218            h_mat: h,
219            q_mat: q,
220            r_mat: r,
221            mu0,
222            p0,
223        }
224    }
225
226    /// Compute Kalman filter log-likelihood for the given parameter vector.
227    ///
228    /// `params` = [log(σ²_η), log(σ²_ζ)?, log(σ²_ω)?, log(σ²_ε)]
229    fn log_likelihood_from_params(&self, params: &[f64], data: &[f64]) -> f64 {
230        let mut model = self.clone();
231        model.apply_params(params);
232        let ssm = model.to_ssm();
233        let obs: Vec<Vec<f64>> = data.iter().map(|&y| vec![y]).collect();
234        ssm.filter(&obs)
235            .map_or(f64::NEG_INFINITY, |k| k.log_likelihood)
236    }
237
238    /// Apply a parameter vector to update variances.
239    fn apply_params(&mut self, params: &[f64]) {
240        let mut idx = 0;
241        // level variance
242        self.trend.level_var = params[idx].exp().max(1e-10);
243        idx += 1;
244        // slope variance (if present)
245        if self.trend.state_dim() == 2 {
246            self.trend.slope_var = params[idx].exp().max(1e-10);
247            idx += 1;
248        }
249        // seasonal variance (if present)
250        if let Some(seas) = &mut self.seasonal {
251            seas.var = params[idx].exp().max(1e-10);
252            idx += 1;
253        }
254        // obs variance
255        if idx < params.len() {
256            self.irregular_var = params[idx].exp().max(1e-10);
257        }
258    }
259
260    /// Extract initial parameter vector (log-scale) for optimisation.
261    fn initial_params(&self) -> Vec<f64> {
262        let mut p = Vec::new();
263        p.push(self.trend.level_var.max(1e-10).ln());
264        if self.trend.state_dim() == 2 {
265            p.push(self.trend.slope_var.max(1e-10).ln());
266        }
267        if let Some(seas) = &self.seasonal {
268            p.push(seas.var.max(1e-10).ln());
269        }
270        p.push(self.irregular_var.max(1e-10).ln());
271        p
272    }
273
274    /// Fit the model by maximising the Kalman filter log-likelihood.
275    ///
276    /// Uses a simple coordinate-ascent / line-search optimiser on log-scale
277    /// variance parameters (no external optimiser dependency).
278    /// Returns the maximised log-likelihood.
279    pub fn fit(&mut self, data: &[f64]) -> Result<f64> {
280        let n = data.len();
281        if n < 3 {
282            return Err(TimeSeriesError::InsufficientData {
283                message: "StructuralModel::fit requires at least 3 observations".to_string(),
284                required: 3,
285                actual: n,
286            });
287        }
288
289        let mut params = self.initial_params();
290        let np = params.len();
291        let max_outer = 100;
292        let tol = 1e-6;
293
294        let mut best_ll = self.log_likelihood_from_params(&params, data);
295
296        // Nelder-Mead style coordinate ascent on log-variance parameters
297        for _outer in 0..max_outer {
298            let prev_ll = best_ll;
299            for pi in 0..np {
300                // Golden section search along this coordinate
301                let (best_v, best_local) = golden_section_search_1d(
302                    |v| {
303                        let mut p2 = params.clone();
304                        p2[pi] = v;
305                        self.log_likelihood_from_params(&p2, data)
306                    },
307                    params[pi] - 6.0,
308                    params[pi] + 6.0,
309                    30,
310                );
311                if best_local > best_ll {
312                    params[pi] = best_v;
313                    best_ll = best_local;
314                }
315            }
316            if (best_ll - prev_ll).abs() < tol {
317                break;
318            }
319        }
320
321        // Apply final parameters
322        self.apply_params(&params);
323        Ok(best_ll)
324    }
325
326    /// Decompose the series into trend, seasonal, and irregular components.
327    ///
328    /// Returns `(trend, seasonal, irregular)`, each of length T.
329    pub fn decompose(&self, data: &[f64]) -> Result<(Vec<f64>, Vec<f64>, Vec<f64>)> {
330        let n = data.len();
331        if n == 0 {
332            return Ok((vec![], vec![], vec![]));
333        }
334
335        let ssm = self.to_ssm();
336        let obs: Vec<Vec<f64>> = data.iter().map(|&y| vec![y]).collect();
337        let (sm_means, _sm_covs) = ssm.smooth(&obs)?;
338
339        let trend_d = self.trend.state_dim();
340        let seas_d = self.seasonal.as_ref().map_or(0, |s| s.state_dim());
341
342        let mut trend_vec = Vec::with_capacity(n);
343        let mut seas_vec = Vec::with_capacity(n);
344        let mut irreg_vec = Vec::with_capacity(n);
345
346        for t in 0..n {
347            let level = sm_means[t][0];
348            let seas_val = if seas_d > 0 {
349                sm_means[t][trend_d]
350            } else {
351                0.0
352            };
353            let fitted = level + seas_val;
354            let irregular = data[t] - fitted;
355
356            trend_vec.push(level);
357            seas_vec.push(seas_val);
358            irreg_vec.push(irregular);
359        }
360
361        Ok((trend_vec, seas_vec, irreg_vec))
362    }
363}
364
365// ---------------------------------------------------------------------------
366// Golden section search (1D maximisation)
367// ---------------------------------------------------------------------------
368
369/// Find the x in [a, b] maximising f(x) using the golden section search.
370fn golden_section_search_1d<F>(f: F, a: f64, b: f64, n_iter: usize) -> (f64, f64)
371where
372    F: Fn(f64) -> f64,
373{
374    let phi = (5.0_f64.sqrt() - 1.0) / 2.0; // 0.618...
375    let mut lo = a;
376    let mut hi = b;
377    let mut x1 = hi - phi * (hi - lo);
378    let mut x2 = lo + phi * (hi - lo);
379    let mut f1 = f(x1);
380    let mut f2 = f(x2);
381
382    for _ in 0..n_iter {
383        if f1 < f2 {
384            lo = x1;
385            x1 = x2;
386            f1 = f2;
387            x2 = lo + phi * (hi - lo);
388            f2 = f(x2);
389        } else {
390            hi = x2;
391            x2 = x1;
392            f2 = f1;
393            x1 = hi - phi * (hi - lo);
394            f1 = f(x1);
395        }
396    }
397
398    let best_x = (lo + hi) / 2.0;
399    let best_f = f(best_x);
400    (best_x, best_f)
401}
402
403// ---------------------------------------------------------------------------
404// Tests
405// ---------------------------------------------------------------------------
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410
411    fn trend_data(n: usize) -> Vec<f64> {
412        (0..n)
413            .map(|i| 1.0 + 0.05 * i as f64 + 0.1 * (i as f64 * 0.7).sin())
414            .collect()
415    }
416
417    fn seasonal_data(n: usize, period: usize) -> Vec<f64> {
418        (0..n)
419            .map(|i| {
420                let trend = 1.0 + 0.02 * i as f64;
421                let seas = (2.0 * std::f64::consts::PI * i as f64 / period as f64).sin();
422                trend + seas + 0.05 * (i as f64 * 1.3).cos()
423            })
424            .collect()
425    }
426
427    #[test]
428    fn test_local_level_to_ssm() {
429        let m = StructuralModel::local_level(0.1, 0.5);
430        assert_eq!(m.state_dim(), 1);
431        let ssm = m.to_ssm();
432        assert_eq!(ssm.dim_state, 1);
433        assert_eq!(ssm.f_mat[0][0], 1.0);
434        assert!((ssm.q_mat[0][0] - 0.1).abs() < 1e-10);
435        assert!((ssm.r_mat[0][0] - 0.5).abs() < 1e-10);
436    }
437
438    #[test]
439    fn test_local_linear_trend_to_ssm() {
440        let m = StructuralModel::local_linear_trend(0.1, 0.01, 0.5);
441        assert_eq!(m.state_dim(), 2);
442        let ssm = m.to_ssm();
443        // F = [[1,1],[0,1]]
444        assert_eq!(ssm.f_mat[0][0], 1.0);
445        assert_eq!(ssm.f_mat[0][1], 1.0);
446        assert_eq!(ssm.f_mat[1][0], 0.0);
447        assert_eq!(ssm.f_mat[1][1], 1.0);
448    }
449
450    #[test]
451    fn test_seasonal_to_ssm() {
452        let m = StructuralModel::new(Some(4)).expect("ok");
453        let ssm = m.to_ssm();
454        // state dim = 1 (local level, slope_var=0.01 > 0 so dim=2) + 3 seasonal
455        // Actually from new(): slope_var=0.01 > 0, so trend_d=2, seas_d=3, total=5
456        assert_eq!(ssm.dim_state, m.state_dim());
457    }
458
459    #[test]
460    fn test_decompose_local_level() {
461        let data = trend_data(40);
462        let m = StructuralModel::local_level(0.2, 0.1);
463        let (trend, seas, irreg) = m.decompose(&data).expect("decompose ok");
464        assert_eq!(trend.len(), 40);
465        assert_eq!(seas.len(), 40);
466        assert_eq!(irreg.len(), 40);
467        // Seasonal should be zero (no seasonal component)
468        for &s in &seas {
469            assert_eq!(s, 0.0);
470        }
471        // Reconstruction check: trend + seasonal + irregular ≈ data
472        for i in 0..40 {
473            let recon = trend[i] + seas[i] + irreg[i];
474            assert!(
475                (recon - data[i]).abs() < 1e-6,
476                "Reconstruction failed at {i}"
477            );
478        }
479    }
480
481    #[test]
482    fn test_decompose_seasonal() {
483        let data = seasonal_data(48, 4);
484        let m = StructuralModel::new(Some(4)).expect("ok");
485        let (trend, _seas, irreg) = m.decompose(&data).expect("decompose ok");
486        assert_eq!(trend.len(), 48);
487        assert_eq!(irreg.len(), 48);
488    }
489
490    #[test]
491    fn test_fit_level_extraction() {
492        // Constant level plus noise: Kalman filter should track the level
493        let data: Vec<f64> = (0..30)
494            .map(|i| 5.0 + 0.1 * ((i as f64) * 1.23).sin())
495            .collect();
496        let mut m = StructuralModel::local_level(0.05, 0.2);
497        let ll = m.fit(&data).expect("fit ok");
498        assert!(ll.is_finite());
499
500        let (trend, _seas, _irreg) = m.decompose(&data).expect("decompose ok");
501        // After fitting, filtered level should stay close to 5.0
502        let level_mean: f64 = trend[10..30].iter().sum::<f64>() / 20.0;
503        assert!(
504            (level_mean - 5.0).abs() < 1.0,
505            "Level mean {level_mean} far from 5.0"
506        );
507    }
508
509    #[test]
510    fn test_seasonal_component_creation() {
511        let s = SeasonalComponent::new(12, 0.05).expect("ok");
512        assert_eq!(s.period, 12);
513        assert_eq!(s.state_dim(), 11);
514        assert_eq!(s.values.len(), 11);
515    }
516
517    #[test]
518    fn test_new_with_period() {
519        let m = StructuralModel::new(Some(7)).expect("ok");
520        assert!(m.seasonal.is_some());
521        let seas = m.seasonal.as_ref().expect("some");
522        assert_eq!(seas.period, 7);
523        assert_eq!(seas.state_dim(), 6);
524    }
525
526    #[test]
527    fn test_new_without_period() {
528        let m = StructuralModel::new(None).expect("ok");
529        assert!(m.seasonal.is_none());
530    }
531}