Skip to main content

scirs2_stats/inla/
types.rs

1//! Types for Integrated Nested Laplace Approximation (INLA)
2//!
3//! This module defines the core data structures for latent Gaussian models,
4//! INLA configuration, and result types.
5
6use scirs2_core::ndarray::{Array1, Array2};
7
8/// Likelihood family for the observed data given the latent field
9#[derive(Debug, Clone, Copy, PartialEq)]
10#[non_exhaustive]
11pub enum LikelihoodFamily {
12    /// Gaussian likelihood: y_i ~ N(eta_i, sigma^2)
13    Gaussian,
14    /// Poisson likelihood: y_i ~ Poisson(exp(eta_i))
15    Poisson,
16    /// Binomial likelihood: y_i ~ Binomial(n_i, logistic(eta_i))
17    Binomial,
18    /// Negative Binomial likelihood: y_i ~ NegBin(r, p_i) with log link
19    NegativeBinomial,
20}
21
22/// A latent Gaussian model specification
23///
24/// The model is:
25///   y | x, theta ~ product of p(y_i | eta_i, theta)
26///   x | theta ~ N(0, Q(theta)^{-1})
27///   eta = A * x  (linear predictor, where A is the design matrix)
28#[derive(Debug, Clone)]
29pub struct LatentGaussianModel {
30    /// Observed data vector (n x 1)
31    pub y: Array1<f64>,
32    /// Fixed effects design matrix (n x p), maps latent field to linear predictor
33    pub design_matrix: Array2<f64>,
34    /// GMRF precision matrix Q(theta) for the latent field (p x p)
35    pub precision_matrix: Array2<f64>,
36    /// Likelihood family for the observations
37    pub likelihood: LikelihoodFamily,
38    /// Number of trials for Binomial likelihood (one per observation).
39    /// Ignored for other likelihood families.
40    pub n_trials: Option<Array1<f64>>,
41    /// Observation precision (1/sigma^2) for Gaussian likelihood.
42    /// Ignored for other likelihood families.
43    pub observation_precision: Option<f64>,
44}
45
46impl LatentGaussianModel {
47    /// Create a new latent Gaussian model
48    ///
49    /// # Arguments
50    /// * `y` - Observation vector
51    /// * `design_matrix` - Design matrix mapping latent field to linear predictor
52    /// * `precision_matrix` - GMRF precision matrix Q(theta)
53    /// * `likelihood` - Likelihood family
54    pub fn new(
55        y: Array1<f64>,
56        design_matrix: Array2<f64>,
57        precision_matrix: Array2<f64>,
58        likelihood: LikelihoodFamily,
59    ) -> Self {
60        Self {
61            y,
62            design_matrix,
63            precision_matrix,
64            likelihood,
65            n_trials: None,
66            observation_precision: None,
67        }
68    }
69
70    /// Set the number of trials for Binomial likelihood
71    pub fn with_n_trials(mut self, n_trials: Array1<f64>) -> Self {
72        self.n_trials = Some(n_trials);
73        self
74    }
75
76    /// Set the observation precision for Gaussian likelihood
77    pub fn with_observation_precision(mut self, precision: f64) -> Self {
78        self.observation_precision = Some(precision);
79        self
80    }
81}
82
83/// Integration strategy for marginal computation
84#[derive(Debug, Clone, Copy, PartialEq)]
85#[non_exhaustive]
86pub enum IntegrationStrategy {
87    /// Full grid-based integration over hyperparameter space
88    Grid,
89    /// Central Composite Design for efficient integration
90    CCD,
91    /// Simplified Laplace approximation (fastest but least accurate)
92    SimplifiedLaplace,
93}
94
95/// Configuration for the INLA algorithm
96#[derive(Debug, Clone)]
97pub struct INLAConfig {
98    /// Number of grid points per hyperparameter dimension
99    pub n_hyperparameter_grid: usize,
100    /// Integration strategy for marginalizing over hyperparameters
101    pub integration_strategy: IntegrationStrategy,
102    /// Maximum Newton-Raphson iterations for mode finding
103    pub max_newton_iter: usize,
104    /// Convergence tolerance for Newton-Raphson
105    pub newton_tol: f64,
106    /// Whether to use simplified Laplace for conditional marginals
107    pub simplified_laplace: bool,
108    /// Step size damping factor for Newton-Raphson (0 < damping <= 1)
109    pub newton_damping: f64,
110    /// Hyperparameter prior log-density (if None, flat prior is used)
111    pub hyperparameter_range: Option<(f64, f64)>,
112}
113
114impl Default for INLAConfig {
115    fn default() -> Self {
116        Self {
117            n_hyperparameter_grid: 25,
118            integration_strategy: IntegrationStrategy::Grid,
119            max_newton_iter: 100,
120            newton_tol: 1e-8,
121            simplified_laplace: false,
122            newton_damping: 1.0,
123            hyperparameter_range: None,
124        }
125    }
126}
127
128/// Posterior distribution of a single hyperparameter
129#[derive(Debug, Clone)]
130pub struct HyperparameterPosterior {
131    /// Grid points at which the posterior was evaluated
132    pub grid_points: Vec<f64>,
133    /// Log-density values at each grid point (unnormalized)
134    pub log_densities: Vec<f64>,
135    /// Posterior mean (computed from normalized density)
136    pub mean: f64,
137    /// Posterior variance (computed from normalized density)
138    pub variance: f64,
139}
140
141/// Result of the INLA algorithm
142#[derive(Debug, Clone)]
143pub struct INLAResult {
144    /// Posterior marginal means for each latent field component
145    pub marginal_means: Array1<f64>,
146    /// Posterior marginal variances for each latent field component
147    pub marginal_variances: Array1<f64>,
148    /// Posterior distributions of hyperparameters
149    pub hyperparameter_posteriors: Vec<HyperparameterPosterior>,
150    /// Log marginal likelihood estimate log p(y)
151    pub log_marginal_likelihood: f64,
152    /// Whether the algorithm converged
153    pub converged: bool,
154    /// Number of Newton-Raphson iterations used at the mode
155    pub newton_iterations: usize,
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161    use scirs2_core::ndarray::{array, Array2};
162
163    #[test]
164    fn test_default_config() {
165        let config = INLAConfig::default();
166        assert_eq!(config.n_hyperparameter_grid, 25);
167        assert_eq!(config.max_newton_iter, 100);
168        assert!((config.newton_tol - 1e-8).abs() < 1e-15);
169        assert!(!config.simplified_laplace);
170        assert_eq!(config.integration_strategy, IntegrationStrategy::Grid);
171    }
172
173    #[test]
174    fn test_latent_gaussian_model_new() {
175        let y = array![1.0, 2.0, 3.0];
176        let design = Array2::eye(3);
177        let precision = Array2::eye(3);
178        let model =
179            LatentGaussianModel::new(y.clone(), design, precision, LikelihoodFamily::Gaussian);
180        assert_eq!(model.y, y);
181        assert_eq!(model.likelihood, LikelihoodFamily::Gaussian);
182        assert!(model.n_trials.is_none());
183        assert!(model.observation_precision.is_none());
184    }
185
186    #[test]
187    fn test_model_with_builders() {
188        let y = array![1.0, 0.0, 1.0];
189        let design = Array2::eye(3);
190        let precision = Array2::eye(3);
191        let n_trials = array![10.0, 10.0, 10.0];
192        let model = LatentGaussianModel::new(y, design, precision, LikelihoodFamily::Binomial)
193            .with_n_trials(n_trials.clone())
194            .with_observation_precision(1.0);
195
196        assert_eq!(model.likelihood, LikelihoodFamily::Binomial);
197        assert!(model.n_trials.is_some());
198        assert_eq!(model.n_trials.as_ref().map(|t| t.len()), Some(3));
199        assert!((model.observation_precision.unwrap_or(0.0) - 1.0).abs() < 1e-15);
200    }
201
202    #[test]
203    fn test_likelihood_variants() {
204        let variants = [
205            LikelihoodFamily::Gaussian,
206            LikelihoodFamily::Poisson,
207            LikelihoodFamily::Binomial,
208            LikelihoodFamily::NegativeBinomial,
209        ];
210        for v in &variants {
211            // Ensure Debug works
212            let _ = format!("{:?}", v);
213        }
214    }
215
216    #[test]
217    fn test_integration_strategy_variants() {
218        let variants = [
219            IntegrationStrategy::Grid,
220            IntegrationStrategy::CCD,
221            IntegrationStrategy::SimplifiedLaplace,
222        ];
223        for v in &variants {
224            let _ = format!("{:?}", v);
225        }
226    }
227}