ruvector_math/information_geometry/
kfac.rs1use crate::error::{MathError, Result};
25use crate::utils::EPS;
26
27#[derive(Debug, Clone)]
29pub struct KFACLayer {
30 pub a_factor: Vec<Vec<f64>>,
32 pub g_factor: Vec<Vec<f64>>,
34 damping: f64,
36 ema_factor: f64,
38 num_updates: usize,
40}
41
42impl KFACLayer {
43 pub fn new(input_dim: usize, output_dim: usize) -> Self {
49 Self {
50 a_factor: vec![vec![0.0; input_dim]; input_dim],
51 g_factor: vec![vec![0.0; output_dim]; output_dim],
52 damping: 1e-3,
53 ema_factor: 0.95,
54 num_updates: 0,
55 }
56 }
57
58 pub fn with_damping(mut self, damping: f64) -> Self {
60 self.damping = damping.max(EPS);
61 self
62 }
63
64 pub fn with_ema(mut self, ema: f64) -> Self {
66 self.ema_factor = ema.clamp(0.0, 1.0);
67 self
68 }
69
70 pub fn update(&mut self, activations: &[Vec<f64>], gradients: &[Vec<f64>]) -> Result<()> {
76 if activations.is_empty() || gradients.is_empty() {
77 return Err(MathError::empty_input("batch"));
78 }
79
80 let batch_size = activations.len();
81 if gradients.len() != batch_size {
82 return Err(MathError::dimension_mismatch(batch_size, gradients.len()));
83 }
84
85 let input_dim = self.a_factor.len();
86 let output_dim = self.g_factor.len();
87
88 let mut new_a = vec![vec![0.0; input_dim]; input_dim];
90 for act in activations {
91 if act.len() != input_dim {
92 return Err(MathError::dimension_mismatch(input_dim, act.len()));
93 }
94 for i in 0..input_dim {
95 for j in 0..input_dim {
96 new_a[i][j] += act[i] * act[j] / batch_size as f64;
97 }
98 }
99 }
100
101 let mut new_g = vec![vec![0.0; output_dim]; output_dim];
103 for grad in gradients {
104 if grad.len() != output_dim {
105 return Err(MathError::dimension_mismatch(output_dim, grad.len()));
106 }
107 for i in 0..output_dim {
108 for j in 0..output_dim {
109 new_g[i][j] += grad[i] * grad[j] / batch_size as f64;
110 }
111 }
112 }
113
114 if self.num_updates == 0 {
116 self.a_factor = new_a;
117 self.g_factor = new_g;
118 } else {
119 for i in 0..input_dim {
120 for j in 0..input_dim {
121 self.a_factor[i][j] = self.ema_factor * self.a_factor[i][j]
122 + (1.0 - self.ema_factor) * new_a[i][j];
123 }
124 }
125 for i in 0..output_dim {
126 for j in 0..output_dim {
127 self.g_factor[i][j] = self.ema_factor * self.g_factor[i][j]
128 + (1.0 - self.ema_factor) * new_g[i][j];
129 }
130 }
131 }
132
133 self.num_updates += 1;
134 Ok(())
135 }
136
137 pub fn natural_gradient(&self, weight_grad: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
144 let output_dim = self.g_factor.len();
145 let input_dim = self.a_factor.len();
146
147 if weight_grad.len() != output_dim {
148 return Err(MathError::dimension_mismatch(output_dim, weight_grad.len()));
149 }
150
151 let a_damped = self.add_damping(&self.a_factor);
153 let g_damped = self.add_damping(&self.g_factor);
154
155 let a_inv = self.invert_matrix(&a_damped)?;
157 let g_inv = self.invert_matrix(&g_damped)?;
158
159 let mut grad_a_inv = vec![vec![0.0; input_dim]; output_dim];
162 for i in 0..output_dim {
163 for j in 0..input_dim {
164 for k in 0..input_dim {
165 grad_a_inv[i][j] += weight_grad[i][k] * a_inv[k][j];
166 }
167 }
168 }
169
170 let mut nat_grad = vec![vec![0.0; input_dim]; output_dim];
172 for i in 0..output_dim {
173 for j in 0..input_dim {
174 for k in 0..output_dim {
175 nat_grad[i][j] += g_inv[i][k] * grad_a_inv[k][j];
176 }
177 }
178 }
179
180 Ok(nat_grad)
181 }
182
183 fn add_damping(&self, matrix: &[Vec<f64>]) -> Vec<Vec<f64>> {
185 let n = matrix.len();
186 let mut damped = matrix.to_vec();
187
188 let trace: f64 = (0..n).map(|i| matrix[i][i]).sum();
190 let pi_damping = (self.damping * trace / n as f64).max(EPS);
191
192 for i in 0..n {
193 damped[i][i] += pi_damping;
194 }
195
196 damped
197 }
198
199 fn invert_matrix(&self, matrix: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
201 let n = matrix.len();
202
203 let mut l = vec![vec![0.0; n]; n];
205
206 for i in 0..n {
207 for j in 0..=i {
208 let mut sum = matrix[i][j];
209 for k in 0..j {
210 sum -= l[i][k] * l[j][k];
211 }
212
213 if i == j {
214 if sum <= 0.0 {
215 return Err(MathError::numerical_instability(
216 "Matrix not positive definite in K-FAC",
217 ));
218 }
219 l[i][j] = sum.sqrt();
220 } else {
221 l[i][j] = sum / l[j][j];
222 }
223 }
224 }
225
226 let mut l_inv = vec![vec![0.0; n]; n];
228 for i in 0..n {
229 l_inv[i][i] = 1.0 / l[i][i];
230 for j in (i + 1)..n {
231 let mut sum = 0.0;
232 for k in i..j {
233 sum -= l[j][k] * l_inv[k][i];
234 }
235 l_inv[j][i] = sum / l[j][j];
236 }
237 }
238
239 let mut inv = vec![vec![0.0; n]; n];
241 for i in 0..n {
242 for j in 0..n {
243 for k in 0..n {
244 inv[i][j] += l_inv[k][i] * l_inv[k][j];
245 }
246 }
247 }
248
249 Ok(inv)
250 }
251
252 pub fn reset(&mut self) {
254 let input_dim = self.a_factor.len();
255 let output_dim = self.g_factor.len();
256
257 self.a_factor = vec![vec![0.0; input_dim]; input_dim];
258 self.g_factor = vec![vec![0.0; output_dim]; output_dim];
259 self.num_updates = 0;
260 }
261}
262
263#[derive(Debug, Clone)]
265pub struct KFACApproximation {
266 layers: Vec<KFACLayer>,
268 learning_rate: f64,
270 damping: f64,
272}
273
274impl KFACApproximation {
275 pub fn new(layer_dims: &[(usize, usize)]) -> Self {
280 let layers = layer_dims
281 .iter()
282 .map(|&(input, output)| KFACLayer::new(input, output))
283 .collect();
284
285 Self {
286 layers,
287 learning_rate: 0.01,
288 damping: 1e-3,
289 }
290 }
291
292 pub fn with_learning_rate(mut self, lr: f64) -> Self {
294 self.learning_rate = lr.max(EPS);
295 self
296 }
297
298 pub fn with_damping(mut self, damping: f64) -> Self {
300 self.damping = damping.max(EPS);
301 for layer in &mut self.layers {
302 layer.damping = damping;
303 }
304 self
305 }
306
307 pub fn update_layer(
309 &mut self,
310 layer_idx: usize,
311 activations: &[Vec<f64>],
312 gradients: &[Vec<f64>],
313 ) -> Result<()> {
314 if layer_idx >= self.layers.len() {
315 return Err(MathError::invalid_parameter(
316 "layer_idx",
317 "index out of bounds",
318 ));
319 }
320
321 self.layers[layer_idx].update(activations, gradients)
322 }
323
324 pub fn natural_gradient_layer(
326 &self,
327 layer_idx: usize,
328 weight_grad: &[Vec<f64>],
329 ) -> Result<Vec<Vec<f64>>> {
330 if layer_idx >= self.layers.len() {
331 return Err(MathError::invalid_parameter(
332 "layer_idx",
333 "index out of bounds",
334 ));
335 }
336
337 let mut nat_grad = self.layers[layer_idx].natural_gradient(weight_grad)?;
338
339 for row in &mut nat_grad {
341 for val in row {
342 *val *= -self.learning_rate;
343 }
344 }
345
346 Ok(nat_grad)
347 }
348
349 pub fn num_layers(&self) -> usize {
351 self.layers.len()
352 }
353
354 pub fn reset(&mut self) {
356 for layer in &mut self.layers {
357 layer.reset();
358 }
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365
366 #[test]
367 fn test_kfac_layer_creation() {
368 let layer = KFACLayer::new(10, 5);
369
370 assert_eq!(layer.a_factor.len(), 10);
371 assert_eq!(layer.g_factor.len(), 5);
372 }
373
374 #[test]
375 fn test_kfac_layer_update() {
376 let mut layer = KFACLayer::new(3, 2);
377
378 let activations = vec![vec![1.0, 0.0, 1.0], vec![0.0, 1.0, 1.0]];
379
380 let gradients = vec![vec![0.5, 0.5], vec![0.3, 0.7]];
381
382 layer.update(&activations, &gradients).unwrap();
383
384 assert!(layer.a_factor[0][0] > 0.0);
386 assert!(layer.g_factor[0][0] > 0.0);
387 }
388
389 #[test]
390 fn test_kfac_natural_gradient() {
391 let mut layer = KFACLayer::new(2, 2).with_damping(0.1);
392
393 let activations = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
395 let gradients = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
396
397 layer.update(&activations, &gradients).unwrap();
398
399 let weight_grad = vec![vec![0.1, 0.2], vec![0.3, 0.4]];
400
401 let nat_grad = layer.natural_gradient(&weight_grad).unwrap();
402
403 assert_eq!(nat_grad.len(), 2);
404 assert_eq!(nat_grad[0].len(), 2);
405 }
406
407 #[test]
408 fn test_kfac_full_network() {
409 let kfac = KFACApproximation::new(&[(10, 20), (20, 5)])
410 .with_learning_rate(0.01)
411 .with_damping(0.001);
412
413 assert_eq!(kfac.num_layers(), 2);
414 }
415}