scirs2_stats/bayesian_nn/
layers.rs1use crate::error::StatsError;
14
15#[derive(Debug, Clone)]
17pub struct BnnConfig {
18 pub prior_std: f64,
20 pub kl_weight: f64,
22 pub n_samples_mc: usize,
24}
25
26impl Default for BnnConfig {
27 fn default() -> Self {
28 Self {
29 prior_std: 1.0,
30 kl_weight: 1.0,
31 n_samples_mc: 10,
32 }
33 }
34}
35
36#[derive(Debug, Clone)]
43pub struct BayesianLinear {
44 pub in_features: usize,
46 pub out_features: usize,
48 pub w_mu: Vec<f64>,
50 pub w_log_sigma: Vec<f64>,
52 pub b_mu: Vec<f64>,
54 pub b_log_sigma: Vec<f64>,
56 pub prior_std: f64,
58}
59
60impl BayesianLinear {
61 pub fn new(
74 in_features: usize,
75 out_features: usize,
76 prior_std: f64,
77 ) -> Result<Self, StatsError> {
78 if in_features == 0 {
79 return Err(StatsError::InvalidArgument(
80 "in_features must be > 0".to_string(),
81 ));
82 }
83 if out_features == 0 {
84 return Err(StatsError::InvalidArgument(
85 "out_features must be > 0".to_string(),
86 ));
87 }
88 if prior_std <= 0.0 {
89 return Err(StatsError::InvalidArgument(
90 "prior_std must be positive".to_string(),
91 ));
92 }
93
94 let n_weights = out_features * in_features;
95
96 let mut w_mu = vec![0.0f64; n_weights];
99 let mut state: u64 = (n_weights as u64)
100 .wrapping_mul(6364136223846793005)
101 .wrapping_add(1442695040888963407);
102 for wm in w_mu.iter_mut() {
103 state = state
104 .wrapping_mul(6364136223846793005)
105 .wrapping_add(1442695040888963407);
106 let u = (state >> 11) as f64 / (1u64 << 53) as f64; *wm = (u - 0.5) * 0.2; }
110
111 let w_log_sigma = vec![-3.0f64; n_weights];
112 let b_mu = vec![0.0f64; out_features];
113 let b_log_sigma = vec![-3.0f64; out_features];
114
115 Ok(Self {
116 in_features,
117 out_features,
118 w_mu,
119 w_log_sigma,
120 b_mu,
121 b_log_sigma,
122 prior_std,
123 })
124 }
125
126 pub fn forward_sample(
141 &self,
142 x: &[f64],
143 rng: &mut impl FnMut() -> f64,
144 ) -> Result<Vec<f64>, StatsError> {
145 if x.len() != self.in_features {
146 return Err(StatsError::DimensionMismatch(format!(
147 "input length {} != in_features {}",
148 x.len(),
149 self.in_features
150 )));
151 }
152
153 let mut out = vec![0.0f64; self.out_features];
154 for o in 0..self.out_features {
155 let eps_b = rng();
157 let b_sigma = self.b_log_sigma[o].exp();
158 let b_sample = self.b_mu[o] + eps_b * b_sigma;
159
160 let mut acc = b_sample;
161 for i in 0..self.in_features {
162 let idx = o * self.in_features + i;
163 let eps_w = rng();
164 let w_sigma = self.w_log_sigma[idx].exp();
165 let w_sample = self.w_mu[idx] + eps_w * w_sigma;
166 acc += w_sample * x[i];
167 }
168 out[o] = acc;
169 }
170 Ok(out)
171 }
172
173 pub fn forward_mean(&self, x: &[f64]) -> Result<Vec<f64>, StatsError> {
183 if x.len() != self.in_features {
184 return Err(StatsError::DimensionMismatch(format!(
185 "input length {} != in_features {}",
186 x.len(),
187 self.in_features
188 )));
189 }
190
191 let mut out = vec![0.0f64; self.out_features];
192 for o in 0..self.out_features {
193 let mut acc = self.b_mu[o];
194 for i in 0..self.in_features {
195 acc += self.w_mu[o * self.in_features + i] * x[i];
196 }
197 out[o] = acc;
198 }
199 Ok(out)
200 }
201
202 pub fn kl_divergence(&self, prior_std: f64) -> f64 {
211 let log_prior_var = (prior_std * prior_std).ln();
212 let prior_var = prior_std * prior_std;
213 let mut kl = 0.0;
214
215 for i in 0..(self.out_features * self.in_features) {
217 let mu = self.w_mu[i];
218 let log_sigma = self.w_log_sigma[i];
219 let sigma_sq = (2.0 * log_sigma).exp();
220 kl += -0.5
221 * (1.0 + 2.0 * log_sigma
222 - log_prior_var
223 - mu * mu / prior_var
224 - sigma_sq / prior_var);
225 }
226
227 for o in 0..self.out_features {
229 let mu = self.b_mu[o];
230 let log_sigma = self.b_log_sigma[o];
231 let sigma_sq = (2.0 * log_sigma).exp();
232 kl += -0.5
233 * (1.0 + 2.0 * log_sigma
234 - log_prior_var
235 - mu * mu / prior_var
236 - sigma_sq / prior_var);
237 }
238
239 kl
240 }
241
242 pub fn update(
254 &mut self,
255 grad_w_mu: &[f64],
256 grad_w_log_sigma: &[f64],
257 grad_b_mu: &[f64],
258 grad_b_log_sigma: &[f64],
259 lr: f64,
260 ) -> Result<(), StatsError> {
261 let n_weights = self.out_features * self.in_features;
262 if grad_w_mu.len() != n_weights {
263 return Err(StatsError::DimensionMismatch(format!(
264 "grad_w_mu length {} != {}",
265 grad_w_mu.len(),
266 n_weights
267 )));
268 }
269 if grad_w_log_sigma.len() != n_weights {
270 return Err(StatsError::DimensionMismatch(format!(
271 "grad_w_log_sigma length {} != {}",
272 grad_w_log_sigma.len(),
273 n_weights
274 )));
275 }
276 if grad_b_mu.len() != self.out_features {
277 return Err(StatsError::DimensionMismatch(format!(
278 "grad_b_mu length {} != {}",
279 grad_b_mu.len(),
280 self.out_features
281 )));
282 }
283 if grad_b_log_sigma.len() != self.out_features {
284 return Err(StatsError::DimensionMismatch(format!(
285 "grad_b_log_sigma length {} != {}",
286 grad_b_log_sigma.len(),
287 self.out_features
288 )));
289 }
290
291 for i in 0..n_weights {
292 self.w_mu[i] -= lr * grad_w_mu[i];
293 self.w_log_sigma[i] -= lr * grad_w_log_sigma[i];
294 }
295 for o in 0..self.out_features {
296 self.b_mu[o] -= lr * grad_b_mu[o];
297 self.b_log_sigma[o] -= lr * grad_b_log_sigma[o];
298 }
299 Ok(())
300 }
301
302 pub fn n_params(&self) -> usize {
304 2 * (self.out_features * self.in_features + self.out_features)
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311
312 fn make_normal_rng() -> impl FnMut() -> f64 {
313 let mut state: u64 = 12345678901234567;
315 let mut cached: Option<f64> = None;
316 move || {
317 if let Some(v) = cached.take() {
318 return v;
319 }
320 state = state
321 .wrapping_mul(6364136223846793005)
322 .wrapping_add(1442695040888963407);
323 let u1 = (state >> 11) as f64 / (1u64 << 53) as f64 + 1e-15;
324 state = state
325 .wrapping_mul(6364136223846793005)
326 .wrapping_add(1442695040888963407);
327 let u2 = (state >> 11) as f64 / (1u64 << 53) as f64;
328 let r = (-2.0 * u1.ln()).sqrt();
329 let theta = 2.0 * std::f64::consts::PI * u2;
330 cached = Some(r * theta.sin());
331 r * theta.cos()
332 }
333 }
334
335 #[test]
336 fn test_bayesian_linear_new() {
337 let layer = BayesianLinear::new(3, 4, 1.0).expect("creation should succeed");
338 assert_eq!(layer.in_features, 3);
339 assert_eq!(layer.out_features, 4);
340 assert_eq!(layer.w_mu.len(), 12);
341 assert_eq!(layer.w_log_sigma.len(), 12);
342 assert_eq!(layer.b_mu.len(), 4);
343 assert_eq!(layer.b_log_sigma.len(), 4);
344 for &ls in &layer.w_log_sigma {
346 assert!((ls - (-3.0)).abs() < 1e-12);
347 }
348 }
349
350 #[test]
351 fn test_forward_mean_shape() {
352 let layer = BayesianLinear::new(5, 3, 1.0).expect("creation");
353 let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
354 let out = layer.forward_mean(&x).expect("forward_mean");
355 assert_eq!(out.len(), 3);
356 }
357
358 #[test]
359 fn test_forward_sample_shape() {
360 let layer = BayesianLinear::new(4, 2, 1.0).expect("creation");
361 let x = vec![1.0, 0.0, -1.0, 0.5];
362 let mut rng = make_normal_rng();
363 let out = layer.forward_sample(&x, &mut rng).expect("forward_sample");
364 assert_eq!(out.len(), 2);
365 }
366
367 #[test]
368 fn test_kl_divergence_positive() {
369 let mut layer = BayesianLinear::new(2, 2, 1.0).expect("creation");
371 layer.w_mu[0] = 1.0;
372 layer.w_mu[1] = -0.5;
373 let kl = layer.kl_divergence(1.0);
374 assert!(
375 kl > 0.0,
376 "KL divergence should be positive with non-zero means, got {}",
377 kl
378 );
379 }
380
381 #[test]
382 fn test_kl_zero_with_prior_params() {
383 let prior_std = 1.0;
385 let mut layer = BayesianLinear::new(2, 1, prior_std).expect("creation");
386 for w in layer.w_mu.iter_mut() {
388 *w = 0.0;
389 }
390 for ls in layer.w_log_sigma.iter_mut() {
391 *ls = prior_std.ln();
392 } for b in layer.b_mu.iter_mut() {
394 *b = 0.0;
395 }
396 for ls in layer.b_log_sigma.iter_mut() {
397 *ls = prior_std.ln();
398 }
399 let kl = layer.kl_divergence(prior_std);
400 assert!(kl.abs() < 1e-10, "KL should be ~0 when q=p, got {}", kl);
401 }
402
403 #[test]
404 fn test_update_step() {
405 let mut layer = BayesianLinear::new(2, 2, 1.0).expect("creation");
406 let w_mu_before = layer.w_mu.clone();
407 let grad_w_mu = vec![1.0, 0.0, -1.0, 0.5];
408 let grad_w_ls = vec![0.1, 0.2, 0.3, 0.4];
409 let grad_b_mu = vec![0.5, -0.5];
410 let grad_b_ls = vec![0.1, 0.1];
411 layer
412 .update(&grad_w_mu, &grad_w_ls, &grad_b_mu, &grad_b_ls, 0.01)
413 .expect("update");
414 assert!((layer.w_mu[0] - (w_mu_before[0] - 0.01 * 1.0)).abs() < 1e-12);
416 }
417
418 #[test]
419 fn test_dimension_errors() {
420 assert!(BayesianLinear::new(0, 3, 1.0).is_err());
421 assert!(BayesianLinear::new(3, 0, 1.0).is_err());
422 assert!(BayesianLinear::new(3, 3, -1.0).is_err());
423
424 let layer = BayesianLinear::new(3, 2, 1.0).expect("creation");
425 assert!(layer.forward_mean(&[1.0, 2.0]).is_err()); }
427}