Skip to main content

ruvector_math/information_geometry/
kfac.rs

1//! K-FAC: Kronecker-Factored Approximate Curvature
2//!
3//! K-FAC approximates the Fisher Information Matrix for neural networks using
4//! Kronecker products, reducing storage from O(n²) to O(n) and inversion from
5//! O(n³) to O(n^{3/2}).
6//!
7//! ## Theory
8//!
9//! For a layer with weights W ∈ R^{m×n}:
10//! - Gradient: ∇W = g ⊗ a (outer product of pre/post activations)
11//! - FIM block: F_W ≈ E[gg^T] ⊗ E[aa^T] = G ⊗ A (Kronecker factorization)
12//!
13//! ## Benefits
14//!
15//! - **Memory efficient**: Store two small matrices instead of one huge one
16//! - **Fast inversion**: (G ⊗ A)⁻¹ = G⁻¹ ⊗ A⁻¹
17//! - **Practical natural gradient**: Scales to large networks
18//!
19//! ## References
20//!
21//! - Martens & Grosse (2015): "Optimizing Neural Networks with Kronecker-factored
22//!   Approximate Curvature"
23
24use crate::error::{MathError, Result};
25use crate::utils::EPS;
26
27/// K-FAC approximation for a single layer
28#[derive(Debug, Clone)]
29pub struct KFACLayer {
30    /// Input-side factor A = E[aa^T]
31    pub a_factor: Vec<Vec<f64>>,
32    /// Output-side factor G = E[gg^T]
33    pub g_factor: Vec<Vec<f64>>,
34    /// Damping factor
35    damping: f64,
36    /// EMA factor for running estimates
37    ema_factor: f64,
38    /// Number of updates
39    num_updates: usize,
40}
41
42impl KFACLayer {
43    /// Create a new K-FAC layer approximation
44    ///
45    /// # Arguments
46    /// * `input_dim` - Size of input activations
47    /// * `output_dim` - Size of output gradients
48    pub fn new(input_dim: usize, output_dim: usize) -> Self {
49        Self {
50            a_factor: vec![vec![0.0; input_dim]; input_dim],
51            g_factor: vec![vec![0.0; output_dim]; output_dim],
52            damping: 1e-3,
53            ema_factor: 0.95,
54            num_updates: 0,
55        }
56    }
57
58    /// Set damping factor
59    pub fn with_damping(mut self, damping: f64) -> Self {
60        self.damping = damping.max(EPS);
61        self
62    }
63
64    /// Set EMA factor
65    pub fn with_ema(mut self, ema: f64) -> Self {
66        self.ema_factor = ema.clamp(0.0, 1.0);
67        self
68    }
69
70    /// Update factors with new activations and gradients
71    ///
72    /// # Arguments
73    /// * `activations` - Pre-activation inputs, shape [batch, input_dim]
74    /// * `gradients` - Post-activation gradients, shape [batch, output_dim]
75    pub fn update(&mut self, activations: &[Vec<f64>], gradients: &[Vec<f64>]) -> Result<()> {
76        if activations.is_empty() || gradients.is_empty() {
77            return Err(MathError::empty_input("batch"));
78        }
79
80        let batch_size = activations.len();
81        if gradients.len() != batch_size {
82            return Err(MathError::dimension_mismatch(batch_size, gradients.len()));
83        }
84
85        let input_dim = self.a_factor.len();
86        let output_dim = self.g_factor.len();
87
88        // Compute A = E[aa^T]
89        let mut new_a = vec![vec![0.0; input_dim]; input_dim];
90        for act in activations {
91            if act.len() != input_dim {
92                return Err(MathError::dimension_mismatch(input_dim, act.len()));
93            }
94            for i in 0..input_dim {
95                for j in 0..input_dim {
96                    new_a[i][j] += act[i] * act[j] / batch_size as f64;
97                }
98            }
99        }
100
101        // Compute G = E[gg^T]
102        let mut new_g = vec![vec![0.0; output_dim]; output_dim];
103        for grad in gradients {
104            if grad.len() != output_dim {
105                return Err(MathError::dimension_mismatch(output_dim, grad.len()));
106            }
107            for i in 0..output_dim {
108                for j in 0..output_dim {
109                    new_g[i][j] += grad[i] * grad[j] / batch_size as f64;
110                }
111            }
112        }
113
114        // EMA update
115        if self.num_updates == 0 {
116            self.a_factor = new_a;
117            self.g_factor = new_g;
118        } else {
119            for i in 0..input_dim {
120                for j in 0..input_dim {
121                    self.a_factor[i][j] = self.ema_factor * self.a_factor[i][j]
122                        + (1.0 - self.ema_factor) * new_a[i][j];
123                }
124            }
125            for i in 0..output_dim {
126                for j in 0..output_dim {
127                    self.g_factor[i][j] = self.ema_factor * self.g_factor[i][j]
128                        + (1.0 - self.ema_factor) * new_g[i][j];
129                }
130            }
131        }
132
133        self.num_updates += 1;
134        Ok(())
135    }
136
137    /// Compute natural gradient for weight matrix
138    ///
139    /// nat_grad = G⁻¹ ∇W A⁻¹
140    ///
141    /// # Arguments
142    /// * `weight_grad` - Gradient w.r.t. weights, shape [output_dim, input_dim]
143    pub fn natural_gradient(&self, weight_grad: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
144        let output_dim = self.g_factor.len();
145        let input_dim = self.a_factor.len();
146
147        if weight_grad.len() != output_dim {
148            return Err(MathError::dimension_mismatch(output_dim, weight_grad.len()));
149        }
150
151        // Add damping to factors
152        let a_damped = self.add_damping(&self.a_factor);
153        let g_damped = self.add_damping(&self.g_factor);
154
155        // Invert factors
156        let a_inv = self.invert_matrix(&a_damped)?;
157        let g_inv = self.invert_matrix(&g_damped)?;
158
159        // Compute G⁻¹ ∇W A⁻¹
160        // First: ∇W A⁻¹
161        let mut grad_a_inv = vec![vec![0.0; input_dim]; output_dim];
162        for i in 0..output_dim {
163            for j in 0..input_dim {
164                for k in 0..input_dim {
165                    grad_a_inv[i][j] += weight_grad[i][k] * a_inv[k][j];
166                }
167            }
168        }
169
170        // Then: G⁻¹ (∇W A⁻¹)
171        let mut nat_grad = vec![vec![0.0; input_dim]; output_dim];
172        for i in 0..output_dim {
173            for j in 0..input_dim {
174                for k in 0..output_dim {
175                    nat_grad[i][j] += g_inv[i][k] * grad_a_inv[k][j];
176                }
177            }
178        }
179
180        Ok(nat_grad)
181    }
182
183    /// Add damping to diagonal of matrix
184    fn add_damping(&self, matrix: &[Vec<f64>]) -> Vec<Vec<f64>> {
185        let n = matrix.len();
186        let mut damped = matrix.to_vec();
187
188        // Add π-damping (Tikhonov + trace normalization)
189        let trace: f64 = (0..n).map(|i| matrix[i][i]).sum();
190        let pi_damping = (self.damping * trace / n as f64).max(EPS);
191
192        for i in 0..n {
193            damped[i][i] += pi_damping;
194        }
195
196        damped
197    }
198
199    /// Invert matrix using Cholesky decomposition
200    fn invert_matrix(&self, matrix: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
201        let n = matrix.len();
202
203        // Cholesky: A = LLᵀ
204        let mut l = vec![vec![0.0; n]; n];
205
206        for i in 0..n {
207            for j in 0..=i {
208                let mut sum = matrix[i][j];
209                for k in 0..j {
210                    sum -= l[i][k] * l[j][k];
211                }
212
213                if i == j {
214                    if sum <= 0.0 {
215                        return Err(MathError::numerical_instability(
216                            "Matrix not positive definite in K-FAC",
217                        ));
218                    }
219                    l[i][j] = sum.sqrt();
220                } else {
221                    l[i][j] = sum / l[j][j];
222                }
223            }
224        }
225
226        // L⁻¹ via forward substitution
227        let mut l_inv = vec![vec![0.0; n]; n];
228        for i in 0..n {
229            l_inv[i][i] = 1.0 / l[i][i];
230            for j in (i + 1)..n {
231                let mut sum = 0.0;
232                for k in i..j {
233                    sum -= l[j][k] * l_inv[k][i];
234                }
235                l_inv[j][i] = sum / l[j][j];
236            }
237        }
238
239        // A⁻¹ = L⁻ᵀL⁻¹
240        let mut inv = vec![vec![0.0; n]; n];
241        for i in 0..n {
242            for j in 0..n {
243                for k in 0..n {
244                    inv[i][j] += l_inv[k][i] * l_inv[k][j];
245                }
246            }
247        }
248
249        Ok(inv)
250    }
251
252    /// Reset factor estimates
253    pub fn reset(&mut self) {
254        let input_dim = self.a_factor.len();
255        let output_dim = self.g_factor.len();
256
257        self.a_factor = vec![vec![0.0; input_dim]; input_dim];
258        self.g_factor = vec![vec![0.0; output_dim]; output_dim];
259        self.num_updates = 0;
260    }
261}
262
263/// K-FAC approximation for full network
264#[derive(Debug, Clone)]
265pub struct KFACApproximation {
266    /// Per-layer K-FAC factors
267    layers: Vec<KFACLayer>,
268    /// Learning rate
269    learning_rate: f64,
270    /// Global damping
271    damping: f64,
272}
273
274impl KFACApproximation {
275    /// Create K-FAC optimizer for a network
276    ///
277    /// # Arguments
278    /// * `layer_dims` - List of (input_dim, output_dim) for each layer
279    pub fn new(layer_dims: &[(usize, usize)]) -> Self {
280        let layers = layer_dims
281            .iter()
282            .map(|&(input, output)| KFACLayer::new(input, output))
283            .collect();
284
285        Self {
286            layers,
287            learning_rate: 0.01,
288            damping: 1e-3,
289        }
290    }
291
292    /// Set learning rate
293    pub fn with_learning_rate(mut self, lr: f64) -> Self {
294        self.learning_rate = lr.max(EPS);
295        self
296    }
297
298    /// Set damping
299    pub fn with_damping(mut self, damping: f64) -> Self {
300        self.damping = damping.max(EPS);
301        for layer in &mut self.layers {
302            layer.damping = damping;
303        }
304        self
305    }
306
307    /// Update factors for a layer
308    pub fn update_layer(
309        &mut self,
310        layer_idx: usize,
311        activations: &[Vec<f64>],
312        gradients: &[Vec<f64>],
313    ) -> Result<()> {
314        if layer_idx >= self.layers.len() {
315            return Err(MathError::invalid_parameter(
316                "layer_idx",
317                "index out of bounds",
318            ));
319        }
320
321        self.layers[layer_idx].update(activations, gradients)
322    }
323
324    /// Compute natural gradient for a layer's weights
325    pub fn natural_gradient_layer(
326        &self,
327        layer_idx: usize,
328        weight_grad: &[Vec<f64>],
329    ) -> Result<Vec<Vec<f64>>> {
330        if layer_idx >= self.layers.len() {
331            return Err(MathError::invalid_parameter(
332                "layer_idx",
333                "index out of bounds",
334            ));
335        }
336
337        let mut nat_grad = self.layers[layer_idx].natural_gradient(weight_grad)?;
338
339        // Scale by learning rate
340        for row in &mut nat_grad {
341            for val in row {
342                *val *= -self.learning_rate;
343            }
344        }
345
346        Ok(nat_grad)
347    }
348
349    /// Get number of layers
350    pub fn num_layers(&self) -> usize {
351        self.layers.len()
352    }
353
354    /// Reset all layer estimates
355    pub fn reset(&mut self) {
356        for layer in &mut self.layers {
357            layer.reset();
358        }
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365
366    #[test]
367    fn test_kfac_layer_creation() {
368        let layer = KFACLayer::new(10, 5);
369
370        assert_eq!(layer.a_factor.len(), 10);
371        assert_eq!(layer.g_factor.len(), 5);
372    }
373
374    #[test]
375    fn test_kfac_layer_update() {
376        let mut layer = KFACLayer::new(3, 2);
377
378        let activations = vec![vec![1.0, 0.0, 1.0], vec![0.0, 1.0, 1.0]];
379
380        let gradients = vec![vec![0.5, 0.5], vec![0.3, 0.7]];
381
382        layer.update(&activations, &gradients).unwrap();
383
384        // Factors should be updated
385        assert!(layer.a_factor[0][0] > 0.0);
386        assert!(layer.g_factor[0][0] > 0.0);
387    }
388
389    #[test]
390    fn test_kfac_natural_gradient() {
391        let mut layer = KFACLayer::new(2, 2).with_damping(0.1);
392
393        // Initialize with identity-like factors
394        let activations = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
395        let gradients = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
396
397        layer.update(&activations, &gradients).unwrap();
398
399        let weight_grad = vec![vec![0.1, 0.2], vec![0.3, 0.4]];
400
401        let nat_grad = layer.natural_gradient(&weight_grad).unwrap();
402
403        assert_eq!(nat_grad.len(), 2);
404        assert_eq!(nat_grad[0].len(), 2);
405    }
406
407    #[test]
408    fn test_kfac_full_network() {
409        let kfac = KFACApproximation::new(&[(10, 20), (20, 5)])
410            .with_learning_rate(0.01)
411            .with_damping(0.001);
412
413        assert_eq!(kfac.num_layers(), 2);
414    }
415}