Skip to main content

scirs2_stats/bayesian_approx/
types.rs

1//! Types for Bayesian neural network posterior approximations.
2//!
3//! Provides configuration and result types for Laplace approximation
4//! and SWAG (Stochastic Weight Averaging Gaussian).
5
6// ============================================================================
7// Hessian computation methods for Laplace approximation
8// ============================================================================
9
10/// Method used to approximate the Hessian of the loss for Laplace approximation.
11///
12/// - `GGN`: Generalized Gauss-Newton (Fisher information proxy) via squared gradients.
13/// - `Diagonal`: Diagonal Hessian approximation.
14/// - `KFAC`: Kronecker-factored approximate curvature.
15#[non_exhaustive]
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
17pub enum HessianMethod {
18    /// Generalized Gauss-Newton / squared-gradient Fisher approximation (default)
19    #[default]
20    GGN,
21    /// Diagonal curvature only
22    Diagonal,
23    /// Kronecker-factored approximate curvature
24    KFAC,
25}
26
27// ============================================================================
28// LaplaceConfig
29// ============================================================================
30
31/// Configuration for Laplace approximation of a BNN posterior.
32#[derive(Debug, Clone)]
33pub struct LaplaceConfig {
34    /// Method for computing the approximate curvature (default: GGN)
35    pub hessian_method: HessianMethod,
36    /// Tikhonov / prior regularization damping added to the diagonal of H.
37    /// Corresponds to the prior precision λ in N(0, λ⁻¹ I). (default: 1.0)
38    pub damping: f64,
39    /// Finite-difference step for gradient computation (default: 1e-5)
40    pub fd_step: f64,
41}
42
43impl Default for LaplaceConfig {
44    fn default() -> Self {
45        Self {
46            hessian_method: HessianMethod::GGN,
47            damping: 1.0,
48            fd_step: 1e-5,
49        }
50    }
51}
52
53// ============================================================================
54// SwagConfig
55// ============================================================================
56
57/// Configuration for the SWAG posterior estimator.
58#[derive(Debug, Clone)]
59pub struct SwagConfig {
60    /// Number of SGD snapshot epochs to collect (default: 20)
61    pub n_epochs: usize,
62    /// Maximum low-rank deviation columns C (SWAG rank; default: 20)
63    pub c: usize,
64    /// Learning rate for the parameter updates provided to `SwagCollector` (default: 0.01)
65    pub lr: f64,
66}
67
68impl Default for SwagConfig {
69    fn default() -> Self {
70        Self {
71            n_epochs: 20,
72            c: 20,
73            lr: 0.01,
74        }
75    }
76}
77
78// ============================================================================
79// BnnApproxResult
80// ============================================================================
81
82/// Summary result of a Bayesian NN approximation (Laplace or SWAG).
83///
84/// Contains the mean weight vector and per-parameter uncertainty (variance).
85#[derive(Debug, Clone)]
86pub struct BnnApproxResult {
87    /// Mean weights θ* (MAP estimate for Laplace; SWA solution for SWAG)
88    pub mean_weights: Vec<f64>,
89    /// Per-parameter posterior variance (diagonal of the covariance)
90    pub uncertainty: Vec<f64>,
91    /// Optional: label describing the approximation method used
92    pub method: String,
93}
94
95impl BnnApproxResult {
96    /// Return per-parameter posterior standard deviations.
97    pub fn std_devs(&self) -> Vec<f64> {
98        self.uncertainty.iter().map(|&v| v.sqrt()).collect()
99    }
100
101    /// Return the number of parameters.
102    pub fn n_params(&self) -> usize {
103        self.mean_weights.len()
104    }
105}
106
107// ============================================================================
108// Tests
109// ============================================================================
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114
115    #[test]
116    fn test_bayesian_approx_config_default() {
117        let lap = LaplaceConfig::default();
118        assert_eq!(lap.hessian_method, HessianMethod::GGN);
119        assert!((lap.damping - 1.0).abs() < 1e-12);
120
121        let swag = SwagConfig::default();
122        assert_eq!(swag.n_epochs, 20);
123        assert_eq!(swag.c, 20);
124        assert!((swag.lr - 0.01).abs() < 1e-12);
125    }
126
127    #[test]
128    fn test_hessian_method_default_is_ggn() {
129        let m = HessianMethod::default();
130        assert_eq!(m, HessianMethod::GGN);
131    }
132
133    #[test]
134    fn test_bnn_approx_result_std_devs() {
135        let result = BnnApproxResult {
136            mean_weights: vec![1.0, 2.0, 3.0],
137            uncertainty: vec![4.0, 9.0, 16.0],
138            method: "Laplace".to_string(),
139        };
140        let stds = result.std_devs();
141        assert!((stds[0] - 2.0).abs() < 1e-12);
142        assert!((stds[1] - 3.0).abs() < 1e-12);
143        assert!((stds[2] - 4.0).abs() < 1e-12);
144        assert_eq!(result.n_params(), 3);
145    }
146}