ruvector_attention/info_geometry/
fisher.rs

1//! Fisher Information Metric
2//!
3//! The Fisher metric on the probability simplex:
4//! F = diag(p) - p*p^T
5//!
6//! This gives the natural geometry for probability distributions.
7
8use serde::{Deserialize, Serialize};
9
10/// Fisher metric configuration
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct FisherConfig {
13    /// Regularization epsilon for numerical stability
14    pub eps: f32,
15    /// Maximum CG iterations
16    pub max_iters: usize,
17    /// Convergence threshold
18    pub tol: f32,
19}
20
21impl Default for FisherConfig {
22    fn default() -> Self {
23        Self {
24            eps: 1e-8,
25            max_iters: 10,
26            tol: 1e-6,
27        }
28    }
29}
30
31/// Fisher metric operations
32#[derive(Debug, Clone)]
33pub struct FisherMetric {
34    config: FisherConfig,
35}
36
37impl FisherMetric {
38    /// Create new Fisher metric
39    pub fn new(config: FisherConfig) -> Self {
40        Self { config }
41    }
42
43    /// Apply Fisher matrix to vector: F*v = diag(p)*v - p*(p^T*v)
44    /// This is O(n) instead of O(n^2)
45    #[inline]
46    pub fn apply(&self, probs: &[f32], v: &[f32]) -> Vec<f32> {
47        let n = probs.len().min(v.len()); // Security: bounds check
48
49        if n == 0 {
50            return vec![];
51        }
52
53        // Compute p^T * v
54        let pv = Self::dot_simd(probs, v);
55
56        // F*v = diag(p)*v - p*(p^T*v)
57        let mut result = vec![0.0f32; n];
58        for i in 0..n {
59            result[i] = probs[i] * v[i] - probs[i] * pv;
60        }
61
62        result
63    }
64
65    /// Apply inverse Fisher (approximately) using diagonal preconditioning
66    /// F^{-1} ≈ diag(1/p) for small perturbations
67    #[inline]
68    pub fn apply_inverse_approx(&self, probs: &[f32], v: &[f32]) -> Vec<f32> {
69        let n = probs.len().min(v.len()); // Security: bounds check
70
71        if n == 0 {
72            return vec![];
73        }
74
75        let mut result = vec![0.0f32; n];
76
77        for i in 0..n {
78            let p = probs[i].max(self.config.eps);
79            result[i] = v[i] / p;
80        }
81
82        // Project to sum-zero (tangent space of simplex)
83        let mean: f32 = result.iter().sum::<f32>() / n as f32;
84        for i in 0..n {
85            result[i] -= mean;
86        }
87
88        result
89    }
90
91    /// Solve F*x = b using conjugate gradient
92    /// Returns x such that probs[i]*x[i] - probs[i]*sum(probs[j]*x[j]) ≈ b[i]
93    pub fn solve_cg(&self, probs: &[f32], b: &[f32]) -> Vec<f32> {
94        let n = probs.len().min(b.len()); // Security: bounds check
95
96        if n == 0 {
97            return vec![];
98        }
99
100        // Project b to sum-zero (must be in tangent space)
101        let mut b_proj = b[..n].to_vec();
102        let b_mean: f32 = b_proj.iter().sum::<f32>() / n as f32;
103        for i in 0..n {
104            b_proj[i] -= b_mean;
105        }
106
107        // CG iteration
108        let mut x = vec![0.0f32; n];
109        let mut r = b_proj.clone();
110        let mut d = r.clone();
111
112        let mut rtr = Self::dot_simd(&r, &r);
113        if rtr < self.config.tol {
114            return x;
115        }
116
117        for _ in 0..self.config.max_iters {
118            let fd = self.apply(probs, &d);
119            let dfd = Self::dot_simd(&d, &fd).max(self.config.eps);
120            let alpha = rtr / dfd;
121
122            for i in 0..n {
123                x[i] += alpha * d[i];
124                r[i] -= alpha * fd[i];
125            }
126
127            let rtr_new = Self::dot_simd(&r, &r);
128            if rtr_new < self.config.tol {
129                break;
130            }
131
132            let beta = rtr_new / rtr.max(self.config.eps); // Security: prevent division by zero
133            for i in 0..n {
134                d[i] = r[i] + beta * d[i];
135            }
136
137            rtr = rtr_new;
138        }
139
140        x
141    }
142
143    /// Compute Fisher-Rao distance between two probability distributions
144    /// d_FR(p, q) = 2 * arccos(sum(sqrt(p_i * q_i)))
145    pub fn fisher_rao_distance(&self, p: &[f32], q: &[f32]) -> f32 {
146        let n = p.len().min(q.len());
147        let mut bhattacharyya = 0.0f32;
148
149        for i in 0..n {
150            let pi = p[i].max(self.config.eps);
151            let qi = q[i].max(self.config.eps);
152            bhattacharyya += (pi * qi).sqrt();
153        }
154
155        // Clamp for numerical stability
156        let cos_half = bhattacharyya.clamp(0.0, 1.0);
157        2.0 * cos_half.acos()
158    }
159
160    /// SIMD-friendly dot product
161    #[inline(always)]
162    fn dot_simd(a: &[f32], b: &[f32]) -> f32 {
163        let len = a.len().min(b.len());
164        let chunks = len / 4;
165        let remainder = len % 4;
166
167        let mut sum0 = 0.0f32;
168        let mut sum1 = 0.0f32;
169        let mut sum2 = 0.0f32;
170        let mut sum3 = 0.0f32;
171
172        for i in 0..chunks {
173            let base = i * 4;
174            sum0 += a[base] * b[base];
175            sum1 += a[base + 1] * b[base + 1];
176            sum2 += a[base + 2] * b[base + 2];
177            sum3 += a[base + 3] * b[base + 3];
178        }
179
180        let base = chunks * 4;
181        for i in 0..remainder {
182            sum0 += a[base + i] * b[base + i];
183        }
184
185        sum0 + sum1 + sum2 + sum3
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192
193    #[test]
194    fn test_fisher_apply() {
195        let fisher = FisherMetric::new(FisherConfig::default());
196
197        // Uniform distribution
198        let p = vec![0.25, 0.25, 0.25, 0.25];
199        let v = vec![1.0, 0.0, 0.0, -1.0];
200
201        let fv = fisher.apply(&p, &v);
202
203        // F*v should be in tangent space (sum to ~0)
204        let sum: f32 = fv.iter().sum();
205        assert!(sum.abs() < 1e-5);
206    }
207
208    #[test]
209    fn test_fisher_cg_solve() {
210        let fisher = FisherMetric::new(FisherConfig::default());
211
212        let p = vec![0.4, 0.3, 0.2, 0.1];
213        let b = vec![0.1, -0.05, -0.02, -0.03]; // sum-zero
214
215        let x = fisher.solve_cg(&p, &b);
216
217        // F*x should approximately equal b
218        let fx = fisher.apply(&p, &x);
219
220        for i in 0..4 {
221            assert!((fx[i] - b[i]).abs() < 0.1);
222        }
223    }
224
225    #[test]
226    fn test_fisher_rao_distance() {
227        let fisher = FisherMetric::new(FisherConfig::default());
228
229        let p = vec![0.5, 0.5];
230        let q = vec![0.5, 0.5];
231
232        // Same distribution = 0 distance
233        let d = fisher.fisher_rao_distance(&p, &q);
234        assert!(d.abs() < 1e-5);
235
236        // Different distributions
237        let q2 = vec![0.9, 0.1];
238        let d2 = fisher.fisher_rao_distance(&p, &q2);
239        assert!(d2 > 0.0);
240    }
241}