ruvector_attention/info_geometry/
fisher.rs1use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct FisherConfig {
13 pub eps: f32,
15 pub max_iters: usize,
17 pub tol: f32,
19}
20
21impl Default for FisherConfig {
22 fn default() -> Self {
23 Self {
24 eps: 1e-8,
25 max_iters: 10,
26 tol: 1e-6,
27 }
28 }
29}
30
31#[derive(Debug, Clone)]
33pub struct FisherMetric {
34 config: FisherConfig,
35}
36
37impl FisherMetric {
38 pub fn new(config: FisherConfig) -> Self {
40 Self { config }
41 }
42
43 #[inline]
46 pub fn apply(&self, probs: &[f32], v: &[f32]) -> Vec<f32> {
47 let n = probs.len().min(v.len()); if n == 0 {
50 return vec![];
51 }
52
53 let pv = Self::dot_simd(probs, v);
55
56 let mut result = vec![0.0f32; n];
58 for i in 0..n {
59 result[i] = probs[i] * v[i] - probs[i] * pv;
60 }
61
62 result
63 }
64
65 #[inline]
68 pub fn apply_inverse_approx(&self, probs: &[f32], v: &[f32]) -> Vec<f32> {
69 let n = probs.len().min(v.len()); if n == 0 {
72 return vec![];
73 }
74
75 let mut result = vec![0.0f32; n];
76
77 for i in 0..n {
78 let p = probs[i].max(self.config.eps);
79 result[i] = v[i] / p;
80 }
81
82 let mean: f32 = result.iter().sum::<f32>() / n as f32;
84 for i in 0..n {
85 result[i] -= mean;
86 }
87
88 result
89 }
90
91 pub fn solve_cg(&self, probs: &[f32], b: &[f32]) -> Vec<f32> {
94 let n = probs.len().min(b.len()); if n == 0 {
97 return vec![];
98 }
99
100 let mut b_proj = b[..n].to_vec();
102 let b_mean: f32 = b_proj.iter().sum::<f32>() / n as f32;
103 for i in 0..n {
104 b_proj[i] -= b_mean;
105 }
106
107 let mut x = vec![0.0f32; n];
109 let mut r = b_proj.clone();
110 let mut d = r.clone();
111
112 let mut rtr = Self::dot_simd(&r, &r);
113 if rtr < self.config.tol {
114 return x;
115 }
116
117 for _ in 0..self.config.max_iters {
118 let fd = self.apply(probs, &d);
119 let dfd = Self::dot_simd(&d, &fd).max(self.config.eps);
120 let alpha = rtr / dfd;
121
122 for i in 0..n {
123 x[i] += alpha * d[i];
124 r[i] -= alpha * fd[i];
125 }
126
127 let rtr_new = Self::dot_simd(&r, &r);
128 if rtr_new < self.config.tol {
129 break;
130 }
131
132 let beta = rtr_new / rtr.max(self.config.eps); for i in 0..n {
134 d[i] = r[i] + beta * d[i];
135 }
136
137 rtr = rtr_new;
138 }
139
140 x
141 }
142
143 pub fn fisher_rao_distance(&self, p: &[f32], q: &[f32]) -> f32 {
146 let n = p.len().min(q.len());
147 let mut bhattacharyya = 0.0f32;
148
149 for i in 0..n {
150 let pi = p[i].max(self.config.eps);
151 let qi = q[i].max(self.config.eps);
152 bhattacharyya += (pi * qi).sqrt();
153 }
154
155 let cos_half = bhattacharyya.clamp(0.0, 1.0);
157 2.0 * cos_half.acos()
158 }
159
160 #[inline(always)]
162 fn dot_simd(a: &[f32], b: &[f32]) -> f32 {
163 let len = a.len().min(b.len());
164 let chunks = len / 4;
165 let remainder = len % 4;
166
167 let mut sum0 = 0.0f32;
168 let mut sum1 = 0.0f32;
169 let mut sum2 = 0.0f32;
170 let mut sum3 = 0.0f32;
171
172 for i in 0..chunks {
173 let base = i * 4;
174 sum0 += a[base] * b[base];
175 sum1 += a[base + 1] * b[base + 1];
176 sum2 += a[base + 2] * b[base + 2];
177 sum3 += a[base + 3] * b[base + 3];
178 }
179
180 let base = chunks * 4;
181 for i in 0..remainder {
182 sum0 += a[base + i] * b[base + i];
183 }
184
185 sum0 + sum1 + sum2 + sum3
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192
193 #[test]
194 fn test_fisher_apply() {
195 let fisher = FisherMetric::new(FisherConfig::default());
196
197 let p = vec![0.25, 0.25, 0.25, 0.25];
199 let v = vec![1.0, 0.0, 0.0, -1.0];
200
201 let fv = fisher.apply(&p, &v);
202
203 let sum: f32 = fv.iter().sum();
205 assert!(sum.abs() < 1e-5);
206 }
207
208 #[test]
209 fn test_fisher_cg_solve() {
210 let fisher = FisherMetric::new(FisherConfig::default());
211
212 let p = vec![0.4, 0.3, 0.2, 0.1];
213 let b = vec![0.1, -0.05, -0.02, -0.03]; let x = fisher.solve_cg(&p, &b);
216
217 let fx = fisher.apply(&p, &x);
219
220 for i in 0..4 {
221 assert!((fx[i] - b[i]).abs() < 0.1);
222 }
223 }
224
225 #[test]
226 fn test_fisher_rao_distance() {
227 let fisher = FisherMetric::new(FisherConfig::default());
228
229 let p = vec![0.5, 0.5];
230 let q = vec![0.5, 0.5];
231
232 let d = fisher.fisher_rao_distance(&p, &q);
234 assert!(d.abs() < 1e-5);
235
236 let q2 = vec![0.9, 0.1];
238 let d2 = fisher.fisher_rao_distance(&p, &q2);
239 assert!(d2 > 0.0);
240 }
241}