Skip to main content

ruvector_math/information_geometry/
fisher.rs

1//! Fisher Information Matrix
2//!
3//! The Fisher Information Matrix (FIM) captures the curvature of the log-likelihood
4//! surface and defines the natural metric on statistical manifolds.
5//!
6//! ## Definition
7//!
8//! F(θ) = E[∇log p(x|θ) ∇log p(x|θ)^T]
9//!
10//! For Gaussian distributions with fixed variance:
11//! F(μ) = I/σ² (identity scaled by inverse variance)
12//!
13//! ## Use Cases
14//!
15//! - Natural gradient computation
16//! - Information-theoretic regularization
17//! - Model uncertainty quantification
18
19use crate::error::{MathError, Result};
20use crate::utils::EPS;
21
22/// Fisher Information Matrix calculator
23#[derive(Debug, Clone)]
24pub struct FisherInformation {
25    /// Damping factor for numerical stability
26    damping: f64,
27    /// Number of samples for empirical estimation
28    num_samples: usize,
29}
30
31impl FisherInformation {
32    /// Create a new FIM calculator
33    pub fn new() -> Self {
34        Self {
35            damping: 1e-4,
36            num_samples: 100,
37        }
38    }
39
40    /// Set damping factor (for matrix inversion stability)
41    pub fn with_damping(mut self, damping: f64) -> Self {
42        self.damping = damping.max(EPS);
43        self
44    }
45
46    /// Set number of samples for empirical FIM
47    pub fn with_samples(mut self, num_samples: usize) -> Self {
48        self.num_samples = num_samples.max(1);
49        self
50    }
51
52    /// Compute empirical FIM from gradient samples
53    ///
54    /// F ≈ (1/N) Σᵢ ∇log p(xᵢ|θ) ∇log p(xᵢ|θ)^T
55    ///
56    /// # Arguments
57    /// * `gradients` - Sample gradients, each of length d
58    pub fn empirical_fim(&self, gradients: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
59        if gradients.is_empty() {
60            return Err(MathError::empty_input("gradients"));
61        }
62
63        let d = gradients[0].len();
64        if d == 0 {
65            return Err(MathError::empty_input("gradient dimension"));
66        }
67
68        let n = gradients.len() as f64;
69
70        // F = (1/n) Σ g gᵀ
71        let mut fim = vec![vec![0.0; d]; d];
72
73        for grad in gradients {
74            if grad.len() != d {
75                return Err(MathError::dimension_mismatch(d, grad.len()));
76            }
77
78            for i in 0..d {
79                for j in 0..d {
80                    fim[i][j] += grad[i] * grad[j] / n;
81                }
82            }
83        }
84
85        // Add damping for stability
86        for i in 0..d {
87            fim[i][i] += self.damping;
88        }
89
90        Ok(fim)
91    }
92
93    /// Compute diagonal FIM approximation (much faster)
94    ///
95    /// Only computes diagonal: F_ii ≈ (1/N) Σₙ (∂log p / ∂θᵢ)²
96    pub fn diagonal_fim(&self, gradients: &[Vec<f64>]) -> Result<Vec<f64>> {
97        if gradients.is_empty() {
98            return Err(MathError::empty_input("gradients"));
99        }
100
101        let d = gradients[0].len();
102        let n = gradients.len() as f64;
103
104        let mut diag = vec![0.0; d];
105
106        for grad in gradients {
107            if grad.len() != d {
108                return Err(MathError::dimension_mismatch(d, grad.len()));
109            }
110
111            for (i, &g) in grad.iter().enumerate() {
112                diag[i] += g * g / n;
113            }
114        }
115
116        // Add damping
117        for d_i in &mut diag {
118            *d_i += self.damping;
119        }
120
121        Ok(diag)
122    }
123
124    /// Compute FIM for Gaussian distribution with known variance
125    ///
126    /// For N(μ, σ²I): F(μ) = I/σ²
127    pub fn gaussian_fim(&self, dim: usize, variance: f64) -> Vec<Vec<f64>> {
128        let scale = 1.0 / (variance + self.damping);
129        let mut fim = vec![vec![0.0; dim]; dim];
130        for i in 0..dim {
131            fim[i][i] = scale;
132        }
133        fim
134    }
135
136    /// Compute FIM for categorical distribution
137    ///
138    /// For categorical p = (p₁, ..., pₖ): F_ij = δᵢⱼ/pᵢ - 1
139    pub fn categorical_fim(&self, probabilities: &[f64]) -> Result<Vec<Vec<f64>>> {
140        let k = probabilities.len();
141        if k == 0 {
142            return Err(MathError::empty_input("probabilities"));
143        }
144
145        let mut fim = vec![vec![-1.0; k]; k]; // Off-diagonal = -1
146
147        for (i, &pi) in probabilities.iter().enumerate() {
148            let safe_pi = pi.max(EPS);
149            fim[i][i] = 1.0 / safe_pi - 1.0 + self.damping;
150        }
151
152        Ok(fim)
153    }
154
155    /// Invert FIM using Cholesky decomposition
156    ///
157    /// Returns F⁻¹ for natural gradient computation
158    pub fn invert_fim(&self, fim: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
159        let n = fim.len();
160        if n == 0 {
161            return Err(MathError::empty_input("FIM"));
162        }
163
164        // Cholesky decomposition: F = LLᵀ
165        let mut l = vec![vec![0.0; n]; n];
166
167        for i in 0..n {
168            for j in 0..=i {
169                let mut sum = fim[i][j];
170
171                for k in 0..j {
172                    sum -= l[i][k] * l[j][k];
173                }
174
175                if i == j {
176                    if sum <= 0.0 {
177                        // Matrix not positive definite
178                        return Err(MathError::numerical_instability(
179                            "FIM not positive definite",
180                        ));
181                    }
182                    l[i][j] = sum.sqrt();
183                } else {
184                    l[i][j] = sum / l[j][j];
185                }
186            }
187        }
188
189        // Forward substitution to get L⁻¹
190        let mut l_inv = vec![vec![0.0; n]; n];
191        for i in 0..n {
192            l_inv[i][i] = 1.0 / l[i][i];
193            for j in (i + 1)..n {
194                let mut sum = 0.0;
195                for k in i..j {
196                    sum -= l[j][k] * l_inv[k][i];
197                }
198                l_inv[j][i] = sum / l[j][j];
199            }
200        }
201
202        // F⁻¹ = (LLᵀ)⁻¹ = L⁻ᵀ L⁻¹
203        let mut fim_inv = vec![vec![0.0; n]; n];
204        for i in 0..n {
205            for j in 0..n {
206                for k in 0..n {
207                    fim_inv[i][j] += l_inv[k][i] * l_inv[k][j];
208                }
209            }
210        }
211
212        Ok(fim_inv)
213    }
214
215    /// Compute natural gradient: F⁻¹ ∇L
216    pub fn natural_gradient(&self, fim: &[Vec<f64>], gradient: &[f64]) -> Result<Vec<f64>> {
217        let fim_inv = self.invert_fim(fim)?;
218        let n = gradient.len();
219
220        if fim_inv.len() != n {
221            return Err(MathError::dimension_mismatch(n, fim_inv.len()));
222        }
223
224        let mut nat_grad = vec![0.0; n];
225        for i in 0..n {
226            for j in 0..n {
227                nat_grad[i] += fim_inv[i][j] * gradient[j];
228            }
229        }
230
231        Ok(nat_grad)
232    }
233}
234
235impl Default for FisherInformation {
236    fn default() -> Self {
237        Self::new()
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    #[test]
246    fn test_empirical_fim() {
247        let fisher = FisherInformation::new().with_damping(0.0);
248
249        // Simple gradients
250        let grads = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
251
252        let fim = fisher.empirical_fim(&grads).unwrap();
253
254        // Expected: [[2/3, 1/3], [1/3, 2/3]] + small damping
255        assert!((fim[0][0] - 2.0 / 3.0).abs() < 1e-6);
256        assert!((fim[1][1] - 2.0 / 3.0).abs() < 1e-6);
257        assert!((fim[0][1] - 1.0 / 3.0).abs() < 1e-6);
258    }
259
260    #[test]
261    fn test_gaussian_fim() {
262        let fisher = FisherInformation::new().with_damping(0.0);
263        let fim = fisher.gaussian_fim(3, 0.5);
264
265        // F = I / 0.5 = 2I (plus small damping on diagonal)
266        assert!((fim[0][0] - 2.0).abs() < 1e-6);
267        assert!((fim[1][1] - 2.0).abs() < 1e-6);
268        assert!(fim[0][1].abs() < 1e-6);
269    }
270
271    #[test]
272    fn test_fim_inversion() {
273        let fisher = FisherInformation::new();
274
275        // Identity matrix
276        let fim = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
277
278        let fim_inv = fisher.invert_fim(&fim).unwrap();
279
280        // Inverse of identity is identity
281        assert!((fim_inv[0][0] - 1.0).abs() < 1e-6);
282        assert!((fim_inv[1][1] - 1.0).abs() < 1e-6);
283    }
284
285    #[test]
286    fn test_natural_gradient() {
287        let fisher = FisherInformation::new().with_damping(0.0);
288
289        // F = 2I
290        let fim = vec![vec![2.0, 0.0], vec![0.0, 2.0]];
291        let grad = vec![4.0, 6.0];
292
293        let nat_grad = fisher.natural_gradient(&fim, &grad).unwrap();
294
295        // nat_grad = F⁻¹ grad = (1/2) grad
296        assert!((nat_grad[0] - 2.0).abs() < 1e-6);
297        assert!((nat_grad[1] - 3.0).abs() < 1e-6);
298    }
299}