ruvector_math/information_geometry/
kfac.rs1use crate::error::{MathError, Result};
25use crate::utils::EPS;
26
27#[derive(Debug, Clone)]
29pub struct KFACLayer {
30 pub a_factor: Vec<Vec<f64>>,
32 pub g_factor: Vec<Vec<f64>>,
34 damping: f64,
36 ema_factor: f64,
38 num_updates: usize,
40}
41
42impl KFACLayer {
43 pub fn new(input_dim: usize, output_dim: usize) -> Self {
49 Self {
50 a_factor: vec![vec![0.0; input_dim]; input_dim],
51 g_factor: vec![vec![0.0; output_dim]; output_dim],
52 damping: 1e-3,
53 ema_factor: 0.95,
54 num_updates: 0,
55 }
56 }
57
58 pub fn with_damping(mut self, damping: f64) -> Self {
60 self.damping = damping.max(EPS);
61 self
62 }
63
64 pub fn with_ema(mut self, ema: f64) -> Self {
66 self.ema_factor = ema.clamp(0.0, 1.0);
67 self
68 }
69
70 pub fn update(
76 &mut self,
77 activations: &[Vec<f64>],
78 gradients: &[Vec<f64>],
79 ) -> Result<()> {
80 if activations.is_empty() || gradients.is_empty() {
81 return Err(MathError::empty_input("batch"));
82 }
83
84 let batch_size = activations.len();
85 if gradients.len() != batch_size {
86 return Err(MathError::dimension_mismatch(batch_size, gradients.len()));
87 }
88
89 let input_dim = self.a_factor.len();
90 let output_dim = self.g_factor.len();
91
92 let mut new_a = vec![vec![0.0; input_dim]; input_dim];
94 for act in activations {
95 if act.len() != input_dim {
96 return Err(MathError::dimension_mismatch(input_dim, act.len()));
97 }
98 for i in 0..input_dim {
99 for j in 0..input_dim {
100 new_a[i][j] += act[i] * act[j] / batch_size as f64;
101 }
102 }
103 }
104
105 let mut new_g = vec![vec![0.0; output_dim]; output_dim];
107 for grad in gradients {
108 if grad.len() != output_dim {
109 return Err(MathError::dimension_mismatch(output_dim, grad.len()));
110 }
111 for i in 0..output_dim {
112 for j in 0..output_dim {
113 new_g[i][j] += grad[i] * grad[j] / batch_size as f64;
114 }
115 }
116 }
117
118 if self.num_updates == 0 {
120 self.a_factor = new_a;
121 self.g_factor = new_g;
122 } else {
123 for i in 0..input_dim {
124 for j in 0..input_dim {
125 self.a_factor[i][j] = self.ema_factor * self.a_factor[i][j]
126 + (1.0 - self.ema_factor) * new_a[i][j];
127 }
128 }
129 for i in 0..output_dim {
130 for j in 0..output_dim {
131 self.g_factor[i][j] = self.ema_factor * self.g_factor[i][j]
132 + (1.0 - self.ema_factor) * new_g[i][j];
133 }
134 }
135 }
136
137 self.num_updates += 1;
138 Ok(())
139 }
140
141 pub fn natural_gradient(&self, weight_grad: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
148 let output_dim = self.g_factor.len();
149 let input_dim = self.a_factor.len();
150
151 if weight_grad.len() != output_dim {
152 return Err(MathError::dimension_mismatch(output_dim, weight_grad.len()));
153 }
154
155 let a_damped = self.add_damping(&self.a_factor);
157 let g_damped = self.add_damping(&self.g_factor);
158
159 let a_inv = self.invert_matrix(&a_damped)?;
161 let g_inv = self.invert_matrix(&g_damped)?;
162
163 let mut grad_a_inv = vec![vec![0.0; input_dim]; output_dim];
166 for i in 0..output_dim {
167 for j in 0..input_dim {
168 for k in 0..input_dim {
169 grad_a_inv[i][j] += weight_grad[i][k] * a_inv[k][j];
170 }
171 }
172 }
173
174 let mut nat_grad = vec![vec![0.0; input_dim]; output_dim];
176 for i in 0..output_dim {
177 for j in 0..input_dim {
178 for k in 0..output_dim {
179 nat_grad[i][j] += g_inv[i][k] * grad_a_inv[k][j];
180 }
181 }
182 }
183
184 Ok(nat_grad)
185 }
186
187 fn add_damping(&self, matrix: &[Vec<f64>]) -> Vec<Vec<f64>> {
189 let n = matrix.len();
190 let mut damped = matrix.to_vec();
191
192 let trace: f64 = (0..n).map(|i| matrix[i][i]).sum();
194 let pi_damping = (self.damping * trace / n as f64).max(EPS);
195
196 for i in 0..n {
197 damped[i][i] += pi_damping;
198 }
199
200 damped
201 }
202
203 fn invert_matrix(&self, matrix: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
205 let n = matrix.len();
206
207 let mut l = vec![vec![0.0; n]; n];
209
210 for i in 0..n {
211 for j in 0..=i {
212 let mut sum = matrix[i][j];
213 for k in 0..j {
214 sum -= l[i][k] * l[j][k];
215 }
216
217 if i == j {
218 if sum <= 0.0 {
219 return Err(MathError::numerical_instability(
220 "Matrix not positive definite in K-FAC",
221 ));
222 }
223 l[i][j] = sum.sqrt();
224 } else {
225 l[i][j] = sum / l[j][j];
226 }
227 }
228 }
229
230 let mut l_inv = vec![vec![0.0; n]; n];
232 for i in 0..n {
233 l_inv[i][i] = 1.0 / l[i][i];
234 for j in (i + 1)..n {
235 let mut sum = 0.0;
236 for k in i..j {
237 sum -= l[j][k] * l_inv[k][i];
238 }
239 l_inv[j][i] = sum / l[j][j];
240 }
241 }
242
243 let mut inv = vec![vec![0.0; n]; n];
245 for i in 0..n {
246 for j in 0..n {
247 for k in 0..n {
248 inv[i][j] += l_inv[k][i] * l_inv[k][j];
249 }
250 }
251 }
252
253 Ok(inv)
254 }
255
256 pub fn reset(&mut self) {
258 let input_dim = self.a_factor.len();
259 let output_dim = self.g_factor.len();
260
261 self.a_factor = vec![vec![0.0; input_dim]; input_dim];
262 self.g_factor = vec![vec![0.0; output_dim]; output_dim];
263 self.num_updates = 0;
264 }
265}
266
267#[derive(Debug, Clone)]
269pub struct KFACApproximation {
270 layers: Vec<KFACLayer>,
272 learning_rate: f64,
274 damping: f64,
276}
277
278impl KFACApproximation {
279 pub fn new(layer_dims: &[(usize, usize)]) -> Self {
284 let layers = layer_dims
285 .iter()
286 .map(|&(input, output)| KFACLayer::new(input, output))
287 .collect();
288
289 Self {
290 layers,
291 learning_rate: 0.01,
292 damping: 1e-3,
293 }
294 }
295
296 pub fn with_learning_rate(mut self, lr: f64) -> Self {
298 self.learning_rate = lr.max(EPS);
299 self
300 }
301
302 pub fn with_damping(mut self, damping: f64) -> Self {
304 self.damping = damping.max(EPS);
305 for layer in &mut self.layers {
306 layer.damping = damping;
307 }
308 self
309 }
310
311 pub fn update_layer(
313 &mut self,
314 layer_idx: usize,
315 activations: &[Vec<f64>],
316 gradients: &[Vec<f64>],
317 ) -> Result<()> {
318 if layer_idx >= self.layers.len() {
319 return Err(MathError::invalid_parameter(
320 "layer_idx",
321 "index out of bounds",
322 ));
323 }
324
325 self.layers[layer_idx].update(activations, gradients)
326 }
327
328 pub fn natural_gradient_layer(
330 &self,
331 layer_idx: usize,
332 weight_grad: &[Vec<f64>],
333 ) -> Result<Vec<Vec<f64>>> {
334 if layer_idx >= self.layers.len() {
335 return Err(MathError::invalid_parameter(
336 "layer_idx",
337 "index out of bounds",
338 ));
339 }
340
341 let mut nat_grad = self.layers[layer_idx].natural_gradient(weight_grad)?;
342
343 for row in &mut nat_grad {
345 for val in row {
346 *val *= -self.learning_rate;
347 }
348 }
349
350 Ok(nat_grad)
351 }
352
353 pub fn num_layers(&self) -> usize {
355 self.layers.len()
356 }
357
358 pub fn reset(&mut self) {
360 for layer in &mut self.layers {
361 layer.reset();
362 }
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369
370 #[test]
371 fn test_kfac_layer_creation() {
372 let layer = KFACLayer::new(10, 5);
373
374 assert_eq!(layer.a_factor.len(), 10);
375 assert_eq!(layer.g_factor.len(), 5);
376 }
377
378 #[test]
379 fn test_kfac_layer_update() {
380 let mut layer = KFACLayer::new(3, 2);
381
382 let activations = vec![vec![1.0, 0.0, 1.0], vec![0.0, 1.0, 1.0]];
383
384 let gradients = vec![vec![0.5, 0.5], vec![0.3, 0.7]];
385
386 layer.update(&activations, &gradients).unwrap();
387
388 assert!(layer.a_factor[0][0] > 0.0);
390 assert!(layer.g_factor[0][0] > 0.0);
391 }
392
393 #[test]
394 fn test_kfac_natural_gradient() {
395 let mut layer = KFACLayer::new(2, 2).with_damping(0.1);
396
397 let activations = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
399 let gradients = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
400
401 layer.update(&activations, &gradients).unwrap();
402
403 let weight_grad = vec![vec![0.1, 0.2], vec![0.3, 0.4]];
404
405 let nat_grad = layer.natural_gradient(&weight_grad).unwrap();
406
407 assert_eq!(nat_grad.len(), 2);
408 assert_eq!(nat_grad[0].len(), 2);
409 }
410
411 #[test]
412 fn test_kfac_full_network() {
413 let kfac = KFACApproximation::new(&[(10, 20), (20, 5)])
414 .with_learning_rate(0.01)
415 .with_damping(0.001);
416
417 assert_eq!(kfac.num_layers(), 2);
418 }
419}