ruvector_attention/info_bottleneck/
kl_divergence.rs1#[derive(Debug, Clone)]
7pub struct DiagonalGaussian {
8 pub mean: Vec<f32>,
10 pub log_var: Vec<f32>,
12}
13
14impl DiagonalGaussian {
15 pub fn new(mean: Vec<f32>, log_var: Vec<f32>) -> Self {
17 Self { mean, log_var }
18 }
19
20 pub fn unit(dim: usize) -> Self {
22 Self {
23 mean: vec![0.0; dim],
24 log_var: vec![0.0; dim],
25 }
26 }
27
28 pub fn sample(&self, epsilon: &[f32]) -> Vec<f32> {
31 let n = self.mean.len();
32 let mut z = vec![0.0f32; n];
33
34 for i in 0..n {
35 let std = (0.5 * self.log_var[i]).exp();
36 z[i] = self.mean[i] + std * epsilon[i];
37 }
38
39 z
40 }
41
42 pub fn variance(&self) -> Vec<f32> {
44 self.log_var.iter().map(|&lv| lv.exp()).collect()
45 }
46
47 pub fn std(&self) -> Vec<f32> {
49 self.log_var.iter().map(|&lv| (0.5 * lv).exp()).collect()
50 }
51}
52
53#[derive(Debug, Clone)]
55pub struct KLDivergence;
56
57impl KLDivergence {
58 pub fn gaussian_to_unit(gaussian: &DiagonalGaussian) -> f32 {
61 let n = gaussian.mean.len();
62 let mut kl = 0.0f32;
63
64 for i in 0..n {
65 let mu = gaussian.mean[i];
66 let lv = gaussian.log_var[i];
67 let var = lv.exp();
68 kl += var + mu * mu - 1.0 - lv;
69 }
70
71 0.5 * kl
72 }
73
74 pub fn gaussian_to_unit_arrays(mean: &[f32], log_var: &[f32]) -> f32 {
76 let n = mean.len().min(log_var.len());
77 let mut kl = 0.0f32;
78
79 for i in 0..n {
80 let mu = mean[i];
81 let lv = log_var[i];
82 let var = lv.exp();
83 kl += var + mu * mu - 1.0 - lv;
84 }
85
86 0.5 * kl
87 }
88
89 pub fn gaussian_to_gaussian(q: &DiagonalGaussian, p: &DiagonalGaussian) -> f32 {
92 let n = q.mean.len().min(p.mean.len());
93 let mut kl = 0.0f32;
94
95 for i in 0..n {
96 let mu_q = q.mean[i];
97 let mu_p = p.mean[i];
98 let lv_q = q.log_var[i];
99 let lv_p = p.log_var[i];
100
101 let var_q = lv_q.exp();
102 let var_p = lv_p.exp().max(1e-8);
103
104 let log_ratio = lv_p - lv_q;
105 let diff = mu_q - mu_p;
106
107 kl += log_ratio + (var_q + diff * diff) / var_p - 1.0;
108 }
109
110 0.5 * kl
111 }
112
113 pub fn categorical(p: &[f32], q: &[f32]) -> f32 {
116 let n = p.len().min(q.len());
117 let mut kl = 0.0f32;
118 let eps = 1e-10;
119
120 for i in 0..n {
121 let pi = p[i].max(eps);
122 let qi = q[i].max(eps);
123 if pi > eps {
124 kl += pi * (pi / qi).ln();
125 }
126 }
127
128 kl.max(0.0)
129 }
130
131 pub fn jensen_shannon(p: &[f32], q: &[f32]) -> f32 {
134 let n = p.len().min(q.len());
135 let mut m = vec![0.0f32; n];
136
137 for i in 0..n {
138 m[i] = 0.5 * (p[i] + q[i]);
139 }
140
141 0.5 * (Self::categorical(p, &m) + Self::categorical(q, &m))
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148
149 #[test]
150 fn test_kl_to_unit() {
151 let unit = DiagonalGaussian::unit(4);
153 let kl = KLDivergence::gaussian_to_unit(&unit);
154 assert!(kl.abs() < 1e-5);
155 }
156
157 #[test]
158 fn test_kl_nonzero() {
159 let g = DiagonalGaussian::new(vec![1.0, 0.5, -0.5], vec![0.5, 0.0, -0.5]);
160 let kl = KLDivergence::gaussian_to_unit(&g);
161 assert!(kl > 0.0);
162 }
163
164 #[test]
165 fn test_kl_arrays() {
166 let mean = vec![0.0, 0.0];
167 let log_var = vec![0.0, 0.0];
168
169 let kl = KLDivergence::gaussian_to_unit_arrays(&mean, &log_var);
170 assert!(kl.abs() < 1e-5);
171 }
172
173 #[test]
174 fn test_categorical_kl() {
175 let p = vec![0.5, 0.5];
176 let q = vec![0.5, 0.5];
177
178 let kl = KLDivergence::categorical(&p, &q);
179 assert!(kl.abs() < 1e-5);
180
181 let q2 = vec![0.9, 0.1];
182 let kl2 = KLDivergence::categorical(&p, &q2);
183 assert!(kl2 > 0.0);
184 }
185
186 #[test]
187 fn test_jensen_shannon() {
188 let p = vec![0.5, 0.5];
189 let q = vec![0.5, 0.5];
190
191 let js = KLDivergence::jensen_shannon(&p, &q);
192 assert!(js.abs() < 1e-5);
193 }
194
195 #[test]
196 fn test_sample() {
197 let g = DiagonalGaussian::new(vec![0.0, 1.0], vec![0.0, 0.0]);
198 let epsilon = vec![0.0, 0.0];
199
200 let z = g.sample(&epsilon);
201 assert!((z[0] - 0.0).abs() < 1e-5);
202 assert!((z[1] - 1.0).abs() < 1e-5);
203 }
204}