1use scirs2_core::ndarray::{Array1, Array2};
7
8#[derive(Debug, Clone, Copy, PartialEq)]
10#[non_exhaustive]
11pub enum LikelihoodFamily {
12 Gaussian,
14 Poisson,
16 Binomial,
18 NegativeBinomial,
20}
21
22#[derive(Debug, Clone)]
29pub struct LatentGaussianModel {
30 pub y: Array1<f64>,
32 pub design_matrix: Array2<f64>,
34 pub precision_matrix: Array2<f64>,
36 pub likelihood: LikelihoodFamily,
38 pub n_trials: Option<Array1<f64>>,
41 pub observation_precision: Option<f64>,
44}
45
46impl LatentGaussianModel {
47 pub fn new(
55 y: Array1<f64>,
56 design_matrix: Array2<f64>,
57 precision_matrix: Array2<f64>,
58 likelihood: LikelihoodFamily,
59 ) -> Self {
60 Self {
61 y,
62 design_matrix,
63 precision_matrix,
64 likelihood,
65 n_trials: None,
66 observation_precision: None,
67 }
68 }
69
70 pub fn with_n_trials(mut self, n_trials: Array1<f64>) -> Self {
72 self.n_trials = Some(n_trials);
73 self
74 }
75
76 pub fn with_observation_precision(mut self, precision: f64) -> Self {
78 self.observation_precision = Some(precision);
79 self
80 }
81}
82
83#[derive(Debug, Clone, Copy, PartialEq)]
85#[non_exhaustive]
86pub enum IntegrationStrategy {
87 Grid,
89 CCD,
91 SimplifiedLaplace,
93}
94
95#[derive(Debug, Clone)]
97pub struct INLAConfig {
98 pub n_hyperparameter_grid: usize,
100 pub integration_strategy: IntegrationStrategy,
102 pub max_newton_iter: usize,
104 pub newton_tol: f64,
106 pub simplified_laplace: bool,
108 pub newton_damping: f64,
110 pub hyperparameter_range: Option<(f64, f64)>,
112}
113
114impl Default for INLAConfig {
115 fn default() -> Self {
116 Self {
117 n_hyperparameter_grid: 25,
118 integration_strategy: IntegrationStrategy::Grid,
119 max_newton_iter: 100,
120 newton_tol: 1e-8,
121 simplified_laplace: false,
122 newton_damping: 1.0,
123 hyperparameter_range: None,
124 }
125 }
126}
127
128#[derive(Debug, Clone)]
130pub struct HyperparameterPosterior {
131 pub grid_points: Vec<f64>,
133 pub log_densities: Vec<f64>,
135 pub mean: f64,
137 pub variance: f64,
139}
140
141#[derive(Debug, Clone)]
143pub struct INLAResult {
144 pub marginal_means: Array1<f64>,
146 pub marginal_variances: Array1<f64>,
148 pub hyperparameter_posteriors: Vec<HyperparameterPosterior>,
150 pub log_marginal_likelihood: f64,
152 pub converged: bool,
154 pub newton_iterations: usize,
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161 use scirs2_core::ndarray::{array, Array2};
162
163 #[test]
164 fn test_default_config() {
165 let config = INLAConfig::default();
166 assert_eq!(config.n_hyperparameter_grid, 25);
167 assert_eq!(config.max_newton_iter, 100);
168 assert!((config.newton_tol - 1e-8).abs() < 1e-15);
169 assert!(!config.simplified_laplace);
170 assert_eq!(config.integration_strategy, IntegrationStrategy::Grid);
171 }
172
173 #[test]
174 fn test_latent_gaussian_model_new() {
175 let y = array![1.0, 2.0, 3.0];
176 let design = Array2::eye(3);
177 let precision = Array2::eye(3);
178 let model =
179 LatentGaussianModel::new(y.clone(), design, precision, LikelihoodFamily::Gaussian);
180 assert_eq!(model.y, y);
181 assert_eq!(model.likelihood, LikelihoodFamily::Gaussian);
182 assert!(model.n_trials.is_none());
183 assert!(model.observation_precision.is_none());
184 }
185
186 #[test]
187 fn test_model_with_builders() {
188 let y = array![1.0, 0.0, 1.0];
189 let design = Array2::eye(3);
190 let precision = Array2::eye(3);
191 let n_trials = array![10.0, 10.0, 10.0];
192 let model = LatentGaussianModel::new(y, design, precision, LikelihoodFamily::Binomial)
193 .with_n_trials(n_trials.clone())
194 .with_observation_precision(1.0);
195
196 assert_eq!(model.likelihood, LikelihoodFamily::Binomial);
197 assert!(model.n_trials.is_some());
198 assert_eq!(model.n_trials.as_ref().map(|t| t.len()), Some(3));
199 assert!((model.observation_precision.unwrap_or(0.0) - 1.0).abs() < 1e-15);
200 }
201
202 #[test]
203 fn test_likelihood_variants() {
204 let variants = [
205 LikelihoodFamily::Gaussian,
206 LikelihoodFamily::Poisson,
207 LikelihoodFamily::Binomial,
208 LikelihoodFamily::NegativeBinomial,
209 ];
210 for v in &variants {
211 let _ = format!("{:?}", v);
213 }
214 }
215
216 #[test]
217 fn test_integration_strategy_variants() {
218 let variants = [
219 IntegrationStrategy::Grid,
220 IntegrationStrategy::CCD,
221 IntegrationStrategy::SimplifiedLaplace,
222 ];
223 for v in &variants {
224 let _ = format!("{:?}", v);
225 }
226 }
227}