ruvector_math/information_geometry/
fisher.rs1use crate::error::{MathError, Result};
20use crate::utils::EPS;
21
22#[derive(Debug, Clone)]
24pub struct FisherInformation {
25 damping: f64,
27 num_samples: usize,
29}
30
31impl FisherInformation {
32 pub fn new() -> Self {
34 Self {
35 damping: 1e-4,
36 num_samples: 100,
37 }
38 }
39
40 pub fn with_damping(mut self, damping: f64) -> Self {
42 self.damping = damping.max(EPS);
43 self
44 }
45
46 pub fn with_samples(mut self, num_samples: usize) -> Self {
48 self.num_samples = num_samples.max(1);
49 self
50 }
51
52 pub fn empirical_fim(&self, gradients: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
59 if gradients.is_empty() {
60 return Err(MathError::empty_input("gradients"));
61 }
62
63 let d = gradients[0].len();
64 if d == 0 {
65 return Err(MathError::empty_input("gradient dimension"));
66 }
67
68 let n = gradients.len() as f64;
69
70 let mut fim = vec![vec![0.0; d]; d];
72
73 for grad in gradients {
74 if grad.len() != d {
75 return Err(MathError::dimension_mismatch(d, grad.len()));
76 }
77
78 for i in 0..d {
79 for j in 0..d {
80 fim[i][j] += grad[i] * grad[j] / n;
81 }
82 }
83 }
84
85 for i in 0..d {
87 fim[i][i] += self.damping;
88 }
89
90 Ok(fim)
91 }
92
93 pub fn diagonal_fim(&self, gradients: &[Vec<f64>]) -> Result<Vec<f64>> {
97 if gradients.is_empty() {
98 return Err(MathError::empty_input("gradients"));
99 }
100
101 let d = gradients[0].len();
102 let n = gradients.len() as f64;
103
104 let mut diag = vec![0.0; d];
105
106 for grad in gradients {
107 if grad.len() != d {
108 return Err(MathError::dimension_mismatch(d, grad.len()));
109 }
110
111 for (i, &g) in grad.iter().enumerate() {
112 diag[i] += g * g / n;
113 }
114 }
115
116 for d_i in &mut diag {
118 *d_i += self.damping;
119 }
120
121 Ok(diag)
122 }
123
124 pub fn gaussian_fim(&self, dim: usize, variance: f64) -> Vec<Vec<f64>> {
128 let scale = 1.0 / (variance + self.damping);
129 let mut fim = vec![vec![0.0; dim]; dim];
130 for i in 0..dim {
131 fim[i][i] = scale;
132 }
133 fim
134 }
135
136 pub fn categorical_fim(&self, probabilities: &[f64]) -> Result<Vec<Vec<f64>>> {
140 let k = probabilities.len();
141 if k == 0 {
142 return Err(MathError::empty_input("probabilities"));
143 }
144
145 let mut fim = vec![vec![-1.0; k]; k]; for (i, &pi) in probabilities.iter().enumerate() {
148 let safe_pi = pi.max(EPS);
149 fim[i][i] = 1.0 / safe_pi - 1.0 + self.damping;
150 }
151
152 Ok(fim)
153 }
154
155 pub fn invert_fim(&self, fim: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
159 let n = fim.len();
160 if n == 0 {
161 return Err(MathError::empty_input("FIM"));
162 }
163
164 let mut l = vec![vec![0.0; n]; n];
166
167 for i in 0..n {
168 for j in 0..=i {
169 let mut sum = fim[i][j];
170
171 for k in 0..j {
172 sum -= l[i][k] * l[j][k];
173 }
174
175 if i == j {
176 if sum <= 0.0 {
177 return Err(MathError::numerical_instability(
179 "FIM not positive definite",
180 ));
181 }
182 l[i][j] = sum.sqrt();
183 } else {
184 l[i][j] = sum / l[j][j];
185 }
186 }
187 }
188
189 let mut l_inv = vec![vec![0.0; n]; n];
191 for i in 0..n {
192 l_inv[i][i] = 1.0 / l[i][i];
193 for j in (i + 1)..n {
194 let mut sum = 0.0;
195 for k in i..j {
196 sum -= l[j][k] * l_inv[k][i];
197 }
198 l_inv[j][i] = sum / l[j][j];
199 }
200 }
201
202 let mut fim_inv = vec![vec![0.0; n]; n];
204 for i in 0..n {
205 for j in 0..n {
206 for k in 0..n {
207 fim_inv[i][j] += l_inv[k][i] * l_inv[k][j];
208 }
209 }
210 }
211
212 Ok(fim_inv)
213 }
214
215 pub fn natural_gradient(
217 &self,
218 fim: &[Vec<f64>],
219 gradient: &[f64],
220 ) -> Result<Vec<f64>> {
221 let fim_inv = self.invert_fim(fim)?;
222 let n = gradient.len();
223
224 if fim_inv.len() != n {
225 return Err(MathError::dimension_mismatch(n, fim_inv.len()));
226 }
227
228 let mut nat_grad = vec![0.0; n];
229 for i in 0..n {
230 for j in 0..n {
231 nat_grad[i] += fim_inv[i][j] * gradient[j];
232 }
233 }
234
235 Ok(nat_grad)
236 }
237}
238
239impl Default for FisherInformation {
240 fn default() -> Self {
241 Self::new()
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248
249 #[test]
250 fn test_empirical_fim() {
251 let fisher = FisherInformation::new().with_damping(0.0);
252
253 let grads = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
255
256 let fim = fisher.empirical_fim(&grads).unwrap();
257
258 assert!((fim[0][0] - 2.0 / 3.0).abs() < 1e-6);
260 assert!((fim[1][1] - 2.0 / 3.0).abs() < 1e-6);
261 assert!((fim[0][1] - 1.0 / 3.0).abs() < 1e-6);
262 }
263
264 #[test]
265 fn test_gaussian_fim() {
266 let fisher = FisherInformation::new().with_damping(0.0);
267 let fim = fisher.gaussian_fim(3, 0.5);
268
269 assert!((fim[0][0] - 2.0).abs() < 1e-6);
271 assert!((fim[1][1] - 2.0).abs() < 1e-6);
272 assert!(fim[0][1].abs() < 1e-6);
273 }
274
275 #[test]
276 fn test_fim_inversion() {
277 let fisher = FisherInformation::new();
278
279 let fim = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
281
282 let fim_inv = fisher.invert_fim(&fim).unwrap();
283
284 assert!((fim_inv[0][0] - 1.0).abs() < 1e-6);
286 assert!((fim_inv[1][1] - 1.0).abs() < 1e-6);
287 }
288
289 #[test]
290 fn test_natural_gradient() {
291 let fisher = FisherInformation::new().with_damping(0.0);
292
293 let fim = vec![vec![2.0, 0.0], vec![0.0, 2.0]];
295 let grad = vec![4.0, 6.0];
296
297 let nat_grad = fisher.natural_gradient(&fim, &grad).unwrap();
298
299 assert!((nat_grad[0] - 2.0).abs() < 1e-6);
301 assert!((nat_grad[1] - 3.0).abs() < 1e-6);
302 }
303}