ruvector_math/information_geometry/
natural_gradient.rs

1//! Natural Gradient Descent
2//!
3//! Natural gradient descent rescales gradient updates to account for the
4//! curvature of the parameter space, leading to faster convergence.
5//!
6//! ## Algorithm
7//!
8//! θ_{t+1} = θ_t - η F(θ_t)⁻¹ ∇L(θ_t)
9//!
10//! where F is the Fisher Information Matrix.
11//!
12//! ## Benefits
13//!
14//! - **Invariant to reparameterization**: Same trajectory regardless of parameterization
15//! - **Faster convergence**: 3-5x fewer iterations than SGD/Adam on well-conditioned problems
16//! - **Better generalization**: Follows geodesics in probability space
17
18use crate::error::{MathError, Result};
19use crate::utils::EPS;
20use super::FisherInformation;
21
22/// Natural gradient optimizer state
23#[derive(Debug, Clone)]
24pub struct NaturalGradient {
25    /// Learning rate
26    learning_rate: f64,
27    /// Damping factor for FIM
28    damping: f64,
29    /// Whether to use diagonal approximation
30    use_diagonal: bool,
31    /// Exponential moving average factor for FIM
32    ema_factor: f64,
33    /// Running FIM estimate
34    fim_estimate: Option<FimEstimate>,
35}
36
37#[derive(Debug, Clone)]
38enum FimEstimate {
39    Full(Vec<Vec<f64>>),
40    Diagonal(Vec<f64>),
41}
42
43impl NaturalGradient {
44    /// Create a new natural gradient optimizer
45    ///
46    /// # Arguments
47    /// * `learning_rate` - Step size (0.01-0.1 typical)
48    pub fn new(learning_rate: f64) -> Self {
49        Self {
50            learning_rate: learning_rate.max(EPS),
51            damping: 1e-4,
52            use_diagonal: false,
53            ema_factor: 0.9,
54            fim_estimate: None,
55        }
56    }
57
58    /// Set damping factor
59    pub fn with_damping(mut self, damping: f64) -> Self {
60        self.damping = damping.max(EPS);
61        self
62    }
63
64    /// Use diagonal FIM approximation (faster, less memory)
65    pub fn with_diagonal(mut self, use_diagonal: bool) -> Self {
66        self.use_diagonal = use_diagonal;
67        self
68    }
69
70    /// Set EMA factor for FIM smoothing
71    pub fn with_ema(mut self, ema: f64) -> Self {
72        self.ema_factor = ema.clamp(0.0, 1.0);
73        self
74    }
75
76    /// Compute natural gradient step
77    ///
78    /// # Arguments
79    /// * `gradient` - Standard gradient ∇L
80    /// * `gradient_samples` - Optional gradient samples for FIM estimation
81    pub fn step(
82        &mut self,
83        gradient: &[f64],
84        gradient_samples: Option<&[Vec<f64>]>,
85    ) -> Result<Vec<f64>> {
86        // Update FIM estimate if samples provided
87        if let Some(samples) = gradient_samples {
88            self.update_fim(samples)?;
89        }
90
91        // Compute natural gradient
92        let nat_grad = match &self.fim_estimate {
93            Some(FimEstimate::Full(fim)) => {
94                let fisher = FisherInformation::new().with_damping(self.damping);
95                fisher.natural_gradient(fim, gradient)?
96            }
97            Some(FimEstimate::Diagonal(diag)) => {
98                // Element-wise: nat_grad = grad / diag
99                gradient
100                    .iter()
101                    .zip(diag.iter())
102                    .map(|(&g, &d)| g / (d + self.damping))
103                    .collect()
104            }
105            None => {
106                // No FIM estimate, use gradient as-is
107                gradient.to_vec()
108            }
109        };
110
111        // Scale by learning rate
112        Ok(nat_grad.iter().map(|&g| -self.learning_rate * g).collect())
113    }
114
115    /// Update running FIM estimate
116    fn update_fim(&mut self, gradient_samples: &[Vec<f64>]) -> Result<()> {
117        let fisher = FisherInformation::new().with_damping(0.0);
118
119        if self.use_diagonal {
120            let new_diag = fisher.diagonal_fim(gradient_samples)?;
121
122            self.fim_estimate = Some(FimEstimate::Diagonal(match &self.fim_estimate {
123                Some(FimEstimate::Diagonal(old)) => {
124                    // EMA update
125                    old.iter()
126                        .zip(new_diag.iter())
127                        .map(|(&o, &n)| self.ema_factor * o + (1.0 - self.ema_factor) * n)
128                        .collect()
129                }
130                _ => new_diag,
131            }));
132        } else {
133            let new_fim = fisher.empirical_fim(gradient_samples)?;
134            let dim = new_fim.len();
135
136            self.fim_estimate = Some(FimEstimate::Full(match &self.fim_estimate {
137                Some(FimEstimate::Full(old)) if old.len() == dim => {
138                    // EMA update
139                    (0..dim)
140                        .map(|i| {
141                            (0..dim)
142                                .map(|j| {
143                                    self.ema_factor * old[i][j]
144                                        + (1.0 - self.ema_factor) * new_fim[i][j]
145                                })
146                                .collect()
147                        })
148                        .collect()
149                }
150                _ => new_fim,
151            }));
152        }
153
154        Ok(())
155    }
156
157    /// Apply update to parameters
158    pub fn apply_update(parameters: &mut [f64], update: &[f64]) -> Result<()> {
159        if parameters.len() != update.len() {
160            return Err(MathError::dimension_mismatch(parameters.len(), update.len()));
161        }
162
163        for (p, &u) in parameters.iter_mut().zip(update.iter()) {
164            *p += u;
165        }
166
167        Ok(())
168    }
169
170    /// Full optimization step: compute and apply update
171    pub fn optimize_step(
172        &mut self,
173        parameters: &mut [f64],
174        gradient: &[f64],
175        gradient_samples: Option<&[Vec<f64>]>,
176    ) -> Result<f64> {
177        let update = self.step(gradient, gradient_samples)?;
178
179        let update_norm: f64 = update.iter().map(|&u| u * u).sum::<f64>().sqrt();
180
181        Self::apply_update(parameters, &update)?;
182
183        Ok(update_norm)
184    }
185
186    /// Reset optimizer state
187    pub fn reset(&mut self) {
188        self.fim_estimate = None;
189    }
190}
191
192/// Natural gradient with diagonal preconditioning (AdaGrad-like)
193#[derive(Debug, Clone)]
194pub struct DiagonalNaturalGradient {
195    /// Learning rate
196    learning_rate: f64,
197    /// Damping factor
198    damping: f64,
199    /// Accumulated squared gradients
200    accumulator: Vec<f64>,
201}
202
203impl DiagonalNaturalGradient {
204    /// Create new diagonal natural gradient optimizer
205    pub fn new(learning_rate: f64, dim: usize) -> Self {
206        Self {
207            learning_rate: learning_rate.max(EPS),
208            damping: 1e-8,
209            accumulator: vec![0.0; dim],
210        }
211    }
212
213    /// Set damping factor
214    pub fn with_damping(mut self, damping: f64) -> Self {
215        self.damping = damping.max(EPS);
216        self
217    }
218
219    /// Compute and apply update
220    pub fn step(&mut self, parameters: &mut [f64], gradient: &[f64]) -> Result<f64> {
221        if parameters.len() != gradient.len() || parameters.len() != self.accumulator.len() {
222            return Err(MathError::dimension_mismatch(
223                parameters.len(),
224                gradient.len(),
225            ));
226        }
227
228        let mut update_norm_sq = 0.0;
229
230        for (i, (p, &g)) in parameters.iter_mut().zip(gradient.iter()).enumerate() {
231            // Accumulate squared gradient (Fisher diagonal approximation)
232            self.accumulator[i] += g * g;
233
234            // Natural gradient step
235            let update = -self.learning_rate * g / (self.accumulator[i].sqrt() + self.damping);
236            *p += update;
237            update_norm_sq += update * update;
238        }
239
240        Ok(update_norm_sq.sqrt())
241    }
242
243    /// Reset accumulator
244    pub fn reset(&mut self) {
245        self.accumulator.fill(0.0);
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    #[test]
254    fn test_natural_gradient_step() {
255        let mut ng = NaturalGradient::new(0.1).with_diagonal(true);
256
257        let gradient = vec![1.0, 2.0, 3.0];
258
259        // First step without FIM estimate uses gradient directly
260        let update = ng.step(&gradient, None).unwrap();
261
262        assert_eq!(update.len(), 3);
263        // Should be -lr * gradient
264        assert!((update[0] + 0.1).abs() < 1e-10);
265    }
266
267    #[test]
268    fn test_natural_gradient_with_fim() {
269        let mut ng = NaturalGradient::new(0.1).with_diagonal(true).with_damping(0.0);
270
271        let gradient = vec![2.0, 4.0];
272
273        // Provide gradient samples for FIM estimation
274        let samples = vec![
275            vec![1.0, 0.0],
276            vec![0.0, 1.0],
277            vec![1.0, 1.0],
278        ];
279
280        let update = ng.step(&gradient, Some(&samples)).unwrap();
281
282        // With FIM, update should be preconditioned
283        assert_eq!(update.len(), 2);
284    }
285
286    #[test]
287    fn test_diagonal_natural_gradient() {
288        let mut dng = DiagonalNaturalGradient::new(1.0, 2);
289
290        let mut params = vec![0.0, 0.0];
291        let gradient = vec![1.0, 2.0];
292
293        let norm = dng.step(&mut params, &gradient).unwrap();
294
295        assert!(norm > 0.0);
296        // Parameters should have moved
297        assert!(params[0] < 0.0); // Moved in negative gradient direction
298    }
299
300    #[test]
301    fn test_optimizer_reset() {
302        let mut ng = NaturalGradient::new(0.1);
303
304        let samples = vec![vec![1.0, 2.0]];
305        let _ = ng.step(&[1.0, 1.0], Some(&samples));
306
307        ng.reset();
308        assert!(ng.fim_estimate.is_none());
309    }
310}