Skip to main content

ruvector_attention/info_geometry/
natural_gradient.rs

1//! Natural Gradient Descent
2//!
3//! Update parameters using the natural gradient: F^{-1} * grad
4//! where F is the Fisher information matrix.
5
6use super::fisher::{FisherConfig, FisherMetric};
7use serde::{Deserialize, Serialize};
8
9/// Natural gradient configuration
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct NaturalGradientConfig {
12    /// Learning rate
13    pub lr: f32,
14    /// Fisher metric config
15    pub fisher: FisherConfig,
16    /// Use diagonal approximation (faster but less accurate)
17    pub use_diagonal: bool,
18}
19
20impl Default for NaturalGradientConfig {
21    fn default() -> Self {
22        Self {
23            lr: 0.1,
24            fisher: FisherConfig::default(),
25            use_diagonal: false,
26        }
27    }
28}
29
30/// Natural gradient optimizer
31#[derive(Debug, Clone)]
32pub struct NaturalGradient {
33    config: NaturalGradientConfig,
34    fisher: FisherMetric,
35}
36
37impl NaturalGradient {
38    /// Create new natural gradient optimizer
39    pub fn new(config: NaturalGradientConfig) -> Self {
40        let fisher = FisherMetric::new(config.fisher.clone());
41        Self { config, fisher }
42    }
43
44    /// Compute natural gradient step for logits
45    /// Returns updated logits
46    pub fn step_logits(&self, logits: &[f32], grad_logits: &[f32]) -> Vec<f32> {
47        let probs = Self::softmax(logits);
48
49        // Compute natural gradient direction
50        let nat_grad = if self.config.use_diagonal {
51            self.fisher.apply_inverse_approx(&probs, grad_logits)
52        } else {
53            self.fisher.solve_cg(&probs, grad_logits)
54        };
55
56        // Update logits
57        let mut new_logits = logits.to_vec();
58        for i in 0..new_logits.len() {
59            new_logits[i] -= self.config.lr * nat_grad[i];
60        }
61
62        new_logits
63    }
64
65    /// Compute natural gradient step for general parameters with diagonal Fisher
66    /// Fisher diag should be pre-computed from data
67    pub fn step_diagonal(&self, params: &[f32], grads: &[f32], fisher_diag: &[f32]) -> Vec<f32> {
68        let n = params.len();
69        let mut new_params = params.to_vec();
70        let eps = self.config.fisher.eps;
71
72        for i in 0..n {
73            let f_inv = 1.0 / (fisher_diag[i].abs() + eps);
74            new_params[i] -= self.config.lr * grads[i] * f_inv;
75        }
76
77        new_params
78    }
79
80    /// Compute natural gradient for attention logits
81    /// Uses the Fisher metric on the output probability distribution
82    pub fn step_attention_logits(&self, logits: &[f32], grad_logits: &[f32]) -> Vec<f32> {
83        self.step_logits(logits, grad_logits)
84    }
85
86    /// Stable softmax
87    fn softmax(logits: &[f32]) -> Vec<f32> {
88        if logits.is_empty() {
89            return vec![];
90        }
91
92        let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
93        let exp_logits: Vec<f32> = logits.iter().map(|&l| (l - max_logit).exp()).collect();
94        let sum: f32 = exp_logits.iter().sum();
95
96        if sum > 0.0 {
97            exp_logits.iter().map(|&e| e / sum).collect()
98        } else {
99            vec![1.0 / logits.len() as f32; logits.len()]
100        }
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107
108    #[test]
109    fn test_natural_gradient_step() {
110        let config = NaturalGradientConfig {
111            lr: 0.1,
112            ..Default::default()
113        };
114        let ng = NaturalGradient::new(config);
115
116        let logits = vec![1.0, 2.0, 0.5, 0.5];
117        let grads = vec![0.1, -0.1, 0.05, -0.05];
118
119        let new_logits = ng.step_logits(&logits, &grads);
120
121        assert_eq!(new_logits.len(), 4);
122        // Should be different from original
123        assert!(
124            (new_logits[0] - logits[0]).abs() > 1e-6 || (new_logits[1] - logits[1]).abs() > 1e-6
125        );
126    }
127
128    #[test]
129    fn test_diagonal_step() {
130        let ng = NaturalGradient::new(NaturalGradientConfig::default());
131
132        let params = vec![1.0, 2.0, 3.0];
133        let grads = vec![0.1, 0.1, 0.1]; // Equal gradients
134        let fisher_diag = vec![1.0, 2.0, 0.5]; // Different Fisher values
135
136        let new_params = ng.step_diagonal(&params, &grads, &fisher_diag);
137
138        assert_eq!(new_params.len(), 3);
139        // Larger Fisher = smaller step (with equal gradients)
140        let step0 = (new_params[0] - params[0]).abs();
141        let step1 = (new_params[1] - params[1]).abs();
142        let step2 = (new_params[2] - params[2]).abs();
143        // Fisher[1] > Fisher[0] > Fisher[2], so step1 < step0 < step2
144        assert!(step1 < step0);
145        assert!(step0 < step2);
146    }
147
148    #[test]
149    fn test_attention_logits_step() {
150        let ng = NaturalGradient::new(NaturalGradientConfig::default());
151
152        let logits = vec![0.0; 10];
153        let grads = vec![0.1; 10];
154
155        let new_logits = ng.step_attention_logits(&logits, &grads);
156
157        assert_eq!(new_logits.len(), 10);
158    }
159}