Skip to main content

scirs2_stats/bayesian_nn/
layers.rs

1//! Variational Bayesian linear layers for Bayesian Neural Networks.
2//!
3//! Implements Bayes-by-Backprop (Blundell et al. 2015) local reparameterization:
4//! weights are parameterized by means and log-standard-deviations, with MC sampling
5//! at forward pass time and KL divergence for the ELBO objective.
6//!
7//! Each weight w ~ q(w) = N(w_mu, exp(w_log_sigma)^2).
8//! The prior is p(w) = N(0, prior_std^2).
9//!
10//! KL[q || p] = -0.5 * sum(1 + 2*log_sigma - log(prior_std^2) - mu^2/prior_std^2
11//!                          - exp(2*log_sigma)/prior_std^2)
12
13use crate::error::StatsError;
14
15/// Configuration for variational Bayesian neural network layers.
16#[derive(Debug, Clone)]
17pub struct BnnConfig {
18    /// Standard deviation of the weight prior N(0, prior_std^2)
19    pub prior_std: f64,
20    /// Scaling factor on the KL term in the ELBO (default 1.0)
21    pub kl_weight: f64,
22    /// Number of Monte Carlo forward-pass samples for stochastic gradient estimation
23    pub n_samples_mc: usize,
24}
25
26impl Default for BnnConfig {
27    fn default() -> Self {
28        Self {
29            prior_std: 1.0,
30            kl_weight: 1.0,
31            n_samples_mc: 10,
32        }
33    }
34}
35
36/// A single variational Bayesian linear layer.
37///
38/// Weights: w_{ij} ~ N(w_mu_{ij}, exp(w_log_sigma_{ij})^2)
39/// Biases:  b_j   ~ N(b_mu_j,   exp(b_log_sigma_j)^2)
40///
41/// Stored as flat row-major vectors of length `out_features * in_features`.
42#[derive(Debug, Clone)]
43pub struct BayesianLinear {
44    /// Number of input features
45    pub in_features: usize,
46    /// Number of output features
47    pub out_features: usize,
48    /// Weight posterior means, length `out_features * in_features`
49    pub w_mu: Vec<f64>,
50    /// Weight posterior log-std, length `out_features * in_features`
51    pub w_log_sigma: Vec<f64>,
52    /// Bias posterior means, length `out_features`
53    pub b_mu: Vec<f64>,
54    /// Bias posterior log-std, length `out_features`
55    pub b_log_sigma: Vec<f64>,
56    /// Prior standard deviation
57    pub prior_std: f64,
58}
59
60impl BayesianLinear {
61    /// Create a new `BayesianLinear` layer.
62    ///
63    /// Weights are initialized from N(0, 0.1) and log-sigma initialized to -3.0
64    /// (corresponding to sigma ≈ 0.05, tight initial posterior).
65    ///
66    /// # Arguments
67    /// * `in_features` - Input dimensionality
68    /// * `out_features` - Output dimensionality
69    /// * `prior_std` - Standard deviation of the weight prior
70    ///
71    /// # Errors
72    /// Returns an error if `in_features` or `out_features` is zero.
73    pub fn new(
74        in_features: usize,
75        out_features: usize,
76        prior_std: f64,
77    ) -> Result<Self, StatsError> {
78        if in_features == 0 {
79            return Err(StatsError::InvalidArgument(
80                "in_features must be > 0".to_string(),
81            ));
82        }
83        if out_features == 0 {
84            return Err(StatsError::InvalidArgument(
85                "out_features must be > 0".to_string(),
86            ));
87        }
88        if prior_std <= 0.0 {
89            return Err(StatsError::InvalidArgument(
90                "prior_std must be positive".to_string(),
91            ));
92        }
93
94        let n_weights = out_features * in_features;
95
96        // Initialize w_mu ~ N(0, 0.1) using a deterministic pseudo-random scheme
97        // (Lehmer LCG seeded by size for reproducibility without external RNG dependency)
98        let mut w_mu = vec![0.0f64; n_weights];
99        let mut state: u64 = (n_weights as u64)
100            .wrapping_mul(6364136223846793005)
101            .wrapping_add(1442695040888963407);
102        for wm in w_mu.iter_mut() {
103            state = state
104                .wrapping_mul(6364136223846793005)
105                .wrapping_add(1442695040888963407);
106            // Map u64 to [-0.1, 0.1]
107            let u = (state >> 11) as f64 / (1u64 << 53) as f64; // [0,1)
108            *wm = (u - 0.5) * 0.2; // [-0.1, 0.1]
109        }
110
111        let w_log_sigma = vec![-3.0f64; n_weights];
112        let b_mu = vec![0.0f64; out_features];
113        let b_log_sigma = vec![-3.0f64; out_features];
114
115        Ok(Self {
116            in_features,
117            out_features,
118            w_mu,
119            w_log_sigma,
120            b_mu,
121            b_log_sigma,
122            prior_std,
123        })
124    }
125
126    /// Forward pass with sampled weights (reparameterization trick).
127    ///
128    /// Samples w = w_mu + eps * exp(w_log_sigma) for each weight and bias,
129    /// then computes the matrix-vector product.
130    ///
131    /// # Arguments
132    /// * `x` - Input vector of length `in_features`
133    /// * `rng` - Closure producing standard normal samples N(0,1)
134    ///
135    /// # Returns
136    /// Output vector of length `out_features`
137    ///
138    /// # Errors
139    /// Returns an error if `x` has incorrect length.
140    pub fn forward_sample(
141        &self,
142        x: &[f64],
143        rng: &mut impl FnMut() -> f64,
144    ) -> Result<Vec<f64>, StatsError> {
145        if x.len() != self.in_features {
146            return Err(StatsError::DimensionMismatch(format!(
147                "input length {} != in_features {}",
148                x.len(),
149                self.in_features
150            )));
151        }
152
153        let mut out = vec![0.0f64; self.out_features];
154        for o in 0..self.out_features {
155            // Sampled bias
156            let eps_b = rng();
157            let b_sigma = self.b_log_sigma[o].exp();
158            let b_sample = self.b_mu[o] + eps_b * b_sigma;
159
160            let mut acc = b_sample;
161            for i in 0..self.in_features {
162                let idx = o * self.in_features + i;
163                let eps_w = rng();
164                let w_sigma = self.w_log_sigma[idx].exp();
165                let w_sample = self.w_mu[idx] + eps_w * w_sigma;
166                acc += w_sample * x[i];
167            }
168            out[o] = acc;
169        }
170        Ok(out)
171    }
172
173    /// Deterministic forward pass using posterior means only.
174    ///
175    /// Computes output = W_mu @ x + b_mu. Useful for fast predictive mean.
176    ///
177    /// # Arguments
178    /// * `x` - Input vector of length `in_features`
179    ///
180    /// # Errors
181    /// Returns an error if `x` has incorrect length.
182    pub fn forward_mean(&self, x: &[f64]) -> Result<Vec<f64>, StatsError> {
183        if x.len() != self.in_features {
184            return Err(StatsError::DimensionMismatch(format!(
185                "input length {} != in_features {}",
186                x.len(),
187                self.in_features
188            )));
189        }
190
191        let mut out = vec![0.0f64; self.out_features];
192        for o in 0..self.out_features {
193            let mut acc = self.b_mu[o];
194            for i in 0..self.in_features {
195                acc += self.w_mu[o * self.in_features + i] * x[i];
196            }
197            out[o] = acc;
198        }
199        Ok(out)
200    }
201
202    /// Compute the KL divergence KL[q(w) || p(w)] for all weights and biases.
203    ///
204    /// For q = N(mu, sigma^2) and p = N(0, prior_std^2):
205    /// KL = -0.5 * sum(1 + 2*log_sigma - log(prior_std^2) - mu^2/prior_std^2
206    ///                   - sigma^2/prior_std^2)
207    ///
208    /// # Arguments
209    /// * `prior_std` - Prior standard deviation (can differ from initialization value)
210    pub fn kl_divergence(&self, prior_std: f64) -> f64 {
211        let log_prior_var = (prior_std * prior_std).ln();
212        let prior_var = prior_std * prior_std;
213        let mut kl = 0.0;
214
215        // Weights
216        for i in 0..(self.out_features * self.in_features) {
217            let mu = self.w_mu[i];
218            let log_sigma = self.w_log_sigma[i];
219            let sigma_sq = (2.0 * log_sigma).exp();
220            kl += -0.5
221                * (1.0 + 2.0 * log_sigma
222                    - log_prior_var
223                    - mu * mu / prior_var
224                    - sigma_sq / prior_var);
225        }
226
227        // Biases
228        for o in 0..self.out_features {
229            let mu = self.b_mu[o];
230            let log_sigma = self.b_log_sigma[o];
231            let sigma_sq = (2.0 * log_sigma).exp();
232            kl += -0.5
233                * (1.0 + 2.0 * log_sigma
234                    - log_prior_var
235                    - mu * mu / prior_var
236                    - sigma_sq / prior_var);
237        }
238
239        kl
240    }
241
242    /// Apply a gradient step (SGD) to the variational parameters.
243    ///
244    /// # Arguments
245    /// * `grad_w_mu`        - Gradient of loss w.r.t. w_mu, length `out*in`
246    /// * `grad_w_log_sigma` - Gradient of loss w.r.t. w_log_sigma, length `out*in`
247    /// * `grad_b_mu`        - Gradient of loss w.r.t. b_mu, length `out`
248    /// * `grad_b_log_sigma` - Gradient of loss w.r.t. b_log_sigma, length `out`
249    /// * `lr`               - Learning rate
250    ///
251    /// # Errors
252    /// Returns an error if gradient dimensions are inconsistent.
253    pub fn update(
254        &mut self,
255        grad_w_mu: &[f64],
256        grad_w_log_sigma: &[f64],
257        grad_b_mu: &[f64],
258        grad_b_log_sigma: &[f64],
259        lr: f64,
260    ) -> Result<(), StatsError> {
261        let n_weights = self.out_features * self.in_features;
262        if grad_w_mu.len() != n_weights {
263            return Err(StatsError::DimensionMismatch(format!(
264                "grad_w_mu length {} != {}",
265                grad_w_mu.len(),
266                n_weights
267            )));
268        }
269        if grad_w_log_sigma.len() != n_weights {
270            return Err(StatsError::DimensionMismatch(format!(
271                "grad_w_log_sigma length {} != {}",
272                grad_w_log_sigma.len(),
273                n_weights
274            )));
275        }
276        if grad_b_mu.len() != self.out_features {
277            return Err(StatsError::DimensionMismatch(format!(
278                "grad_b_mu length {} != {}",
279                grad_b_mu.len(),
280                self.out_features
281            )));
282        }
283        if grad_b_log_sigma.len() != self.out_features {
284            return Err(StatsError::DimensionMismatch(format!(
285                "grad_b_log_sigma length {} != {}",
286                grad_b_log_sigma.len(),
287                self.out_features
288            )));
289        }
290
291        for i in 0..n_weights {
292            self.w_mu[i] -= lr * grad_w_mu[i];
293            self.w_log_sigma[i] -= lr * grad_w_log_sigma[i];
294        }
295        for o in 0..self.out_features {
296            self.b_mu[o] -= lr * grad_b_mu[o];
297            self.b_log_sigma[o] -= lr * grad_b_log_sigma[o];
298        }
299        Ok(())
300    }
301
302    /// Total number of variational parameters (for KL / ELBO scaling)
303    pub fn n_params(&self) -> usize {
304        2 * (self.out_features * self.in_features + self.out_features)
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311
312    fn make_normal_rng() -> impl FnMut() -> f64 {
313        // Box-Muller for standard normal without external deps
314        let mut state: u64 = 12345678901234567;
315        let mut cached: Option<f64> = None;
316        move || {
317            if let Some(v) = cached.take() {
318                return v;
319            }
320            state = state
321                .wrapping_mul(6364136223846793005)
322                .wrapping_add(1442695040888963407);
323            let u1 = (state >> 11) as f64 / (1u64 << 53) as f64 + 1e-15;
324            state = state
325                .wrapping_mul(6364136223846793005)
326                .wrapping_add(1442695040888963407);
327            let u2 = (state >> 11) as f64 / (1u64 << 53) as f64;
328            let r = (-2.0 * u1.ln()).sqrt();
329            let theta = 2.0 * std::f64::consts::PI * u2;
330            cached = Some(r * theta.sin());
331            r * theta.cos()
332        }
333    }
334
335    #[test]
336    fn test_bayesian_linear_new() {
337        let layer = BayesianLinear::new(3, 4, 1.0).expect("creation should succeed");
338        assert_eq!(layer.in_features, 3);
339        assert_eq!(layer.out_features, 4);
340        assert_eq!(layer.w_mu.len(), 12);
341        assert_eq!(layer.w_log_sigma.len(), 12);
342        assert_eq!(layer.b_mu.len(), 4);
343        assert_eq!(layer.b_log_sigma.len(), 4);
344        // All log-sigma should be -3.0
345        for &ls in &layer.w_log_sigma {
346            assert!((ls - (-3.0)).abs() < 1e-12);
347        }
348    }
349
350    #[test]
351    fn test_forward_mean_shape() {
352        let layer = BayesianLinear::new(5, 3, 1.0).expect("creation");
353        let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
354        let out = layer.forward_mean(&x).expect("forward_mean");
355        assert_eq!(out.len(), 3);
356    }
357
358    #[test]
359    fn test_forward_sample_shape() {
360        let layer = BayesianLinear::new(4, 2, 1.0).expect("creation");
361        let x = vec![1.0, 0.0, -1.0, 0.5];
362        let mut rng = make_normal_rng();
363        let out = layer.forward_sample(&x, &mut rng).expect("forward_sample");
364        assert_eq!(out.len(), 2);
365    }
366
367    #[test]
368    fn test_kl_divergence_positive() {
369        // With non-zero means, KL should be > 0
370        let mut layer = BayesianLinear::new(2, 2, 1.0).expect("creation");
371        layer.w_mu[0] = 1.0;
372        layer.w_mu[1] = -0.5;
373        let kl = layer.kl_divergence(1.0);
374        assert!(
375            kl > 0.0,
376            "KL divergence should be positive with non-zero means, got {}",
377            kl
378        );
379    }
380
381    #[test]
382    fn test_kl_zero_with_prior_params() {
383        // When mu=0 and sigma=prior_std, KL should be 0
384        let prior_std = 1.0;
385        let mut layer = BayesianLinear::new(2, 1, prior_std).expect("creation");
386        // Set mu=0, log_sigma = log(prior_std) = 0.0
387        for w in layer.w_mu.iter_mut() {
388            *w = 0.0;
389        }
390        for ls in layer.w_log_sigma.iter_mut() {
391            *ls = prior_std.ln();
392        } // = 0.0
393        for b in layer.b_mu.iter_mut() {
394            *b = 0.0;
395        }
396        for ls in layer.b_log_sigma.iter_mut() {
397            *ls = prior_std.ln();
398        }
399        let kl = layer.kl_divergence(prior_std);
400        assert!(kl.abs() < 1e-10, "KL should be ~0 when q=p, got {}", kl);
401    }
402
403    #[test]
404    fn test_update_step() {
405        let mut layer = BayesianLinear::new(2, 2, 1.0).expect("creation");
406        let w_mu_before = layer.w_mu.clone();
407        let grad_w_mu = vec![1.0, 0.0, -1.0, 0.5];
408        let grad_w_ls = vec![0.1, 0.2, 0.3, 0.4];
409        let grad_b_mu = vec![0.5, -0.5];
410        let grad_b_ls = vec![0.1, 0.1];
411        layer
412            .update(&grad_w_mu, &grad_w_ls, &grad_b_mu, &grad_b_ls, 0.01)
413            .expect("update");
414        // w_mu should have changed
415        assert!((layer.w_mu[0] - (w_mu_before[0] - 0.01 * 1.0)).abs() < 1e-12);
416    }
417
418    #[test]
419    fn test_dimension_errors() {
420        assert!(BayesianLinear::new(0, 3, 1.0).is_err());
421        assert!(BayesianLinear::new(3, 0, 1.0).is_err());
422        assert!(BayesianLinear::new(3, 3, -1.0).is_err());
423
424        let layer = BayesianLinear::new(3, 2, 1.0).expect("creation");
425        assert!(layer.forward_mean(&[1.0, 2.0]).is_err()); // wrong input size
426    }
427}