ruvector_attention/info_bottleneck/
kl_divergence.rs1#[derive(Debug, Clone)]
8pub struct DiagonalGaussian {
9 pub mean: Vec<f32>,
11 pub log_var: Vec<f32>,
13}
14
15impl DiagonalGaussian {
16 pub fn new(mean: Vec<f32>, log_var: Vec<f32>) -> Self {
18 Self { mean, log_var }
19 }
20
21 pub fn unit(dim: usize) -> Self {
23 Self {
24 mean: vec![0.0; dim],
25 log_var: vec![0.0; dim],
26 }
27 }
28
29 pub fn sample(&self, epsilon: &[f32]) -> Vec<f32> {
32 let n = self.mean.len();
33 let mut z = vec![0.0f32; n];
34
35 for i in 0..n {
36 let std = (0.5 * self.log_var[i]).exp();
37 z[i] = self.mean[i] + std * epsilon[i];
38 }
39
40 z
41 }
42
43 pub fn variance(&self) -> Vec<f32> {
45 self.log_var.iter().map(|&lv| lv.exp()).collect()
46 }
47
48 pub fn std(&self) -> Vec<f32> {
50 self.log_var.iter().map(|&lv| (0.5 * lv).exp()).collect()
51 }
52}
53
54#[derive(Debug, Clone)]
56pub struct KLDivergence;
57
58impl KLDivergence {
59 pub fn gaussian_to_unit(gaussian: &DiagonalGaussian) -> f32 {
62 let n = gaussian.mean.len();
63 let mut kl = 0.0f32;
64
65 for i in 0..n {
66 let mu = gaussian.mean[i];
67 let lv = gaussian.log_var[i];
68 let var = lv.exp();
69 kl += var + mu * mu - 1.0 - lv;
70 }
71
72 0.5 * kl
73 }
74
75 pub fn gaussian_to_unit_arrays(mean: &[f32], log_var: &[f32]) -> f32 {
77 let n = mean.len().min(log_var.len());
78 let mut kl = 0.0f32;
79
80 for i in 0..n {
81 let mu = mean[i];
82 let lv = log_var[i];
83 let var = lv.exp();
84 kl += var + mu * mu - 1.0 - lv;
85 }
86
87 0.5 * kl
88 }
89
90 pub fn gaussian_to_gaussian(q: &DiagonalGaussian, p: &DiagonalGaussian) -> f32 {
93 let n = q.mean.len().min(p.mean.len());
94 let mut kl = 0.0f32;
95
96 for i in 0..n {
97 let mu_q = q.mean[i];
98 let mu_p = p.mean[i];
99 let lv_q = q.log_var[i];
100 let lv_p = p.log_var[i];
101
102 let var_q = lv_q.exp();
103 let var_p = lv_p.exp().max(1e-8);
104
105 let log_ratio = lv_p - lv_q;
106 let diff = mu_q - mu_p;
107
108 kl += log_ratio + (var_q + diff * diff) / var_p - 1.0;
109 }
110
111 0.5 * kl
112 }
113
114 pub fn categorical(p: &[f32], q: &[f32]) -> f32 {
117 let n = p.len().min(q.len());
118 let mut kl = 0.0f32;
119 let eps = 1e-10;
120
121 for i in 0..n {
122 let pi = p[i].max(eps);
123 let qi = q[i].max(eps);
124 if pi > eps {
125 kl += pi * (pi / qi).ln();
126 }
127 }
128
129 kl.max(0.0)
130 }
131
132 pub fn jensen_shannon(p: &[f32], q: &[f32]) -> f32 {
135 let n = p.len().min(q.len());
136 let mut m = vec![0.0f32; n];
137
138 for i in 0..n {
139 m[i] = 0.5 * (p[i] + q[i]);
140 }
141
142 0.5 * (Self::categorical(p, &m) + Self::categorical(q, &m))
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149
150 #[test]
151 fn test_kl_to_unit() {
152 let unit = DiagonalGaussian::unit(4);
154 let kl = KLDivergence::gaussian_to_unit(&unit);
155 assert!(kl.abs() < 1e-5);
156 }
157
158 #[test]
159 fn test_kl_nonzero() {
160 let g = DiagonalGaussian::new(
161 vec![1.0, 0.5, -0.5],
162 vec![0.5, 0.0, -0.5],
163 );
164 let kl = KLDivergence::gaussian_to_unit(&g);
165 assert!(kl > 0.0);
166 }
167
168 #[test]
169 fn test_kl_arrays() {
170 let mean = vec![0.0, 0.0];
171 let log_var = vec![0.0, 0.0];
172
173 let kl = KLDivergence::gaussian_to_unit_arrays(&mean, &log_var);
174 assert!(kl.abs() < 1e-5);
175 }
176
177 #[test]
178 fn test_categorical_kl() {
179 let p = vec![0.5, 0.5];
180 let q = vec![0.5, 0.5];
181
182 let kl = KLDivergence::categorical(&p, &q);
183 assert!(kl.abs() < 1e-5);
184
185 let q2 = vec![0.9, 0.1];
186 let kl2 = KLDivergence::categorical(&p, &q2);
187 assert!(kl2 > 0.0);
188 }
189
190 #[test]
191 fn test_jensen_shannon() {
192 let p = vec![0.5, 0.5];
193 let q = vec![0.5, 0.5];
194
195 let js = KLDivergence::jensen_shannon(&p, &q);
196 assert!(js.abs() < 1e-5);
197 }
198
199 #[test]
200 fn test_sample() {
201 let g = DiagonalGaussian::new(
202 vec![0.0, 1.0],
203 vec![0.0, 0.0],
204 );
205 let epsilon = vec![0.0, 0.0];
206
207 let z = g.sample(&epsilon);
208 assert!((z[0] - 0.0).abs() < 1e-5);
209 assert!((z[1] - 1.0).abs() < 1e-5);
210 }
211}