ruvector_math/information_geometry/
fisher.rs

1//! Fisher Information Matrix
2//!
3//! The Fisher Information Matrix (FIM) captures the curvature of the log-likelihood
4//! surface and defines the natural metric on statistical manifolds.
5//!
6//! ## Definition
7//!
8//! F(θ) = E[∇log p(x|θ) ∇log p(x|θ)^T]
9//!
10//! For Gaussian distributions with fixed variance:
11//! F(μ) = I/σ² (identity scaled by inverse variance)
12//!
13//! ## Use Cases
14//!
15//! - Natural gradient computation
16//! - Information-theoretic regularization
17//! - Model uncertainty quantification
18
19use crate::error::{MathError, Result};
20use crate::utils::EPS;
21
22/// Fisher Information Matrix calculator
23#[derive(Debug, Clone)]
24pub struct FisherInformation {
25    /// Damping factor for numerical stability
26    damping: f64,
27    /// Number of samples for empirical estimation
28    num_samples: usize,
29}
30
31impl FisherInformation {
32    /// Create a new FIM calculator
33    pub fn new() -> Self {
34        Self {
35            damping: 1e-4,
36            num_samples: 100,
37        }
38    }
39
40    /// Set damping factor (for matrix inversion stability)
41    pub fn with_damping(mut self, damping: f64) -> Self {
42        self.damping = damping.max(EPS);
43        self
44    }
45
46    /// Set number of samples for empirical FIM
47    pub fn with_samples(mut self, num_samples: usize) -> Self {
48        self.num_samples = num_samples.max(1);
49        self
50    }
51
52    /// Compute empirical FIM from gradient samples
53    ///
54    /// F ≈ (1/N) Σᵢ ∇log p(xᵢ|θ) ∇log p(xᵢ|θ)^T
55    ///
56    /// # Arguments
57    /// * `gradients` - Sample gradients, each of length d
58    pub fn empirical_fim(&self, gradients: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
59        if gradients.is_empty() {
60            return Err(MathError::empty_input("gradients"));
61        }
62
63        let d = gradients[0].len();
64        if d == 0 {
65            return Err(MathError::empty_input("gradient dimension"));
66        }
67
68        let n = gradients.len() as f64;
69
70        // F = (1/n) Σ g gᵀ
71        let mut fim = vec![vec![0.0; d]; d];
72
73        for grad in gradients {
74            if grad.len() != d {
75                return Err(MathError::dimension_mismatch(d, grad.len()));
76            }
77
78            for i in 0..d {
79                for j in 0..d {
80                    fim[i][j] += grad[i] * grad[j] / n;
81                }
82            }
83        }
84
85        // Add damping for stability
86        for i in 0..d {
87            fim[i][i] += self.damping;
88        }
89
90        Ok(fim)
91    }
92
93    /// Compute diagonal FIM approximation (much faster)
94    ///
95    /// Only computes diagonal: F_ii ≈ (1/N) Σₙ (∂log p / ∂θᵢ)²
96    pub fn diagonal_fim(&self, gradients: &[Vec<f64>]) -> Result<Vec<f64>> {
97        if gradients.is_empty() {
98            return Err(MathError::empty_input("gradients"));
99        }
100
101        let d = gradients[0].len();
102        let n = gradients.len() as f64;
103
104        let mut diag = vec![0.0; d];
105
106        for grad in gradients {
107            if grad.len() != d {
108                return Err(MathError::dimension_mismatch(d, grad.len()));
109            }
110
111            for (i, &g) in grad.iter().enumerate() {
112                diag[i] += g * g / n;
113            }
114        }
115
116        // Add damping
117        for d_i in &mut diag {
118            *d_i += self.damping;
119        }
120
121        Ok(diag)
122    }
123
124    /// Compute FIM for Gaussian distribution with known variance
125    ///
126    /// For N(μ, σ²I): F(μ) = I/σ²
127    pub fn gaussian_fim(&self, dim: usize, variance: f64) -> Vec<Vec<f64>> {
128        let scale = 1.0 / (variance + self.damping);
129        let mut fim = vec![vec![0.0; dim]; dim];
130        for i in 0..dim {
131            fim[i][i] = scale;
132        }
133        fim
134    }
135
136    /// Compute FIM for categorical distribution
137    ///
138    /// For categorical p = (p₁, ..., pₖ): F_ij = δᵢⱼ/pᵢ - 1
139    pub fn categorical_fim(&self, probabilities: &[f64]) -> Result<Vec<Vec<f64>>> {
140        let k = probabilities.len();
141        if k == 0 {
142            return Err(MathError::empty_input("probabilities"));
143        }
144
145        let mut fim = vec![vec![-1.0; k]; k]; // Off-diagonal = -1
146
147        for (i, &pi) in probabilities.iter().enumerate() {
148            let safe_pi = pi.max(EPS);
149            fim[i][i] = 1.0 / safe_pi - 1.0 + self.damping;
150        }
151
152        Ok(fim)
153    }
154
155    /// Invert FIM using Cholesky decomposition
156    ///
157    /// Returns F⁻¹ for natural gradient computation
158    pub fn invert_fim(&self, fim: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
159        let n = fim.len();
160        if n == 0 {
161            return Err(MathError::empty_input("FIM"));
162        }
163
164        // Cholesky decomposition: F = LLᵀ
165        let mut l = vec![vec![0.0; n]; n];
166
167        for i in 0..n {
168            for j in 0..=i {
169                let mut sum = fim[i][j];
170
171                for k in 0..j {
172                    sum -= l[i][k] * l[j][k];
173                }
174
175                if i == j {
176                    if sum <= 0.0 {
177                        // Matrix not positive definite
178                        return Err(MathError::numerical_instability(
179                            "FIM not positive definite",
180                        ));
181                    }
182                    l[i][j] = sum.sqrt();
183                } else {
184                    l[i][j] = sum / l[j][j];
185                }
186            }
187        }
188
189        // Forward substitution to get L⁻¹
190        let mut l_inv = vec![vec![0.0; n]; n];
191        for i in 0..n {
192            l_inv[i][i] = 1.0 / l[i][i];
193            for j in (i + 1)..n {
194                let mut sum = 0.0;
195                for k in i..j {
196                    sum -= l[j][k] * l_inv[k][i];
197                }
198                l_inv[j][i] = sum / l[j][j];
199            }
200        }
201
202        // F⁻¹ = (LLᵀ)⁻¹ = L⁻ᵀ L⁻¹
203        let mut fim_inv = vec![vec![0.0; n]; n];
204        for i in 0..n {
205            for j in 0..n {
206                for k in 0..n {
207                    fim_inv[i][j] += l_inv[k][i] * l_inv[k][j];
208                }
209            }
210        }
211
212        Ok(fim_inv)
213    }
214
215    /// Compute natural gradient: F⁻¹ ∇L
216    pub fn natural_gradient(
217        &self,
218        fim: &[Vec<f64>],
219        gradient: &[f64],
220    ) -> Result<Vec<f64>> {
221        let fim_inv = self.invert_fim(fim)?;
222        let n = gradient.len();
223
224        if fim_inv.len() != n {
225            return Err(MathError::dimension_mismatch(n, fim_inv.len()));
226        }
227
228        let mut nat_grad = vec![0.0; n];
229        for i in 0..n {
230            for j in 0..n {
231                nat_grad[i] += fim_inv[i][j] * gradient[j];
232            }
233        }
234
235        Ok(nat_grad)
236    }
237}
238
239impl Default for FisherInformation {
240    fn default() -> Self {
241        Self::new()
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    #[test]
250    fn test_empirical_fim() {
251        let fisher = FisherInformation::new().with_damping(0.0);
252
253        // Simple gradients
254        let grads = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
255
256        let fim = fisher.empirical_fim(&grads).unwrap();
257
258        // Expected: [[2/3, 1/3], [1/3, 2/3]] + small damping
259        assert!((fim[0][0] - 2.0 / 3.0).abs() < 1e-6);
260        assert!((fim[1][1] - 2.0 / 3.0).abs() < 1e-6);
261        assert!((fim[0][1] - 1.0 / 3.0).abs() < 1e-6);
262    }
263
264    #[test]
265    fn test_gaussian_fim() {
266        let fisher = FisherInformation::new().with_damping(0.0);
267        let fim = fisher.gaussian_fim(3, 0.5);
268
269        // F = I / 0.5 = 2I (plus small damping on diagonal)
270        assert!((fim[0][0] - 2.0).abs() < 1e-6);
271        assert!((fim[1][1] - 2.0).abs() < 1e-6);
272        assert!(fim[0][1].abs() < 1e-6);
273    }
274
275    #[test]
276    fn test_fim_inversion() {
277        let fisher = FisherInformation::new();
278
279        // Identity matrix
280        let fim = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
281
282        let fim_inv = fisher.invert_fim(&fim).unwrap();
283
284        // Inverse of identity is identity
285        assert!((fim_inv[0][0] - 1.0).abs() < 1e-6);
286        assert!((fim_inv[1][1] - 1.0).abs() < 1e-6);
287    }
288
289    #[test]
290    fn test_natural_gradient() {
291        let fisher = FisherInformation::new().with_damping(0.0);
292
293        // F = 2I
294        let fim = vec![vec![2.0, 0.0], vec![0.0, 2.0]];
295        let grad = vec![4.0, 6.0];
296
297        let nat_grad = fisher.natural_gradient(&fim, &grad).unwrap();
298
299        // nat_grad = F⁻¹ grad = (1/2) grad
300        assert!((nat_grad[0] - 2.0).abs() < 1e-6);
301        assert!((nat_grad[1] - 3.0).abs() < 1e-6);
302    }
303}