pub struct BayesianLinear {
pub in_features: usize,
pub out_features: usize,
pub w_mu: Vec<f64>,
pub w_log_sigma: Vec<f64>,
pub b_mu: Vec<f64>,
pub b_log_sigma: Vec<f64>,
pub prior_std: f64,
}Expand description
A single variational Bayesian linear layer.
Weights: w_{ij} ~ N(w_mu_{ij}, exp(w_log_sigma_{ij})^2) Biases: b_j ~ N(b_mu_j, exp(b_log_sigma_j)^2)
Stored as flat row-major vectors of length out_features * in_features.
Fields§
§in_features: usizeNumber of input features
out_features: usizeNumber of output features
w_mu: Vec<f64>Weight posterior means, length out_features * in_features
w_log_sigma: Vec<f64>Weight posterior log-std, length out_features * in_features
b_mu: Vec<f64>Bias posterior means, length out_features
b_log_sigma: Vec<f64>Bias posterior log-std, length out_features
prior_std: f64Prior standard deviation
Implementations§
Source§impl BayesianLinear
impl BayesianLinear
Sourcepub fn new(
in_features: usize,
out_features: usize,
prior_std: f64,
) -> Result<Self, StatsError>
pub fn new( in_features: usize, out_features: usize, prior_std: f64, ) -> Result<Self, StatsError>
Create a new BayesianLinear layer.
Weights are initialized from N(0, 0.1) and log-sigma initialized to -3.0 (corresponding to sigma ≈ 0.05, tight initial posterior).
§Arguments
in_features- Input dimensionalityout_features- Output dimensionalityprior_std- Standard deviation of the weight prior
§Errors
Returns an error if in_features or out_features is zero.
Sourcepub fn forward_sample(
&self,
x: &[f64],
rng: &mut impl FnMut() -> f64,
) -> Result<Vec<f64>, StatsError>
pub fn forward_sample( &self, x: &[f64], rng: &mut impl FnMut() -> f64, ) -> Result<Vec<f64>, StatsError>
Forward pass with sampled weights (reparameterization trick).
Samples w = w_mu + eps * exp(w_log_sigma) for each weight and bias, then computes the matrix-vector product.
§Arguments
x- Input vector of lengthin_featuresrng- Closure producing standard normal samples N(0,1)
§Returns
Output vector of length out_features
§Errors
Returns an error if x has incorrect length.
Sourcepub fn forward_mean(&self, x: &[f64]) -> Result<Vec<f64>, StatsError>
pub fn forward_mean(&self, x: &[f64]) -> Result<Vec<f64>, StatsError>
Sourcepub fn kl_divergence(&self, prior_std: f64) -> f64
pub fn kl_divergence(&self, prior_std: f64) -> f64
Compute the KL divergence KL[q(w) || p(w)] for all weights and biases.
For q = N(mu, sigma^2) and p = N(0, prior_std^2): KL = -0.5 * sum(1 + 2*log_sigma - log(prior_std^2) - mu^2/prior_std^2 - sigma^2/prior_std^2)
§Arguments
prior_std- Prior standard deviation (can differ from initialization value)
Sourcepub fn update(
&mut self,
grad_w_mu: &[f64],
grad_w_log_sigma: &[f64],
grad_b_mu: &[f64],
grad_b_log_sigma: &[f64],
lr: f64,
) -> Result<(), StatsError>
pub fn update( &mut self, grad_w_mu: &[f64], grad_w_log_sigma: &[f64], grad_b_mu: &[f64], grad_b_log_sigma: &[f64], lr: f64, ) -> Result<(), StatsError>
Apply a gradient step (SGD) to the variational parameters.
§Arguments
grad_w_mu- Gradient of loss w.r.t. w_mu, lengthout*ingrad_w_log_sigma- Gradient of loss w.r.t. w_log_sigma, lengthout*ingrad_b_mu- Gradient of loss w.r.t. b_mu, lengthoutgrad_b_log_sigma- Gradient of loss w.r.t. b_log_sigma, lengthoutlr- Learning rate
§Errors
Returns an error if gradient dimensions are inconsistent.
Trait Implementations§
Source§impl Clone for BayesianLinear
impl Clone for BayesianLinear
Source§fn clone(&self) -> BayesianLinear
fn clone(&self) -> BayesianLinear
1.0.0 (const: unstable) · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
source. Read moreAuto Trait Implementations§
impl Freeze for BayesianLinear
impl RefUnwindSafe for BayesianLinear
impl Send for BayesianLinear
impl Sync for BayesianLinear
impl Unpin for BayesianLinear
impl UnsafeUnpin for BayesianLinear
impl UnwindSafe for BayesianLinear
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
impl<ST, DT> CastableFrom<ST, Initialized, Initialized> for DT
impl<ST, DT> CastableFrom<ST, Uninit, Uninit> for DT
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> Instrument for T
impl<T> Instrument for T
Source§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
Source§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§impl<T> Pointable for T
impl<T> Pointable for T
impl<T> Read<Exclusive, BecauseExclusive> for Twhere
T: ?Sized,
Source§impl<SS, SP> SupersetOf<SS> for SPwhere
SS: SubsetOf<SP>,
impl<SS, SP> SupersetOf<SS> for SPwhere
SS: SubsetOf<SP>,
Source§fn to_subset(&self) -> Option<SS>
fn to_subset(&self) -> Option<SS>
self from the equivalent element of its
superset. Read moreSource§fn is_in_subset(&self) -> bool
fn is_in_subset(&self) -> bool
self is actually part of its subset T (and can be converted to it).Source§fn to_subset_unchecked(&self) -> SS
fn to_subset_unchecked(&self) -> SS
self.to_subset but without any property checks. Always succeeds.Source§fn from_subset(element: &SS) -> SP
fn from_subset(element: &SS) -> SP
self to the equivalent element of its superset.