ruvector_math/information_geometry/
kfac.rs

1//! K-FAC: Kronecker-Factored Approximate Curvature
2//!
3//! K-FAC approximates the Fisher Information Matrix for neural networks using
4//! Kronecker products, reducing storage from O(n²) to O(n) and inversion from
5//! O(n³) to O(n^{3/2}).
6//!
7//! ## Theory
8//!
9//! For a layer with weights W ∈ R^{m×n}:
10//! - Gradient: ∇W = g ⊗ a (outer product of pre/post activations)
11//! - FIM block: F_W ≈ E[gg^T] ⊗ E[aa^T] = G ⊗ A (Kronecker factorization)
12//!
13//! ## Benefits
14//!
15//! - **Memory efficient**: Store two small matrices instead of one huge one
16//! - **Fast inversion**: (G ⊗ A)⁻¹ = G⁻¹ ⊗ A⁻¹
17//! - **Practical natural gradient**: Scales to large networks
18//!
19//! ## References
20//!
21//! - Martens & Grosse (2015): "Optimizing Neural Networks with Kronecker-factored
22//!   Approximate Curvature"
23
24use crate::error::{MathError, Result};
25use crate::utils::EPS;
26
27/// K-FAC approximation for a single layer
28#[derive(Debug, Clone)]
29pub struct KFACLayer {
30    /// Input-side factor A = E[aa^T]
31    pub a_factor: Vec<Vec<f64>>,
32    /// Output-side factor G = E[gg^T]
33    pub g_factor: Vec<Vec<f64>>,
34    /// Damping factor
35    damping: f64,
36    /// EMA factor for running estimates
37    ema_factor: f64,
38    /// Number of updates
39    num_updates: usize,
40}
41
42impl KFACLayer {
43    /// Create a new K-FAC layer approximation
44    ///
45    /// # Arguments
46    /// * `input_dim` - Size of input activations
47    /// * `output_dim` - Size of output gradients
48    pub fn new(input_dim: usize, output_dim: usize) -> Self {
49        Self {
50            a_factor: vec![vec![0.0; input_dim]; input_dim],
51            g_factor: vec![vec![0.0; output_dim]; output_dim],
52            damping: 1e-3,
53            ema_factor: 0.95,
54            num_updates: 0,
55        }
56    }
57
58    /// Set damping factor
59    pub fn with_damping(mut self, damping: f64) -> Self {
60        self.damping = damping.max(EPS);
61        self
62    }
63
64    /// Set EMA factor
65    pub fn with_ema(mut self, ema: f64) -> Self {
66        self.ema_factor = ema.clamp(0.0, 1.0);
67        self
68    }
69
70    /// Update factors with new activations and gradients
71    ///
72    /// # Arguments
73    /// * `activations` - Pre-activation inputs, shape [batch, input_dim]
74    /// * `gradients` - Post-activation gradients, shape [batch, output_dim]
75    pub fn update(
76        &mut self,
77        activations: &[Vec<f64>],
78        gradients: &[Vec<f64>],
79    ) -> Result<()> {
80        if activations.is_empty() || gradients.is_empty() {
81            return Err(MathError::empty_input("batch"));
82        }
83
84        let batch_size = activations.len();
85        if gradients.len() != batch_size {
86            return Err(MathError::dimension_mismatch(batch_size, gradients.len()));
87        }
88
89        let input_dim = self.a_factor.len();
90        let output_dim = self.g_factor.len();
91
92        // Compute A = E[aa^T]
93        let mut new_a = vec![vec![0.0; input_dim]; input_dim];
94        for act in activations {
95            if act.len() != input_dim {
96                return Err(MathError::dimension_mismatch(input_dim, act.len()));
97            }
98            for i in 0..input_dim {
99                for j in 0..input_dim {
100                    new_a[i][j] += act[i] * act[j] / batch_size as f64;
101                }
102            }
103        }
104
105        // Compute G = E[gg^T]
106        let mut new_g = vec![vec![0.0; output_dim]; output_dim];
107        for grad in gradients {
108            if grad.len() != output_dim {
109                return Err(MathError::dimension_mismatch(output_dim, grad.len()));
110            }
111            for i in 0..output_dim {
112                for j in 0..output_dim {
113                    new_g[i][j] += grad[i] * grad[j] / batch_size as f64;
114                }
115            }
116        }
117
118        // EMA update
119        if self.num_updates == 0 {
120            self.a_factor = new_a;
121            self.g_factor = new_g;
122        } else {
123            for i in 0..input_dim {
124                for j in 0..input_dim {
125                    self.a_factor[i][j] = self.ema_factor * self.a_factor[i][j]
126                        + (1.0 - self.ema_factor) * new_a[i][j];
127                }
128            }
129            for i in 0..output_dim {
130                for j in 0..output_dim {
131                    self.g_factor[i][j] = self.ema_factor * self.g_factor[i][j]
132                        + (1.0 - self.ema_factor) * new_g[i][j];
133                }
134            }
135        }
136
137        self.num_updates += 1;
138        Ok(())
139    }
140
141    /// Compute natural gradient for weight matrix
142    ///
143    /// nat_grad = G⁻¹ ∇W A⁻¹
144    ///
145    /// # Arguments
146    /// * `weight_grad` - Gradient w.r.t. weights, shape [output_dim, input_dim]
147    pub fn natural_gradient(&self, weight_grad: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
148        let output_dim = self.g_factor.len();
149        let input_dim = self.a_factor.len();
150
151        if weight_grad.len() != output_dim {
152            return Err(MathError::dimension_mismatch(output_dim, weight_grad.len()));
153        }
154
155        // Add damping to factors
156        let a_damped = self.add_damping(&self.a_factor);
157        let g_damped = self.add_damping(&self.g_factor);
158
159        // Invert factors
160        let a_inv = self.invert_matrix(&a_damped)?;
161        let g_inv = self.invert_matrix(&g_damped)?;
162
163        // Compute G⁻¹ ∇W A⁻¹
164        // First: ∇W A⁻¹
165        let mut grad_a_inv = vec![vec![0.0; input_dim]; output_dim];
166        for i in 0..output_dim {
167            for j in 0..input_dim {
168                for k in 0..input_dim {
169                    grad_a_inv[i][j] += weight_grad[i][k] * a_inv[k][j];
170                }
171            }
172        }
173
174        // Then: G⁻¹ (∇W A⁻¹)
175        let mut nat_grad = vec![vec![0.0; input_dim]; output_dim];
176        for i in 0..output_dim {
177            for j in 0..input_dim {
178                for k in 0..output_dim {
179                    nat_grad[i][j] += g_inv[i][k] * grad_a_inv[k][j];
180                }
181            }
182        }
183
184        Ok(nat_grad)
185    }
186
187    /// Add damping to diagonal of matrix
188    fn add_damping(&self, matrix: &[Vec<f64>]) -> Vec<Vec<f64>> {
189        let n = matrix.len();
190        let mut damped = matrix.to_vec();
191
192        // Add π-damping (Tikhonov + trace normalization)
193        let trace: f64 = (0..n).map(|i| matrix[i][i]).sum();
194        let pi_damping = (self.damping * trace / n as f64).max(EPS);
195
196        for i in 0..n {
197            damped[i][i] += pi_damping;
198        }
199
200        damped
201    }
202
203    /// Invert matrix using Cholesky decomposition
204    fn invert_matrix(&self, matrix: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
205        let n = matrix.len();
206
207        // Cholesky: A = LLᵀ
208        let mut l = vec![vec![0.0; n]; n];
209
210        for i in 0..n {
211            for j in 0..=i {
212                let mut sum = matrix[i][j];
213                for k in 0..j {
214                    sum -= l[i][k] * l[j][k];
215                }
216
217                if i == j {
218                    if sum <= 0.0 {
219                        return Err(MathError::numerical_instability(
220                            "Matrix not positive definite in K-FAC",
221                        ));
222                    }
223                    l[i][j] = sum.sqrt();
224                } else {
225                    l[i][j] = sum / l[j][j];
226                }
227            }
228        }
229
230        // L⁻¹ via forward substitution
231        let mut l_inv = vec![vec![0.0; n]; n];
232        for i in 0..n {
233            l_inv[i][i] = 1.0 / l[i][i];
234            for j in (i + 1)..n {
235                let mut sum = 0.0;
236                for k in i..j {
237                    sum -= l[j][k] * l_inv[k][i];
238                }
239                l_inv[j][i] = sum / l[j][j];
240            }
241        }
242
243        // A⁻¹ = L⁻ᵀL⁻¹
244        let mut inv = vec![vec![0.0; n]; n];
245        for i in 0..n {
246            for j in 0..n {
247                for k in 0..n {
248                    inv[i][j] += l_inv[k][i] * l_inv[k][j];
249                }
250            }
251        }
252
253        Ok(inv)
254    }
255
256    /// Reset factor estimates
257    pub fn reset(&mut self) {
258        let input_dim = self.a_factor.len();
259        let output_dim = self.g_factor.len();
260
261        self.a_factor = vec![vec![0.0; input_dim]; input_dim];
262        self.g_factor = vec![vec![0.0; output_dim]; output_dim];
263        self.num_updates = 0;
264    }
265}
266
267/// K-FAC approximation for full network
268#[derive(Debug, Clone)]
269pub struct KFACApproximation {
270    /// Per-layer K-FAC factors
271    layers: Vec<KFACLayer>,
272    /// Learning rate
273    learning_rate: f64,
274    /// Global damping
275    damping: f64,
276}
277
278impl KFACApproximation {
279    /// Create K-FAC optimizer for a network
280    ///
281    /// # Arguments
282    /// * `layer_dims` - List of (input_dim, output_dim) for each layer
283    pub fn new(layer_dims: &[(usize, usize)]) -> Self {
284        let layers = layer_dims
285            .iter()
286            .map(|&(input, output)| KFACLayer::new(input, output))
287            .collect();
288
289        Self {
290            layers,
291            learning_rate: 0.01,
292            damping: 1e-3,
293        }
294    }
295
296    /// Set learning rate
297    pub fn with_learning_rate(mut self, lr: f64) -> Self {
298        self.learning_rate = lr.max(EPS);
299        self
300    }
301
302    /// Set damping
303    pub fn with_damping(mut self, damping: f64) -> Self {
304        self.damping = damping.max(EPS);
305        for layer in &mut self.layers {
306            layer.damping = damping;
307        }
308        self
309    }
310
311    /// Update factors for a layer
312    pub fn update_layer(
313        &mut self,
314        layer_idx: usize,
315        activations: &[Vec<f64>],
316        gradients: &[Vec<f64>],
317    ) -> Result<()> {
318        if layer_idx >= self.layers.len() {
319            return Err(MathError::invalid_parameter(
320                "layer_idx",
321                "index out of bounds",
322            ));
323        }
324
325        self.layers[layer_idx].update(activations, gradients)
326    }
327
328    /// Compute natural gradient for a layer's weights
329    pub fn natural_gradient_layer(
330        &self,
331        layer_idx: usize,
332        weight_grad: &[Vec<f64>],
333    ) -> Result<Vec<Vec<f64>>> {
334        if layer_idx >= self.layers.len() {
335            return Err(MathError::invalid_parameter(
336                "layer_idx",
337                "index out of bounds",
338            ));
339        }
340
341        let mut nat_grad = self.layers[layer_idx].natural_gradient(weight_grad)?;
342
343        // Scale by learning rate
344        for row in &mut nat_grad {
345            for val in row {
346                *val *= -self.learning_rate;
347            }
348        }
349
350        Ok(nat_grad)
351    }
352
353    /// Get number of layers
354    pub fn num_layers(&self) -> usize {
355        self.layers.len()
356    }
357
358    /// Reset all layer estimates
359    pub fn reset(&mut self) {
360        for layer in &mut self.layers {
361            layer.reset();
362        }
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369
370    #[test]
371    fn test_kfac_layer_creation() {
372        let layer = KFACLayer::new(10, 5);
373
374        assert_eq!(layer.a_factor.len(), 10);
375        assert_eq!(layer.g_factor.len(), 5);
376    }
377
378    #[test]
379    fn test_kfac_layer_update() {
380        let mut layer = KFACLayer::new(3, 2);
381
382        let activations = vec![vec![1.0, 0.0, 1.0], vec![0.0, 1.0, 1.0]];
383
384        let gradients = vec![vec![0.5, 0.5], vec![0.3, 0.7]];
385
386        layer.update(&activations, &gradients).unwrap();
387
388        // Factors should be updated
389        assert!(layer.a_factor[0][0] > 0.0);
390        assert!(layer.g_factor[0][0] > 0.0);
391    }
392
393    #[test]
394    fn test_kfac_natural_gradient() {
395        let mut layer = KFACLayer::new(2, 2).with_damping(0.1);
396
397        // Initialize with identity-like factors
398        let activations = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
399        let gradients = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
400
401        layer.update(&activations, &gradients).unwrap();
402
403        let weight_grad = vec![vec![0.1, 0.2], vec![0.3, 0.4]];
404
405        let nat_grad = layer.natural_gradient(&weight_grad).unwrap();
406
407        assert_eq!(nat_grad.len(), 2);
408        assert_eq!(nat_grad[0].len(), 2);
409    }
410
411    #[test]
412    fn test_kfac_full_network() {
413        let kfac = KFACApproximation::new(&[(10, 20), (20, 5)])
414            .with_learning_rate(0.01)
415            .with_damping(0.001);
416
417        assert_eq!(kfac.num_layers(), 2);
418    }
419}