1use scirs2_core::ndarray::{Array1, Array2};
10
11use super::laplace;
12use super::types::{
13 HyperparameterPosterior, INLAConfig, IntegrationStrategy, LatentGaussianModel, LikelihoodFamily,
14};
15use crate::error::StatsError;
16
17#[derive(Debug, Clone)]
19pub struct HyperparameterPoint {
20 pub theta: Vec<f64>,
22 pub log_posterior: f64,
24 pub mode: Array1<f64>,
26 pub marginal_variances: Array1<f64>,
28}
29
30pub fn evaluate_hyperparameter(
44 theta: f64,
45 model: &LatentGaussianModel,
46 config: &INLAConfig,
47) -> Result<HyperparameterPoint, StatsError> {
48 let scale = theta.exp();
50 let scaled_precision = &model.precision_matrix * scale;
51
52 let mode_result = laplace::find_mode(
54 &scaled_precision,
55 &model.y,
56 &model.design_matrix,
57 model.likelihood,
58 model.n_trials.as_ref(),
59 model.observation_precision,
60 config.max_newton_iter,
61 config.newton_tol,
62 config.newton_damping,
63 )?;
64
65 let log_marginal = laplace::laplace_log_marginal_likelihood(&mode_result, &scaled_precision)?;
67
68 let log_prior_theta = log_hyperprior(theta, config);
70
71 let marginal_vars = laplace::inverse_diagonal(&mode_result.neg_hessian)?;
73
74 Ok(HyperparameterPoint {
75 theta: vec![theta],
76 log_posterior: log_marginal + log_prior_theta,
77 mode: mode_result.mode,
78 marginal_variances: marginal_vars,
79 })
80}
81
82fn log_hyperprior(theta: f64, config: &INLAConfig) -> f64 {
86 match config.hyperparameter_range {
87 Some((lo, hi)) => {
88 let mid = (lo + hi) / 2.0;
90 let scale = (hi - lo) / 4.0; if scale <= 0.0 {
92 return 0.0;
93 }
94 -0.5 * ((theta - mid) / scale).powi(2)
95 }
96 None => 0.0, }
98}
99
100pub fn explore_hyperparameter_grid(
112 model: &LatentGaussianModel,
113 config: &INLAConfig,
114) -> Result<Vec<HyperparameterPoint>, StatsError> {
115 let n_grid = config.n_hyperparameter_grid;
116 if n_grid == 0 {
117 return Err(StatsError::InvalidArgument(
118 "Number of hyperparameter grid points must be positive".to_string(),
119 ));
120 }
121
122 let (lo, hi) = config.hyperparameter_range.unwrap_or((-3.0, 3.0));
124
125 let grid_points = create_grid(lo, hi, n_grid);
126
127 let mut results = Vec::with_capacity(n_grid);
128 for &theta in &grid_points {
129 match evaluate_hyperparameter(theta, model, config) {
130 Ok(point) => results.push(point),
131 Err(_) => {
132 continue;
134 }
135 }
136 }
137
138 if results.is_empty() {
139 return Err(StatsError::ConvergenceError(
140 "INLA failed to evaluate any hyperparameter grid point".to_string(),
141 ));
142 }
143
144 results.sort_by(|a, b| {
146 b.log_posterior
147 .partial_cmp(&a.log_posterior)
148 .unwrap_or(std::cmp::Ordering::Equal)
149 });
150
151 Ok(results)
152}
153
154fn create_grid(lo: f64, hi: f64, n: usize) -> Vec<f64> {
156 if n == 1 {
157 return vec![(lo + hi) / 2.0];
158 }
159 let step = (hi - lo) / (n - 1) as f64;
160 (0..n).map(|i| lo + i as f64 * step).collect()
161}
162
163pub fn ccd_integration_points(n_hyperparams: usize) -> Result<Vec<Vec<f64>>, StatsError> {
180 if n_hyperparams == 0 {
181 return Err(StatsError::InvalidArgument(
182 "Number of hyperparameters must be positive".to_string(),
183 ));
184 }
185
186 let mut points = Vec::new();
187
188 points.push(vec![0.0; n_hyperparams]);
190
191 let alpha = (n_hyperparams as f64).sqrt();
193
194 for d in 0..n_hyperparams {
196 let mut point_pos = vec![0.0; n_hyperparams];
197 point_pos[d] = alpha;
198 points.push(point_pos);
199
200 let mut point_neg = vec![0.0; n_hyperparams];
201 point_neg[d] = -alpha;
202 points.push(point_neg);
203 }
204
205 let max_factorial = if n_hyperparams <= 6 {
208 1usize << n_hyperparams } else {
210 2 * n_hyperparams
212 };
213
214 let n_factorial = (1usize << n_hyperparams).min(max_factorial);
215 for i in 0..n_factorial {
216 let mut point = vec![0.0; n_hyperparams];
217 for d in 0..n_hyperparams {
218 point[d] = if (i >> d) & 1 == 0 { -1.0 } else { 1.0 };
219 }
220 points.push(point);
221 }
222
223 Ok(points)
224}
225
226pub fn grid_integration(
240 log_densities: &[f64],
241 grid_spacing: f64,
242) -> Result<(Vec<f64>, f64), StatsError> {
243 if log_densities.is_empty() {
244 return Err(StatsError::InvalidArgument(
245 "Log densities array is empty".to_string(),
246 ));
247 }
248
249 let max_log = log_densities
251 .iter()
252 .copied()
253 .fold(f64::NEG_INFINITY, f64::max);
254
255 if max_log.is_infinite() && max_log < 0.0 {
256 return Err(StatsError::ComputationError(
257 "All log densities are -infinity".to_string(),
258 ));
259 }
260
261 let n = log_densities.len();
263 let mut weights = Vec::with_capacity(n);
264 for i in 0..n {
265 let trap_factor = if i == 0 || i == n - 1 { 0.5 } else { 1.0 };
266 weights.push((log_densities[i] - max_log).exp() * trap_factor * grid_spacing);
267 }
268
269 let total_weight: f64 = weights.iter().sum();
270 if total_weight <= 0.0 {
271 return Err(StatsError::ComputationError(
272 "Total integration weight is non-positive".to_string(),
273 ));
274 }
275
276 let log_normalizing = max_log + total_weight.ln();
277
278 let normalized: Vec<f64> = weights.iter().map(|w| w / total_weight).collect();
280
281 Ok((normalized, log_normalizing))
282}
283
284pub fn summarize_hyperparameter_posterior(
294 grid_points: &[f64],
295 log_densities: &[f64],
296 grid_spacing: f64,
297) -> Result<HyperparameterPosterior, StatsError> {
298 if grid_points.len() != log_densities.len() {
299 return Err(StatsError::DimensionMismatch(
300 "Grid points and log densities must have the same length".to_string(),
301 ));
302 }
303
304 let (weights, _) = grid_integration(log_densities, grid_spacing)?;
305
306 let mean: f64 = weights
308 .iter()
309 .zip(grid_points.iter())
310 .map(|(w, t)| w * t)
311 .sum();
312
313 let variance: f64 = weights
315 .iter()
316 .zip(grid_points.iter())
317 .map(|(w, t)| w * (t - mean).powi(2))
318 .sum();
319
320 Ok(HyperparameterPosterior {
321 grid_points: grid_points.to_vec(),
322 log_densities: log_densities.to_vec(),
323 mean,
324 variance,
325 })
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331 use scirs2_core::ndarray::{array, Array2};
332
333 #[test]
334 fn test_create_grid() {
335 let grid = create_grid(-1.0, 1.0, 5);
336 assert_eq!(grid.len(), 5);
337 assert!((grid[0] - (-1.0)).abs() < 1e-10);
338 assert!((grid[4] - 1.0).abs() < 1e-10);
339 assert!((grid[2] - 0.0).abs() < 1e-10);
340 }
341
342 #[test]
343 fn test_create_grid_single() {
344 let grid = create_grid(-1.0, 1.0, 1);
345 assert_eq!(grid.len(), 1);
346 assert!((grid[0] - 0.0).abs() < 1e-10);
347 }
348
349 #[test]
350 fn test_ccd_1d() {
351 let points = ccd_integration_points(1).expect("CCD should succeed for 1D");
352 assert_eq!(points.len(), 5);
354 assert!((points[0][0]).abs() < 1e-10);
356 assert!((points[1][0] - 1.0).abs() < 1e-10);
358 assert!((points[2][0] - (-1.0)).abs() < 1e-10);
359 }
360
361 #[test]
362 fn test_ccd_2d() {
363 let points = ccd_integration_points(2).expect("CCD should succeed for 2D");
364 assert_eq!(points.len(), 9);
366 assert!((points[0][0]).abs() < 1e-10);
368 assert!((points[0][1]).abs() < 1e-10);
369 }
370
371 #[test]
372 fn test_ccd_3d() {
373 let points = ccd_integration_points(3).expect("CCD should succeed for 3D");
374 assert_eq!(points.len(), 15);
376 }
377
378 #[test]
379 fn test_ccd_zero() {
380 let result = ccd_integration_points(0);
381 assert!(result.is_err());
382 }
383
384 #[test]
385 fn test_grid_integration_uniform() {
386 let log_densities = vec![0.0, 0.0, 0.0, 0.0, 0.0];
388 let (weights, _) =
389 grid_integration(&log_densities, 1.0).expect("Integration should succeed");
390 assert!((weights[0] - 0.125).abs() < 1e-10);
393 assert!((weights[2] - 0.25).abs() < 1e-10);
394 }
395
396 #[test]
397 fn test_grid_integration_peaked() {
398 let log_densities = vec![-100.0, -10.0, 0.0, -10.0, -100.0];
400 let (weights, _) =
401 grid_integration(&log_densities, 1.0).expect("Integration should succeed");
402 assert!(weights[2] > 0.9);
404 }
405
406 #[test]
407 fn test_grid_integration_empty() {
408 let result = grid_integration(&[], 1.0);
409 assert!(result.is_err());
410 }
411
412 #[test]
413 fn test_summarize_posterior() {
414 let grid_points = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
416 let log_densities = vec![-2.0, -0.5, 0.0, -0.5, -2.0];
417 let result = summarize_hyperparameter_posterior(&grid_points, &log_densities, 1.0)
418 .expect("Summary should succeed");
419 assert!(
420 result.mean.abs() < 0.1,
421 "Mean should be near 0, got {}",
422 result.mean
423 );
424 assert!(result.variance > 0.0, "Variance should be positive");
425 }
426
427 #[test]
428 fn test_explore_grid_gaussian() {
429 let n = 3;
430 let y = array![1.0, 2.0, 3.0];
431 let design = Array2::eye(n);
432 let precision = Array2::eye(n);
433
434 let model = LatentGaussianModel::new(y, design, precision, LikelihoodFamily::Gaussian)
435 .with_observation_precision(1.0);
436
437 let config = INLAConfig {
438 n_hyperparameter_grid: 5,
439 hyperparameter_range: Some((-1.0, 1.0)),
440 max_newton_iter: 50,
441 ..INLAConfig::default()
442 };
443
444 let results =
445 explore_hyperparameter_grid(&model, &config).expect("Grid exploration should succeed");
446
447 assert!(!results.is_empty(), "Should have some valid grid points");
448 for i in 1..results.len() {
450 assert!(
451 results[i - 1].log_posterior >= results[i].log_posterior,
452 "Results should be sorted descending"
453 );
454 }
455 }
456
457 #[test]
458 fn test_dimension_mismatch_summary() {
459 let grid = vec![1.0, 2.0];
460 let densities = vec![0.0, 0.0, 0.0];
461 let result = summarize_hyperparameter_posterior(&grid, &densities, 1.0);
462 assert!(result.is_err());
463 }
464}