scirs2_stats/bayesian_approx/types.rs
1//! Types for Bayesian neural network posterior approximations.
2//!
3//! Provides configuration and result types for Laplace approximation
4//! and SWAG (Stochastic Weight Averaging Gaussian).
5
6// ============================================================================
7// Hessian computation methods for Laplace approximation
8// ============================================================================
9
10/// Method used to approximate the Hessian of the loss for Laplace approximation.
11///
12/// - `GGN`: Generalized Gauss-Newton (Fisher information proxy) via squared gradients.
13/// - `Diagonal`: Diagonal Hessian approximation.
14/// - `KFAC`: Kronecker-factored approximate curvature.
15#[non_exhaustive]
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
17pub enum HessianMethod {
18 /// Generalized Gauss-Newton / squared-gradient Fisher approximation (default)
19 #[default]
20 GGN,
21 /// Diagonal curvature only
22 Diagonal,
23 /// Kronecker-factored approximate curvature
24 KFAC,
25}
26
27// ============================================================================
28// LaplaceConfig
29// ============================================================================
30
31/// Configuration for Laplace approximation of a BNN posterior.
32#[derive(Debug, Clone)]
33pub struct LaplaceConfig {
34 /// Method for computing the approximate curvature (default: GGN)
35 pub hessian_method: HessianMethod,
36 /// Tikhonov / prior regularization damping added to the diagonal of H.
37 /// Corresponds to the prior precision λ in N(0, λ⁻¹ I). (default: 1.0)
38 pub damping: f64,
39 /// Finite-difference step for gradient computation (default: 1e-5)
40 pub fd_step: f64,
41}
42
43impl Default for LaplaceConfig {
44 fn default() -> Self {
45 Self {
46 hessian_method: HessianMethod::GGN,
47 damping: 1.0,
48 fd_step: 1e-5,
49 }
50 }
51}
52
53// ============================================================================
54// SwagConfig
55// ============================================================================
56
57/// Configuration for the SWAG posterior estimator.
58#[derive(Debug, Clone)]
59pub struct SwagConfig {
60 /// Number of SGD snapshot epochs to collect (default: 20)
61 pub n_epochs: usize,
62 /// Maximum low-rank deviation columns C (SWAG rank; default: 20)
63 pub c: usize,
64 /// Learning rate for the parameter updates provided to `SwagCollector` (default: 0.01)
65 pub lr: f64,
66}
67
68impl Default for SwagConfig {
69 fn default() -> Self {
70 Self {
71 n_epochs: 20,
72 c: 20,
73 lr: 0.01,
74 }
75 }
76}
77
78// ============================================================================
79// BnnApproxResult
80// ============================================================================
81
82/// Summary result of a Bayesian NN approximation (Laplace or SWAG).
83///
84/// Contains the mean weight vector and per-parameter uncertainty (variance).
85#[derive(Debug, Clone)]
86pub struct BnnApproxResult {
87 /// Mean weights θ* (MAP estimate for Laplace; SWA solution for SWAG)
88 pub mean_weights: Vec<f64>,
89 /// Per-parameter posterior variance (diagonal of the covariance)
90 pub uncertainty: Vec<f64>,
91 /// Optional: label describing the approximation method used
92 pub method: String,
93}
94
95impl BnnApproxResult {
96 /// Return per-parameter posterior standard deviations.
97 pub fn std_devs(&self) -> Vec<f64> {
98 self.uncertainty.iter().map(|&v| v.sqrt()).collect()
99 }
100
101 /// Return the number of parameters.
102 pub fn n_params(&self) -> usize {
103 self.mean_weights.len()
104 }
105}
106
107// ============================================================================
108// Tests
109// ============================================================================
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114
115 #[test]
116 fn test_bayesian_approx_config_default() {
117 let lap = LaplaceConfig::default();
118 assert_eq!(lap.hessian_method, HessianMethod::GGN);
119 assert!((lap.damping - 1.0).abs() < 1e-12);
120
121 let swag = SwagConfig::default();
122 assert_eq!(swag.n_epochs, 20);
123 assert_eq!(swag.c, 20);
124 assert!((swag.lr - 0.01).abs() < 1e-12);
125 }
126
127 #[test]
128 fn test_hessian_method_default_is_ggn() {
129 let m = HessianMethod::default();
130 assert_eq!(m, HessianMethod::GGN);
131 }
132
133 #[test]
134 fn test_bnn_approx_result_std_devs() {
135 let result = BnnApproxResult {
136 mean_weights: vec![1.0, 2.0, 3.0],
137 uncertainty: vec![4.0, 9.0, 16.0],
138 method: "Laplace".to_string(),
139 };
140 let stds = result.std_devs();
141 assert!((stds[0] - 2.0).abs() < 1e-12);
142 assert!((stds[1] - 3.0).abs() < 1e-12);
143 assert!((stds[2] - 4.0).abs() < 1e-12);
144 assert_eq!(result.n_params(), 3);
145 }
146}