scirs2_stats/bayesian_nn/
types.rs1use scirs2_core::ndarray::{Array1, Array2};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10#[non_exhaustive]
11pub enum UncertaintyType {
12 Aleatoric,
14 Epistemic,
16 Total,
18}
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22#[non_exhaustive]
23pub enum ApproximationMethod {
24 LastLayerLaplace,
26 FullLaplace,
28 SWAG,
30 SWAGDiag,
32}
33
34#[derive(Debug, Clone)]
36pub struct BNNConfig {
37 pub method: ApproximationMethod,
39 pub n_samples: usize,
41 pub prior_precision: f64,
43 pub swag_rank: usize,
45 pub swag_collection_start: usize,
47 pub swag_collection_freq: usize,
49}
50
51impl Default for BNNConfig {
52 fn default() -> Self {
53 Self {
54 method: ApproximationMethod::FullLaplace,
55 n_samples: 30,
56 prior_precision: 1.0,
57 swag_rank: 20,
58 swag_collection_start: 0,
59 swag_collection_freq: 1,
60 }
61 }
62}
63
64#[derive(Debug, Clone)]
66#[non_exhaustive]
67pub enum CovarianceType {
68 Full(Array2<f64>),
70 Diagonal(Array1<f64>),
72 LowRankPlusDiagonal {
74 d_diag: Array1<f64>,
76 deviation: Array2<f64>,
78 },
79 KroneckerFactored {
81 a_factor: Array2<f64>,
83 b_factor: Array2<f64>,
85 },
86}
87
88#[derive(Debug, Clone)]
90pub struct BNNPosterior {
91 pub mean: Array1<f64>,
93 pub covariance_type: CovarianceType,
95 pub log_marginal_likelihood: f64,
97}
98
99#[derive(Debug, Clone)]
101pub struct PredictiveDistribution {
102 pub mean: Array1<f64>,
104 pub variance: Array1<f64>,
106 pub samples: Option<Array2<f64>>,
108}
109
110#[derive(Debug, Clone)]
112pub struct ReliabilityBin {
113 pub mean_predicted: f64,
115 pub mean_observed: f64,
117 pub count: usize,
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124
125 #[test]
126 fn test_default_config() {
127 let cfg = BNNConfig::default();
128 assert_eq!(cfg.n_samples, 30);
129 assert!((cfg.prior_precision - 1.0).abs() < 1e-12);
130 assert_eq!(cfg.swag_rank, 20);
131 assert_eq!(cfg.method, ApproximationMethod::FullLaplace);
132 }
133
134 #[test]
135 fn test_uncertainty_type_variants() {
136 let u = UncertaintyType::Total;
137 assert_eq!(u, UncertaintyType::Total);
138 assert_ne!(UncertaintyType::Aleatoric, UncertaintyType::Epistemic);
139 }
140
141 #[test]
142 fn test_covariance_type_diagonal() {
143 let diag = Array1::from_vec(vec![1.0, 2.0, 3.0]);
144 let cov = CovarianceType::Diagonal(diag.clone());
145 match &cov {
146 CovarianceType::Diagonal(d) => assert_eq!(d.len(), 3),
147 _ => panic!("Expected Diagonal variant"),
148 }
149 }
150
151 #[test]
152 fn test_predictive_distribution_creation() {
153 let pd = PredictiveDistribution {
154 mean: Array1::from_vec(vec![1.0, 2.0]),
155 variance: Array1::from_vec(vec![0.1, 0.2]),
156 samples: None,
157 };
158 assert_eq!(pd.mean.len(), 2);
159 assert!(pd.samples.is_none());
160 }
161
162 #[test]
163 fn test_reliability_bin() {
164 let bin = ReliabilityBin {
165 mean_predicted: 0.5,
166 mean_observed: 0.48,
167 count: 100,
168 };
169 assert_eq!(bin.count, 100);
170 assert!((bin.mean_predicted - 0.5).abs() < 1e-12);
171 }
172}