ruvector_attention/info_geometry/
natural_gradient.rs

1//! Natural Gradient Descent
2//!
3//! Update parameters using the natural gradient: F^{-1} * grad
4//! where F is the Fisher information matrix.
5
6use super::fisher::{FisherMetric, FisherConfig};
7use serde::{Deserialize, Serialize};
8
9/// Natural gradient configuration
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct NaturalGradientConfig {
12    /// Learning rate
13    pub lr: f32,
14    /// Fisher metric config
15    pub fisher: FisherConfig,
16    /// Use diagonal approximation (faster but less accurate)
17    pub use_diagonal: bool,
18}
19
20impl Default for NaturalGradientConfig {
21    fn default() -> Self {
22        Self {
23            lr: 0.1,
24            fisher: FisherConfig::default(),
25            use_diagonal: false,
26        }
27    }
28}
29
30/// Natural gradient optimizer
31#[derive(Debug, Clone)]
32pub struct NaturalGradient {
33    config: NaturalGradientConfig,
34    fisher: FisherMetric,
35}
36
37impl NaturalGradient {
38    /// Create new natural gradient optimizer
39    pub fn new(config: NaturalGradientConfig) -> Self {
40        let fisher = FisherMetric::new(config.fisher.clone());
41        Self { config, fisher }
42    }
43
44    /// Compute natural gradient step for logits
45    /// Returns updated logits
46    pub fn step_logits(&self, logits: &[f32], grad_logits: &[f32]) -> Vec<f32> {
47        let probs = Self::softmax(logits);
48
49        // Compute natural gradient direction
50        let nat_grad = if self.config.use_diagonal {
51            self.fisher.apply_inverse_approx(&probs, grad_logits)
52        } else {
53            self.fisher.solve_cg(&probs, grad_logits)
54        };
55
56        // Update logits
57        let mut new_logits = logits.to_vec();
58        for i in 0..new_logits.len() {
59            new_logits[i] -= self.config.lr * nat_grad[i];
60        }
61
62        new_logits
63    }
64
65    /// Compute natural gradient step for general parameters with diagonal Fisher
66    /// Fisher diag should be pre-computed from data
67    pub fn step_diagonal(
68        &self,
69        params: &[f32],
70        grads: &[f32],
71        fisher_diag: &[f32],
72    ) -> Vec<f32> {
73        let n = params.len();
74        let mut new_params = params.to_vec();
75        let eps = self.config.fisher.eps;
76
77        for i in 0..n {
78            let f_inv = 1.0 / (fisher_diag[i].abs() + eps);
79            new_params[i] -= self.config.lr * grads[i] * f_inv;
80        }
81
82        new_params
83    }
84
85    /// Compute natural gradient for attention logits
86    /// Uses the Fisher metric on the output probability distribution
87    pub fn step_attention_logits(
88        &self,
89        logits: &[f32],
90        grad_logits: &[f32],
91    ) -> Vec<f32> {
92        self.step_logits(logits, grad_logits)
93    }
94
95    /// Stable softmax
96    fn softmax(logits: &[f32]) -> Vec<f32> {
97        if logits.is_empty() {
98            return vec![];
99        }
100
101        let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
102        let exp_logits: Vec<f32> = logits.iter().map(|&l| (l - max_logit).exp()).collect();
103        let sum: f32 = exp_logits.iter().sum();
104
105        if sum > 0.0 {
106            exp_logits.iter().map(|&e| e / sum).collect()
107        } else {
108            vec![1.0 / logits.len() as f32; logits.len()]
109        }
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116
117    #[test]
118    fn test_natural_gradient_step() {
119        let config = NaturalGradientConfig {
120            lr: 0.1,
121            ..Default::default()
122        };
123        let ng = NaturalGradient::new(config);
124
125        let logits = vec![1.0, 2.0, 0.5, 0.5];
126        let grads = vec![0.1, -0.1, 0.05, -0.05];
127
128        let new_logits = ng.step_logits(&logits, &grads);
129
130        assert_eq!(new_logits.len(), 4);
131        // Should be different from original
132        assert!((new_logits[0] - logits[0]).abs() > 1e-6 ||
133                (new_logits[1] - logits[1]).abs() > 1e-6);
134    }
135
136    #[test]
137    fn test_diagonal_step() {
138        let ng = NaturalGradient::new(NaturalGradientConfig::default());
139
140        let params = vec![1.0, 2.0, 3.0];
141        let grads = vec![0.1, 0.1, 0.1]; // Equal gradients
142        let fisher_diag = vec![1.0, 2.0, 0.5]; // Different Fisher values
143
144        let new_params = ng.step_diagonal(&params, &grads, &fisher_diag);
145
146        assert_eq!(new_params.len(), 3);
147        // Larger Fisher = smaller step (with equal gradients)
148        let step0 = (new_params[0] - params[0]).abs();
149        let step1 = (new_params[1] - params[1]).abs();
150        let step2 = (new_params[2] - params[2]).abs();
151        // Fisher[1] > Fisher[0] > Fisher[2], so step1 < step0 < step2
152        assert!(step1 < step0);
153        assert!(step0 < step2);
154    }
155
156    #[test]
157    fn test_attention_logits_step() {
158        let ng = NaturalGradient::new(NaturalGradientConfig::default());
159
160        let logits = vec![0.0; 10];
161        let grads = vec![0.1; 10];
162
163        let new_logits = ng.step_attention_logits(&logits, &grads);
164
165        assert_eq!(new_logits.len(), 10);
166    }
167}