Skip to main content

ruvector_attention/info_bottleneck/
kl_divergence.rs

1//! KL Divergence Computations
2//!
3//! Efficient KL divergence for various distributions used in attention.
4
5/// Diagonal Gaussian parameters
6#[derive(Debug, Clone)]
7pub struct DiagonalGaussian {
8    /// Mean vector
9    pub mean: Vec<f32>,
10    /// Log variance vector
11    pub log_var: Vec<f32>,
12}
13
14impl DiagonalGaussian {
15    /// Create from mean and log variance
16    pub fn new(mean: Vec<f32>, log_var: Vec<f32>) -> Self {
17        Self { mean, log_var }
18    }
19
20    /// Create unit Gaussian (mean=0, var=1)
21    pub fn unit(dim: usize) -> Self {
22        Self {
23            mean: vec![0.0; dim],
24            log_var: vec![0.0; dim],
25        }
26    }
27
28    /// Sample using reparameterization trick
29    /// z = mean + std * epsilon, where epsilon ~ N(0, 1)
30    pub fn sample(&self, epsilon: &[f32]) -> Vec<f32> {
31        let n = self.mean.len();
32        let mut z = vec![0.0f32; n];
33
34        for i in 0..n {
35            let std = (0.5 * self.log_var[i]).exp();
36            z[i] = self.mean[i] + std * epsilon[i];
37        }
38
39        z
40    }
41
42    /// Get variance
43    pub fn variance(&self) -> Vec<f32> {
44        self.log_var.iter().map(|&lv| lv.exp()).collect()
45    }
46
47    /// Get standard deviation
48    pub fn std(&self) -> Vec<f32> {
49        self.log_var.iter().map(|&lv| (0.5 * lv).exp()).collect()
50    }
51}
52
53/// KL Divergence computations
54#[derive(Debug, Clone)]
55pub struct KLDivergence;
56
57impl KLDivergence {
58    /// KL(N(mu, sigma^2) || N(0, 1))
59    /// = 0.5 * sum(exp(log_var) + mu^2 - 1 - log_var)
60    pub fn gaussian_to_unit(gaussian: &DiagonalGaussian) -> f32 {
61        let n = gaussian.mean.len();
62        let mut kl = 0.0f32;
63
64        for i in 0..n {
65            let mu = gaussian.mean[i];
66            let lv = gaussian.log_var[i];
67            let var = lv.exp();
68            kl += var + mu * mu - 1.0 - lv;
69        }
70
71        0.5 * kl
72    }
73
74    /// KL(N(mu, sigma^2) || N(0, 1)) from separate arrays
75    pub fn gaussian_to_unit_arrays(mean: &[f32], log_var: &[f32]) -> f32 {
76        let n = mean.len().min(log_var.len());
77        let mut kl = 0.0f32;
78
79        for i in 0..n {
80            let mu = mean[i];
81            let lv = log_var[i];
82            let var = lv.exp();
83            kl += var + mu * mu - 1.0 - lv;
84        }
85
86        0.5 * kl
87    }
88
89    /// KL(N(mu1, sigma1^2) || N(mu2, sigma2^2))
90    /// = 0.5 * sum(log(var2/var1) + (var1 + (mu1-mu2)^2)/var2 - 1)
91    pub fn gaussian_to_gaussian(q: &DiagonalGaussian, p: &DiagonalGaussian) -> f32 {
92        let n = q.mean.len().min(p.mean.len());
93        let mut kl = 0.0f32;
94
95        for i in 0..n {
96            let mu_q = q.mean[i];
97            let mu_p = p.mean[i];
98            let lv_q = q.log_var[i];
99            let lv_p = p.log_var[i];
100
101            let var_q = lv_q.exp();
102            let var_p = lv_p.exp().max(1e-8);
103
104            let log_ratio = lv_p - lv_q;
105            let diff = mu_q - mu_p;
106
107            kl += log_ratio + (var_q + diff * diff) / var_p - 1.0;
108        }
109
110        0.5 * kl
111    }
112
113    /// KL divergence between categorical distributions
114    /// KL(p || q) = sum(p * log(p/q))
115    pub fn categorical(p: &[f32], q: &[f32]) -> f32 {
116        let n = p.len().min(q.len());
117        let mut kl = 0.0f32;
118        let eps = 1e-10;
119
120        for i in 0..n {
121            let pi = p[i].max(eps);
122            let qi = q[i].max(eps);
123            if pi > eps {
124                kl += pi * (pi / qi).ln();
125            }
126        }
127
128        kl.max(0.0)
129    }
130
131    /// Symmetric KL (Jensen-Shannon divergence approximation)
132    /// JS(p, q) ≈ 0.5 * (KL(p || m) + KL(q || m)) where m = (p+q)/2
133    pub fn jensen_shannon(p: &[f32], q: &[f32]) -> f32 {
134        let n = p.len().min(q.len());
135        let mut m = vec![0.0f32; n];
136
137        for i in 0..n {
138            m[i] = 0.5 * (p[i] + q[i]);
139        }
140
141        0.5 * (Self::categorical(p, &m) + Self::categorical(q, &m))
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148
149    #[test]
150    fn test_kl_to_unit() {
151        // Unit Gaussian should have KL = 0
152        let unit = DiagonalGaussian::unit(4);
153        let kl = KLDivergence::gaussian_to_unit(&unit);
154        assert!(kl.abs() < 1e-5);
155    }
156
157    #[test]
158    fn test_kl_nonzero() {
159        let g = DiagonalGaussian::new(vec![1.0, 0.5, -0.5], vec![0.5, 0.0, -0.5]);
160        let kl = KLDivergence::gaussian_to_unit(&g);
161        assert!(kl > 0.0);
162    }
163
164    #[test]
165    fn test_kl_arrays() {
166        let mean = vec![0.0, 0.0];
167        let log_var = vec![0.0, 0.0];
168
169        let kl = KLDivergence::gaussian_to_unit_arrays(&mean, &log_var);
170        assert!(kl.abs() < 1e-5);
171    }
172
173    #[test]
174    fn test_categorical_kl() {
175        let p = vec![0.5, 0.5];
176        let q = vec![0.5, 0.5];
177
178        let kl = KLDivergence::categorical(&p, &q);
179        assert!(kl.abs() < 1e-5);
180
181        let q2 = vec![0.9, 0.1];
182        let kl2 = KLDivergence::categorical(&p, &q2);
183        assert!(kl2 > 0.0);
184    }
185
186    #[test]
187    fn test_jensen_shannon() {
188        let p = vec![0.5, 0.5];
189        let q = vec![0.5, 0.5];
190
191        let js = KLDivergence::jensen_shannon(&p, &q);
192        assert!(js.abs() < 1e-5);
193    }
194
195    #[test]
196    fn test_sample() {
197        let g = DiagonalGaussian::new(vec![0.0, 1.0], vec![0.0, 0.0]);
198        let epsilon = vec![0.0, 0.0];
199
200        let z = g.sample(&epsilon);
201        assert!((z[0] - 0.0).abs() < 1e-5);
202        assert!((z[1] - 1.0).abs() < 1e-5);
203    }
204}