Skip to main content

scirs2_stats/variational/
mod.rs

1//! Variational Inference Methods
2//!
3//! This module provides modern variational inference algorithms for approximate
4//! Bayesian posterior computation:
5//!
6//! - **ADVI**: Automatic Differentiation Variational Inference (Kucukelbir et al. 2017)
7//!   with mean-field and full-rank Gaussian approximations, automatic parameter
8//!   transformations, ELBO optimization via reparameterization trick + Adam optimizer.
9//!
10//! - **SVGD**: Stein Variational Gradient Descent (Liu & Wang 2016) — a particle-based
11//!   method that transports a set of particles to approximate the posterior using
12//!   kernelized Stein discrepancy with RBF kernel and median bandwidth heuristic.
13//!
14//! - **Normalizing Flows**: Flexible posterior approximations via invertible
15//!   transformations (planar and radial flows) with log-determinant Jacobian tracking.
16
17mod advi;
18pub mod bbvi;
19mod normalizing_flow;
20mod svgd;
21
22pub use advi::*;
23pub use normalizing_flow::*;
24pub use svgd::*;
25
26use crate::error::StatsResult;
27use scirs2_core::ndarray::Array1;
28
29// ============================================================================
30// Common Trait
31// ============================================================================
32
33/// Result of variational inference
34#[derive(Debug, Clone)]
35pub struct PosteriorResult {
36    /// Posterior means (in constrained space)
37    pub means: Array1<f64>,
38    /// Posterior standard deviations (in constrained space)
39    pub std_devs: Array1<f64>,
40    /// ELBO history over iterations
41    pub elbo_history: Vec<f64>,
42    /// Number of iterations performed
43    pub iterations: usize,
44    /// Whether the algorithm converged
45    pub converged: bool,
46    /// Optional: posterior samples (for particle-based methods like SVGD)
47    pub samples: Option<Vec<Array1<f64>>>,
48}
49
50/// Common trait for variational inference methods
51pub trait VariationalInference {
52    /// Fit the variational approximation to a target log-joint distribution.
53    ///
54    /// # Arguments
55    /// * `log_joint` - Function computing `(log p(x, theta), grad_theta log p(x, theta))`
56    /// * `dim` - Dimensionality of the parameter space
57    ///
58    /// # Returns
59    /// A `PosteriorResult` with posterior statistics and convergence info
60    fn fit<F>(&mut self, log_joint: F, dim: usize) -> StatsResult<PosteriorResult>
61    where
62        F: Fn(&Array1<f64>) -> StatsResult<(f64, Array1<f64>)>;
63}