Skip to main content

so_tsa/
statespace.rs

1//! State space models and Kalman filter
2//!
3//! This module implements state space models for time series analysis,
4//! including the Kalman filter for estimation and forecasting.
5//!
6//! # State Space Representation
7//!
8//! General linear Gaussian state space model:
9//!
10//! Observation equation: yₜ = Zₜ αₜ + εₜ, εₜ ∼ N(0, Hₜ)
11//! State equation:     αₜ = Tₜ αₜ₋₁ + Rₜ ηₜ, ηₜ ∼ N(0, Qₜ)
12//!
13//! where:
14//! - yₜ: observed time series
15//! - αₜ: unobserved state vector
16//! - Zₜ: observation matrix
17//! - Tₜ: transition matrix
18//! - Rₜ: selection matrix for state disturbances
19//! - Hₜ: observation covariance matrix
20//! - Qₜ: state disturbance covariance matrix
21//!
22//! # Common Models
23//!
24//! 1. **Local Level Model**: yₜ = μₜ + εₜ, μₜ = μₜ₋₁ + ηₜ
25//! 2. **Local Linear Trend**: yₜ = μₜ + εₜ, μₜ = μₜ₋₁ + νₜ₋₁ + ηₜ, νₜ = νₜ₋₁ + ζₜ
26//! 3. **Basic Structural Model**: Adds seasonal components
27//! 4. **ARMA in State Space**: Any ARMA model can be represented in state space
28
29use ndarray::{Array1, Array2, Array3};
30use serde::{Deserialize, Serialize};
31use so_core::error::Result;
32use so_linalg;
33
34/// State space model specification
35#[derive(Debug, Clone)]
36pub struct StateSpaceModel {
37    /// Observation matrix Z (n_obs × n_states)
38    pub observation_matrix: Array2<f64>,
39    /// Transition matrix T (n_states × n_states)
40    pub transition_matrix: Array2<f64>,
41    /// Selection matrix R (n_states × n_disturbances)
42    pub selection_matrix: Array2<f64>,
43    /// Observation covariance H (n_obs × n_obs)
44    pub observation_cov: Array2<f64>,
45    /// State disturbance covariance Q (n_disturbances × n_disturbances)
46    pub state_cov: Array2<f64>,
47    /// Initial state mean α₀ (n_states)
48    pub initial_state_mean: Array1<f64>,
49    /// Initial state covariance P₀ (n_states × n_states)
50    pub initial_state_cov: Array2<f64>,
51}
52
53/// Kalman filter results
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct KalmanFilterResults {
56    /// Filtered state means (n_timesteps × n_states)
57    pub filtered_state_means: Array2<f64>,
58    /// Filtered state covariances (n_timesteps × n_states × n_states)
59    pub filtered_state_covs: Array3<f64>,
60    /// Predicted state means (n_timesteps × n_states)
61    pub predicted_state_means: Array2<f64>,
62    /// Predicted state covariances (n_timesteps × n_states × n_states)
63    pub predicted_state_covs: Array3<f64>,
64    /// Innovations (prediction errors)
65    pub innovations: Array1<f64>,
66    /// Innovation variances
67    pub innovation_variances: Array1<f64>,
68    /// Kalman gains (n_timesteps × n_states × n_obs)
69    pub kalman_gains: Array3<f64>,
70    /// Log-likelihood
71    pub log_likelihood: f64,
72}
73
74impl StateSpaceModel {
75    /// Create local level model (random walk plus noise)
76    pub fn local_level(obs_var: f64, level_var: f64) -> Self {
77        // y_t = μ_t + ε_t, ε_t ~ N(0, σ_ε²)
78        // μ_t = μ_{t-1} + η_t, η_t ~ N(0, σ_η²)
79
80        let observation_matrix = ndarray::array![[1.0]];
81        let transition_matrix = ndarray::array![[1.0]];
82        let selection_matrix = ndarray::array![[1.0]];
83        let observation_cov = ndarray::array![[obs_var]];
84        let state_cov = ndarray::array![[level_var]];
85        let initial_state_mean = ndarray::array![0.0];
86        let initial_state_cov = ndarray::array![[1e6]]; // Diffuse prior
87
88        Self {
89            observation_matrix,
90            transition_matrix,
91            selection_matrix,
92            observation_cov,
93            state_cov,
94            initial_state_mean,
95            initial_state_cov,
96        }
97    }
98
99    /// Create local linear trend model
100    pub fn local_linear_trend(obs_var: f64, level_var: f64, slope_var: f64) -> Self {
101        // y_t = μ_t + ε_t
102        // μ_t = μ_{t-1} + ν_{t-1} + η_t
103        // ν_t = ν_{t-1} + ζ_t
104
105        let observation_matrix = ndarray::array![[1.0, 0.0]];
106        let transition_matrix = ndarray::array![[1.0, 1.0], [0.0, 1.0]];
107        let selection_matrix = ndarray::array![[1.0, 0.0], [0.0, 1.0]];
108        let observation_cov = ndarray::array![[obs_var]];
109        let state_cov = ndarray::array![[level_var, 0.0], [0.0, slope_var]];
110        let initial_state_mean = ndarray::array![0.0, 0.0];
111        let initial_state_cov = ndarray::array![[1e6, 0.0], [0.0, 1e6]];
112
113        Self {
114            observation_matrix,
115            transition_matrix,
116            selection_matrix,
117            observation_cov,
118            state_cov,
119            initial_state_mean,
120            initial_state_cov,
121        }
122    }
123
124    /// Create ARMA(p, q) model in state space form
125    pub fn arma(ar_coef: &[f64], ma_coef: &[f64], sigma2: f64) -> Self {
126        let p = ar_coef.len();
127        let q = ma_coef.len();
128        let r = p.max(q + 1);
129        let n_states = r;
130
131        // Build transition matrix (companion form)
132        let mut transition = Array2::zeros((n_states, n_states));
133
134        if p > 0 {
135            // First row contains AR coefficients
136            for j in 0..p {
137                transition[(0, j)] = ar_coef[j];
138            }
139        }
140
141        // Sub-diagonal of ones
142        for i in 1..n_states {
143            transition[(i, i - 1)] = 1.0;
144        }
145
146        // Observation matrix
147        let mut observation = Array1::zeros(n_states);
148        observation[0] = 1.0;
149        if q > 0 {
150            // Include MA coefficients
151            for j in 0..q.min(n_states - 1) {
152                observation[j + 1] = ma_coef[j];
153            }
154        }
155        let observation_matrix = observation.insert_axis(ndarray::Axis(0));
156
157        // Selection matrix (for state disturbances)
158        let mut selection = Array2::zeros((n_states, 1));
159        selection[(0, 0)] = 1.0;
160        for j in 1..q.min(n_states - 1) {
161            selection[(j, 0)] = ma_coef[j - 1];
162        }
163
164        // Covariance matrices
165        let observation_cov = ndarray::array![[0.0]]; // No measurement error in standard ARMA
166        let state_cov = ndarray::array![[sigma2]];
167
168        // Initial state (diffuse)
169        let initial_state_mean = Array1::zeros(n_states);
170        let mut initial_state_cov = Array2::zeros((n_states, n_states));
171        for i in 0..n_states {
172            initial_state_cov[(i, i)] = 1e6;
173        }
174
175        Self {
176            observation_matrix,
177            transition_matrix: transition,
178            selection_matrix: selection,
179            observation_cov,
180            state_cov,
181            initial_state_mean,
182            initial_state_cov,
183        }
184    }
185
186    /// Apply Kalman filter to time series
187    pub fn filter(&self, y: &Array1<f64>) -> Result<KalmanFilterResults> {
188        let n = y.len();
189        let n_states = self.observation_matrix.ncols();
190        let n_obs = self.observation_matrix.nrows();
191
192        // Initialize arrays
193        let mut filtered_means = Array2::zeros((n, n_states));
194        let mut filtered_covs = Array3::zeros((n, n_states, n_states));
195        let mut predicted_means = Array2::zeros((n, n_states));
196        let mut predicted_covs = Array3::zeros((n, n_states, n_states));
197        let mut innovations = Array1::zeros(n);
198        let mut innovation_variances = Array1::zeros(n);
199        let mut kalman_gains = Array3::zeros((n, n_states, n_obs));
200
201        let mut log_likelihood = 0.0;
202
203        // Initial prediction (t = 0)
204        let mut pred_mean = self.initial_state_mean.clone();
205        let mut pred_cov = self.initial_state_cov.clone();
206
207        for t in 0..n {
208            // Store prediction
209            predicted_means.row_mut(t).assign(&pred_mean);
210            predicted_covs
211                .slice_mut(ndarray::s![t, .., ..])
212                .assign(&pred_cov);
213
214            // Innovation (prediction error)
215            let obs_pred = self.observation_matrix.dot(&pred_mean);
216            let innovation = y[t] - obs_pred[0];
217            innovations[t] = innovation;
218
219            // Innovation variance
220            let innovation_var = self
221                .observation_matrix
222                .dot(&pred_cov.dot(&self.observation_matrix.t()))
223                + &self.observation_cov;
224            let innovation_var_scalar = innovation_var[(0, 0)];
225            innovation_variances[t] = innovation_var_scalar;
226
227            // Log-likelihood contribution (ignoring constant)
228            if innovation_var_scalar > 0.0 {
229                log_likelihood += -0.5 * innovation_var_scalar.ln()
230                    - 0.5 * innovation.powi(2) / innovation_var_scalar;
231            }
232
233            // Kalman gain
234            let kalman_gain = if innovation_var_scalar > 0.0 {
235                pred_cov.dot(&self.observation_matrix.t()) / innovation_var_scalar
236            } else {
237                Array2::zeros((n_states, n_obs))
238            };
239
240            kalman_gains
241                .slice_mut(ndarray::s![t, .., ..])
242                .assign(&kalman_gain);
243
244            // Filtered state estimate
245            let filtered_mean = &pred_mean + kalman_gain.dot(&ndarray::array![innovation]);
246            let filtered_cov = &pred_cov - kalman_gain.dot(&self.observation_matrix.dot(&pred_cov));
247
248            // Store filtered estimates
249            filtered_means.row_mut(t).assign(&filtered_mean);
250            filtered_covs
251                .slice_mut(ndarray::s![t, .., ..])
252                .assign(&filtered_cov);
253
254            // Predict next state
255            if t < n - 1 {
256                pred_mean = self.transition_matrix.dot(&filtered_mean);
257                pred_cov = self
258                    .transition_matrix
259                    .dot(&filtered_cov.dot(&self.transition_matrix.t()))
260                    + self
261                        .selection_matrix
262                        .dot(&self.state_cov.dot(&self.selection_matrix.t()));
263            }
264        }
265
266        Ok(KalmanFilterResults {
267            filtered_state_means: filtered_means,
268            filtered_state_covs: filtered_covs,
269            predicted_state_means: predicted_means,
270            predicted_state_covs: predicted_covs,
271            innovations,
272            innovation_variances,
273            kalman_gains,
274            log_likelihood,
275        })
276    }
277
278    /// Apply Kalman smoother (Rauch-Tung-Striebel smoother)
279    pub fn smooth(&self, filter_results: &KalmanFilterResults) -> KalmanFilterResults {
280        let n = filter_results.filtered_state_means.nrows();
281        let n_states = self.observation_matrix.ncols();
282
283        // Initialize smoothed arrays
284        let mut smoothed_means = filter_results.filtered_state_means.clone();
285        let mut smoothed_covs = filter_results.filtered_state_covs.clone();
286
287        // Start from last time point
288        let mut smoother_gain = Array2::zeros((n_states, n_states));
289
290        for t in (0..n - 1).rev() {
291            // Smoother gain
292            let pred_cov = filter_results
293                .predicted_state_covs
294                .slice(ndarray::s![t + 1, .., ..]);
295            let filtered_cov = filter_results
296                .filtered_state_covs
297                .slice(ndarray::s![t, .., ..]);
298
299            let pred_cov_inv =
300                so_linalg::inv(&pred_cov.to_owned()).unwrap_or_else(|_| pred_cov.to_owned());
301            smoother_gain.assign(
302                &filtered_cov
303                    .dot(&self.transition_matrix.t())
304                    .dot(&pred_cov_inv),
305            );
306
307            // Smoothed state
308            let filtered_mean = filter_results.filtered_state_means.row(t);
309            let next_smoothed_mean = smoothed_means.row(t + 1);
310            let next_pred_mean = filter_results.predicted_state_means.row(t + 1);
311
312            let mut smoothed_mean = filtered_mean.to_owned();
313            smoothed_mean += &smoother_gain.dot(&(&next_smoothed_mean - &next_pred_mean));
314
315            // Smoothed covariance
316            let next_smoothed_cov = smoothed_covs.slice(ndarray::s![t + 1, .., ..]);
317            let next_pred_cov =
318                filter_results
319                    .predicted_state_covs
320                    .slice(ndarray::s![t + 1, .., ..]);
321
322            let mut smoothed_cov = filtered_cov.to_owned();
323            let diff_cov = &next_smoothed_cov - &next_pred_cov;
324            smoothed_cov += &smoother_gain.dot(&diff_cov.dot(&smoother_gain.t()));
325
326            // Store results
327            smoothed_means.row_mut(t).assign(&smoothed_mean);
328            smoothed_covs
329                .slice_mut(ndarray::s![t, .., ..])
330                .assign(&smoothed_cov);
331        }
332
333        KalmanFilterResults {
334            filtered_state_means: smoothed_means,
335            filtered_state_covs: smoothed_covs,
336            ..filter_results.clone()
337        }
338    }
339
340    /// Forecast future states
341    pub fn forecast(
342        &self,
343        filter_results: &KalmanFilterResults,
344        steps: usize,
345    ) -> (Array2<f64>, Array3<f64>) {
346        let n = filter_results.filtered_state_means.nrows();
347        let n_states = self.observation_matrix.ncols();
348
349        let mut forecast_means = Array2::zeros((steps, n_states));
350        let mut forecast_covs = Array3::zeros((steps, n_states, n_states));
351
352        // Start from last filtered state
353        let mut current_mean = filter_results.filtered_state_means.row(n - 1).to_owned();
354        let mut current_cov = filter_results
355            .filtered_state_covs
356            .slice(ndarray::s![n - 1, .., ..])
357            .to_owned();
358
359        for h in 0..steps {
360            // Predict state
361            current_mean = self.transition_matrix.dot(&current_mean);
362            current_cov = self
363                .transition_matrix
364                .dot(&current_cov.dot(&self.transition_matrix.t()))
365                + self
366                    .selection_matrix
367                    .dot(&self.state_cov.dot(&self.selection_matrix.t()));
368
369            // Store forecast
370            forecast_means.row_mut(h).assign(&current_mean);
371            forecast_covs
372                .slice_mut(ndarray::s![h, .., ..])
373                .assign(&current_cov);
374        }
375
376        (forecast_means, forecast_covs)
377    }
378
379    /// Calculate marginal log-likelihood
380    pub fn log_likelihood(&self, y: &Array1<f64>) -> Result<f64> {
381        let results = self.filter(y)?;
382        Ok(results.log_likelihood)
383    }
384
385    /// Estimate parameters via maximum likelihood
386    pub fn estimate(&mut self, _y: &Array1<f64>) -> Result<()> {
387        // This would implement MLE using EM algorithm or numerical optimization
388        // For now, just return the model as-is
389        Ok(())
390    }
391}
392
393/// Kalman filter implementation
394pub struct KalmanFilter;
395
396impl KalmanFilter {
397    /// Create new Kalman filter
398    pub fn new() -> Self {
399        Self
400    }
401
402    /// Run filter on state space model
403    pub fn filter(&self, model: &StateSpaceModel, y: &Array1<f64>) -> Result<KalmanFilterResults> {
404        model.filter(y)
405    }
406
407    /// Run smoother on filtered results
408    pub fn smooth(
409        &self,
410        model: &StateSpaceModel,
411        results: &KalmanFilterResults,
412    ) -> KalmanFilterResults {
413        model.smooth(results)
414    }
415
416    /// Run filter and smoother
417    pub fn filter_smooth(
418        &self,
419        model: &StateSpaceModel,
420        y: &Array1<f64>,
421    ) -> Result<KalmanFilterResults> {
422        let filtered = model.filter(y)?;
423        Ok(model.smooth(&filtered))
424    }
425
426    /// Forecast future observations
427    pub fn forecast(
428        &self,
429        model: &StateSpaceModel,
430        results: &KalmanFilterResults,
431        steps: usize,
432    ) -> (Array1<f64>, Array1<f64>) {
433        let (state_means, state_covs) = model.forecast(results, steps);
434
435        let mut forecast_means = Array1::zeros(steps);
436        let mut forecast_variances = Array1::zeros(steps);
437
438        for h in 0..steps {
439            let state_mean = state_means.row(h);
440            let state_cov = state_covs.slice(ndarray::s![h, .., ..]);
441
442            // Forecast observation
443            let obs_mean = model.observation_matrix.dot(&state_mean);
444            forecast_means[h] = obs_mean[0];
445
446            // Forecast variance
447            let obs_var = model
448                .observation_matrix
449                .dot(&state_cov.dot(&model.observation_matrix.t()))
450                + &model.observation_cov;
451            forecast_variances[h] = obs_var[(0, 0)];
452        }
453
454        (forecast_means, forecast_variances)
455    }
456}