Skip to main content

scirs2_stats/bayesian_nn/
types.rs

1//! Type definitions for Bayesian Neural Network approximations.
2//!
3//! Provides core data structures for Laplace approximation and SWAG
4//! posterior inference over neural network weights.
5
6use scirs2_core::ndarray::{Array1, Array2};
7
8/// Type of uncertainty to quantify.
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10#[non_exhaustive]
11pub enum UncertaintyType {
12    /// Irreducible noise in the data
13    Aleatoric,
14    /// Model uncertainty due to limited data
15    Epistemic,
16    /// Total uncertainty (aleatoric + epistemic)
17    Total,
18}
19
20/// Posterior approximation method.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22#[non_exhaustive]
23pub enum ApproximationMethod {
24    /// Laplace approximation on last layer weights only
25    LastLayerLaplace,
26    /// Laplace approximation on all network weights
27    FullLaplace,
28    /// Full-rank SWAG (diagonal + low-rank)
29    SWAG,
30    /// Diagonal-only SWAG
31    SWAGDiag,
32}
33
34/// Configuration for Bayesian neural network approximations.
35#[derive(Debug, Clone)]
36pub struct BNNConfig {
37    /// Approximation method to use
38    pub method: ApproximationMethod,
39    /// Number of Monte Carlo samples for predictive distribution (default 30)
40    pub n_samples: usize,
41    /// Prior precision: weights ~ N(0, (1/prior_precision) * I) (default 1.0)
42    pub prior_precision: f64,
43    /// Number of low-rank deviation columns for SWAG (default 20)
44    pub swag_rank: usize,
45    /// Start collecting weight snapshots after this many SGD steps (default 0)
46    pub swag_collection_start: usize,
47    /// Collect a weight snapshot every N SGD steps (default 1)
48    pub swag_collection_freq: usize,
49}
50
51impl Default for BNNConfig {
52    fn default() -> Self {
53        Self {
54            method: ApproximationMethod::FullLaplace,
55            n_samples: 30,
56            prior_precision: 1.0,
57            swag_rank: 20,
58            swag_collection_start: 0,
59            swag_collection_freq: 1,
60        }
61    }
62}
63
64/// Structure of the posterior covariance matrix.
65#[derive(Debug, Clone)]
66#[non_exhaustive]
67pub enum CovarianceType {
68    /// Full dense covariance matrix
69    Full(Array2<f64>),
70    /// Diagonal covariance (variance vector)
71    Diagonal(Array1<f64>),
72    /// Low-rank plus diagonal: Sigma = diag(d_diag) + deviation * deviation^T / (K-1)
73    LowRankPlusDiagonal {
74        /// Diagonal component
75        d_diag: Array1<f64>,
76        /// Low-rank deviation matrix, columns are (theta_i - theta_bar)
77        deviation: Array2<f64>,
78    },
79    /// Kronecker-factored covariance: Sigma approx A kron B
80    KroneckerFactored {
81        /// Activation covariance factor
82        a_factor: Array2<f64>,
83        /// Gradient covariance factor
84        b_factor: Array2<f64>,
85    },
86}
87
88/// Posterior distribution over neural network weights.
89#[derive(Debug, Clone)]
90pub struct BNNPosterior {
91    /// Posterior mean (MAP estimate)
92    pub mean: Array1<f64>,
93    /// Covariance structure
94    pub covariance_type: CovarianceType,
95    /// Log marginal likelihood estimate
96    pub log_marginal_likelihood: f64,
97}
98
99/// Predictive distribution at test points.
100#[derive(Debug, Clone)]
101pub struct PredictiveDistribution {
102    /// Predictive mean
103    pub mean: Array1<f64>,
104    /// Predictive variance
105    pub variance: Array1<f64>,
106    /// Optional matrix of prediction samples, shape \[n_samples x n_outputs\]
107    pub samples: Option<Array2<f64>>,
108}
109
110/// A single bin in a reliability (calibration) diagram.
111#[derive(Debug, Clone)]
112pub struct ReliabilityBin {
113    /// Mean predicted probability in this bin
114    pub mean_predicted: f64,
115    /// Mean observed frequency (fraction of positives) in this bin
116    pub mean_observed: f64,
117    /// Number of predictions that fell in this bin
118    pub count: usize,
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124
125    #[test]
126    fn test_default_config() {
127        let cfg = BNNConfig::default();
128        assert_eq!(cfg.n_samples, 30);
129        assert!((cfg.prior_precision - 1.0).abs() < 1e-12);
130        assert_eq!(cfg.swag_rank, 20);
131        assert_eq!(cfg.method, ApproximationMethod::FullLaplace);
132    }
133
134    #[test]
135    fn test_uncertainty_type_variants() {
136        let u = UncertaintyType::Total;
137        assert_eq!(u, UncertaintyType::Total);
138        assert_ne!(UncertaintyType::Aleatoric, UncertaintyType::Epistemic);
139    }
140
141    #[test]
142    fn test_covariance_type_diagonal() {
143        let diag = Array1::from_vec(vec![1.0, 2.0, 3.0]);
144        let cov = CovarianceType::Diagonal(diag.clone());
145        match &cov {
146            CovarianceType::Diagonal(d) => assert_eq!(d.len(), 3),
147            _ => panic!("Expected Diagonal variant"),
148        }
149    }
150
151    #[test]
152    fn test_predictive_distribution_creation() {
153        let pd = PredictiveDistribution {
154            mean: Array1::from_vec(vec![1.0, 2.0]),
155            variance: Array1::from_vec(vec![0.1, 0.2]),
156            samples: None,
157        };
158        assert_eq!(pd.mean.len(), 2);
159        assert!(pd.samples.is_none());
160    }
161
162    #[test]
163    fn test_reliability_bin() {
164        let bin = ReliabilityBin {
165            mean_predicted: 0.5,
166            mean_observed: 0.48,
167            count: 100,
168        };
169        assert_eq!(bin.count, 100);
170        assert!((bin.mean_predicted - 0.5).abs() < 1e-12);
171    }
172}