ruvector_attention/info_bottleneck/
kl_divergence.rs

1//! KL Divergence Computations
2//!
3//! Efficient KL divergence for various distributions used in attention.
4
5
6/// Diagonal Gaussian parameters
7#[derive(Debug, Clone)]
8pub struct DiagonalGaussian {
9    /// Mean vector
10    pub mean: Vec<f32>,
11    /// Log variance vector
12    pub log_var: Vec<f32>,
13}
14
15impl DiagonalGaussian {
16    /// Create from mean and log variance
17    pub fn new(mean: Vec<f32>, log_var: Vec<f32>) -> Self {
18        Self { mean, log_var }
19    }
20
21    /// Create unit Gaussian (mean=0, var=1)
22    pub fn unit(dim: usize) -> Self {
23        Self {
24            mean: vec![0.0; dim],
25            log_var: vec![0.0; dim],
26        }
27    }
28
29    /// Sample using reparameterization trick
30    /// z = mean + std * epsilon, where epsilon ~ N(0, 1)
31    pub fn sample(&self, epsilon: &[f32]) -> Vec<f32> {
32        let n = self.mean.len();
33        let mut z = vec![0.0f32; n];
34
35        for i in 0..n {
36            let std = (0.5 * self.log_var[i]).exp();
37            z[i] = self.mean[i] + std * epsilon[i];
38        }
39
40        z
41    }
42
43    /// Get variance
44    pub fn variance(&self) -> Vec<f32> {
45        self.log_var.iter().map(|&lv| lv.exp()).collect()
46    }
47
48    /// Get standard deviation
49    pub fn std(&self) -> Vec<f32> {
50        self.log_var.iter().map(|&lv| (0.5 * lv).exp()).collect()
51    }
52}
53
54/// KL Divergence computations
55#[derive(Debug, Clone)]
56pub struct KLDivergence;
57
58impl KLDivergence {
59    /// KL(N(mu, sigma^2) || N(0, 1))
60    /// = 0.5 * sum(exp(log_var) + mu^2 - 1 - log_var)
61    pub fn gaussian_to_unit(gaussian: &DiagonalGaussian) -> f32 {
62        let n = gaussian.mean.len();
63        let mut kl = 0.0f32;
64
65        for i in 0..n {
66            let mu = gaussian.mean[i];
67            let lv = gaussian.log_var[i];
68            let var = lv.exp();
69            kl += var + mu * mu - 1.0 - lv;
70        }
71
72        0.5 * kl
73    }
74
75    /// KL(N(mu, sigma^2) || N(0, 1)) from separate arrays
76    pub fn gaussian_to_unit_arrays(mean: &[f32], log_var: &[f32]) -> f32 {
77        let n = mean.len().min(log_var.len());
78        let mut kl = 0.0f32;
79
80        for i in 0..n {
81            let mu = mean[i];
82            let lv = log_var[i];
83            let var = lv.exp();
84            kl += var + mu * mu - 1.0 - lv;
85        }
86
87        0.5 * kl
88    }
89
90    /// KL(N(mu1, sigma1^2) || N(mu2, sigma2^2))
91    /// = 0.5 * sum(log(var2/var1) + (var1 + (mu1-mu2)^2)/var2 - 1)
92    pub fn gaussian_to_gaussian(q: &DiagonalGaussian, p: &DiagonalGaussian) -> f32 {
93        let n = q.mean.len().min(p.mean.len());
94        let mut kl = 0.0f32;
95
96        for i in 0..n {
97            let mu_q = q.mean[i];
98            let mu_p = p.mean[i];
99            let lv_q = q.log_var[i];
100            let lv_p = p.log_var[i];
101
102            let var_q = lv_q.exp();
103            let var_p = lv_p.exp().max(1e-8);
104
105            let log_ratio = lv_p - lv_q;
106            let diff = mu_q - mu_p;
107
108            kl += log_ratio + (var_q + diff * diff) / var_p - 1.0;
109        }
110
111        0.5 * kl
112    }
113
114    /// KL divergence between categorical distributions
115    /// KL(p || q) = sum(p * log(p/q))
116    pub fn categorical(p: &[f32], q: &[f32]) -> f32 {
117        let n = p.len().min(q.len());
118        let mut kl = 0.0f32;
119        let eps = 1e-10;
120
121        for i in 0..n {
122            let pi = p[i].max(eps);
123            let qi = q[i].max(eps);
124            if pi > eps {
125                kl += pi * (pi / qi).ln();
126            }
127        }
128
129        kl.max(0.0)
130    }
131
132    /// Symmetric KL (Jensen-Shannon divergence approximation)
133    /// JS(p, q) ≈ 0.5 * (KL(p || m) + KL(q || m)) where m = (p+q)/2
134    pub fn jensen_shannon(p: &[f32], q: &[f32]) -> f32 {
135        let n = p.len().min(q.len());
136        let mut m = vec![0.0f32; n];
137
138        for i in 0..n {
139            m[i] = 0.5 * (p[i] + q[i]);
140        }
141
142        0.5 * (Self::categorical(p, &m) + Self::categorical(q, &m))
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149
150    #[test]
151    fn test_kl_to_unit() {
152        // Unit Gaussian should have KL = 0
153        let unit = DiagonalGaussian::unit(4);
154        let kl = KLDivergence::gaussian_to_unit(&unit);
155        assert!(kl.abs() < 1e-5);
156    }
157
158    #[test]
159    fn test_kl_nonzero() {
160        let g = DiagonalGaussian::new(
161            vec![1.0, 0.5, -0.5],
162            vec![0.5, 0.0, -0.5],
163        );
164        let kl = KLDivergence::gaussian_to_unit(&g);
165        assert!(kl > 0.0);
166    }
167
168    #[test]
169    fn test_kl_arrays() {
170        let mean = vec![0.0, 0.0];
171        let log_var = vec![0.0, 0.0];
172
173        let kl = KLDivergence::gaussian_to_unit_arrays(&mean, &log_var);
174        assert!(kl.abs() < 1e-5);
175    }
176
177    #[test]
178    fn test_categorical_kl() {
179        let p = vec![0.5, 0.5];
180        let q = vec![0.5, 0.5];
181
182        let kl = KLDivergence::categorical(&p, &q);
183        assert!(kl.abs() < 1e-5);
184
185        let q2 = vec![0.9, 0.1];
186        let kl2 = KLDivergence::categorical(&p, &q2);
187        assert!(kl2 > 0.0);
188    }
189
190    #[test]
191    fn test_jensen_shannon() {
192        let p = vec![0.5, 0.5];
193        let q = vec![0.5, 0.5];
194
195        let js = KLDivergence::jensen_shannon(&p, &q);
196        assert!(js.abs() < 1e-5);
197    }
198
199    #[test]
200    fn test_sample() {
201        let g = DiagonalGaussian::new(
202            vec![0.0, 1.0],
203            vec![0.0, 0.0],
204        );
205        let epsilon = vec![0.0, 0.0];
206
207        let z = g.sample(&epsilon);
208        assert!((z[0] - 0.0).abs() < 1e-5);
209        assert!((z[1] - 1.0).abs() < 1e-5);
210    }
211}