Skip to main content

ruvector_math/information_geometry/
natural_gradient.rs

1//! Natural Gradient Descent
2//!
3//! Natural gradient descent rescales gradient updates to account for the
4//! curvature of the parameter space, leading to faster convergence.
5//!
6//! ## Algorithm
7//!
8//! θ_{t+1} = θ_t - η F(θ_t)⁻¹ ∇L(θ_t)
9//!
10//! where F is the Fisher Information Matrix.
11//!
12//! ## Benefits
13//!
14//! - **Invariant to reparameterization**: Same trajectory regardless of parameterization
15//! - **Faster convergence**: 3-5x fewer iterations than SGD/Adam on well-conditioned problems
16//! - **Better generalization**: Follows geodesics in probability space
17
18use super::FisherInformation;
19use crate::error::{MathError, Result};
20use crate::utils::EPS;
21
22/// Natural gradient optimizer state
23#[derive(Debug, Clone)]
24pub struct NaturalGradient {
25    /// Learning rate
26    learning_rate: f64,
27    /// Damping factor for FIM
28    damping: f64,
29    /// Whether to use diagonal approximation
30    use_diagonal: bool,
31    /// Exponential moving average factor for FIM
32    ema_factor: f64,
33    /// Running FIM estimate
34    fim_estimate: Option<FimEstimate>,
35}
36
37#[derive(Debug, Clone)]
38enum FimEstimate {
39    Full(Vec<Vec<f64>>),
40    Diagonal(Vec<f64>),
41}
42
43impl NaturalGradient {
44    /// Create a new natural gradient optimizer
45    ///
46    /// # Arguments
47    /// * `learning_rate` - Step size (0.01-0.1 typical)
48    pub fn new(learning_rate: f64) -> Self {
49        Self {
50            learning_rate: learning_rate.max(EPS),
51            damping: 1e-4,
52            use_diagonal: false,
53            ema_factor: 0.9,
54            fim_estimate: None,
55        }
56    }
57
58    /// Set damping factor
59    pub fn with_damping(mut self, damping: f64) -> Self {
60        self.damping = damping.max(EPS);
61        self
62    }
63
64    /// Use diagonal FIM approximation (faster, less memory)
65    pub fn with_diagonal(mut self, use_diagonal: bool) -> Self {
66        self.use_diagonal = use_diagonal;
67        self
68    }
69
70    /// Set EMA factor for FIM smoothing
71    pub fn with_ema(mut self, ema: f64) -> Self {
72        self.ema_factor = ema.clamp(0.0, 1.0);
73        self
74    }
75
76    /// Compute natural gradient step
77    ///
78    /// # Arguments
79    /// * `gradient` - Standard gradient ∇L
80    /// * `gradient_samples` - Optional gradient samples for FIM estimation
81    pub fn step(
82        &mut self,
83        gradient: &[f64],
84        gradient_samples: Option<&[Vec<f64>]>,
85    ) -> Result<Vec<f64>> {
86        // Update FIM estimate if samples provided
87        if let Some(samples) = gradient_samples {
88            self.update_fim(samples)?;
89        }
90
91        // Compute natural gradient
92        let nat_grad = match &self.fim_estimate {
93            Some(FimEstimate::Full(fim)) => {
94                let fisher = FisherInformation::new().with_damping(self.damping);
95                fisher.natural_gradient(fim, gradient)?
96            }
97            Some(FimEstimate::Diagonal(diag)) => {
98                // Element-wise: nat_grad = grad / diag
99                gradient
100                    .iter()
101                    .zip(diag.iter())
102                    .map(|(&g, &d)| g / (d + self.damping))
103                    .collect()
104            }
105            None => {
106                // No FIM estimate, use gradient as-is
107                gradient.to_vec()
108            }
109        };
110
111        // Scale by learning rate
112        Ok(nat_grad.iter().map(|&g| -self.learning_rate * g).collect())
113    }
114
115    /// Update running FIM estimate
116    fn update_fim(&mut self, gradient_samples: &[Vec<f64>]) -> Result<()> {
117        let fisher = FisherInformation::new().with_damping(0.0);
118
119        if self.use_diagonal {
120            let new_diag = fisher.diagonal_fim(gradient_samples)?;
121
122            self.fim_estimate = Some(FimEstimate::Diagonal(match &self.fim_estimate {
123                Some(FimEstimate::Diagonal(old)) => {
124                    // EMA update
125                    old.iter()
126                        .zip(new_diag.iter())
127                        .map(|(&o, &n)| self.ema_factor * o + (1.0 - self.ema_factor) * n)
128                        .collect()
129                }
130                _ => new_diag,
131            }));
132        } else {
133            let new_fim = fisher.empirical_fim(gradient_samples)?;
134            let dim = new_fim.len();
135
136            self.fim_estimate = Some(FimEstimate::Full(match &self.fim_estimate {
137                Some(FimEstimate::Full(old)) if old.len() == dim => {
138                    // EMA update
139                    (0..dim)
140                        .map(|i| {
141                            (0..dim)
142                                .map(|j| {
143                                    self.ema_factor * old[i][j]
144                                        + (1.0 - self.ema_factor) * new_fim[i][j]
145                                })
146                                .collect()
147                        })
148                        .collect()
149                }
150                _ => new_fim,
151            }));
152        }
153
154        Ok(())
155    }
156
157    /// Apply update to parameters
158    pub fn apply_update(parameters: &mut [f64], update: &[f64]) -> Result<()> {
159        if parameters.len() != update.len() {
160            return Err(MathError::dimension_mismatch(
161                parameters.len(),
162                update.len(),
163            ));
164        }
165
166        for (p, &u) in parameters.iter_mut().zip(update.iter()) {
167            *p += u;
168        }
169
170        Ok(())
171    }
172
173    /// Full optimization step: compute and apply update
174    pub fn optimize_step(
175        &mut self,
176        parameters: &mut [f64],
177        gradient: &[f64],
178        gradient_samples: Option<&[Vec<f64>]>,
179    ) -> Result<f64> {
180        let update = self.step(gradient, gradient_samples)?;
181
182        let update_norm: f64 = update.iter().map(|&u| u * u).sum::<f64>().sqrt();
183
184        Self::apply_update(parameters, &update)?;
185
186        Ok(update_norm)
187    }
188
189    /// Reset optimizer state
190    pub fn reset(&mut self) {
191        self.fim_estimate = None;
192    }
193}
194
195/// Natural gradient with diagonal preconditioning (AdaGrad-like)
196#[derive(Debug, Clone)]
197pub struct DiagonalNaturalGradient {
198    /// Learning rate
199    learning_rate: f64,
200    /// Damping factor
201    damping: f64,
202    /// Accumulated squared gradients
203    accumulator: Vec<f64>,
204}
205
206impl DiagonalNaturalGradient {
207    /// Create new diagonal natural gradient optimizer
208    pub fn new(learning_rate: f64, dim: usize) -> Self {
209        Self {
210            learning_rate: learning_rate.max(EPS),
211            damping: 1e-8,
212            accumulator: vec![0.0; dim],
213        }
214    }
215
216    /// Set damping factor
217    pub fn with_damping(mut self, damping: f64) -> Self {
218        self.damping = damping.max(EPS);
219        self
220    }
221
222    /// Compute and apply update
223    pub fn step(&mut self, parameters: &mut [f64], gradient: &[f64]) -> Result<f64> {
224        if parameters.len() != gradient.len() || parameters.len() != self.accumulator.len() {
225            return Err(MathError::dimension_mismatch(
226                parameters.len(),
227                gradient.len(),
228            ));
229        }
230
231        let mut update_norm_sq = 0.0;
232
233        for (i, (p, &g)) in parameters.iter_mut().zip(gradient.iter()).enumerate() {
234            // Accumulate squared gradient (Fisher diagonal approximation)
235            self.accumulator[i] += g * g;
236
237            // Natural gradient step
238            let update = -self.learning_rate * g / (self.accumulator[i].sqrt() + self.damping);
239            *p += update;
240            update_norm_sq += update * update;
241        }
242
243        Ok(update_norm_sq.sqrt())
244    }
245
246    /// Reset accumulator
247    pub fn reset(&mut self) {
248        self.accumulator.fill(0.0);
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    #[test]
257    fn test_natural_gradient_step() {
258        let mut ng = NaturalGradient::new(0.1).with_diagonal(true);
259
260        let gradient = vec![1.0, 2.0, 3.0];
261
262        // First step without FIM estimate uses gradient directly
263        let update = ng.step(&gradient, None).unwrap();
264
265        assert_eq!(update.len(), 3);
266        // Should be -lr * gradient
267        assert!((update[0] + 0.1).abs() < 1e-10);
268    }
269
270    #[test]
271    fn test_natural_gradient_with_fim() {
272        let mut ng = NaturalGradient::new(0.1)
273            .with_diagonal(true)
274            .with_damping(0.0);
275
276        let gradient = vec![2.0, 4.0];
277
278        // Provide gradient samples for FIM estimation
279        let samples = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
280
281        let update = ng.step(&gradient, Some(&samples)).unwrap();
282
283        // With FIM, update should be preconditioned
284        assert_eq!(update.len(), 2);
285    }
286
287    #[test]
288    fn test_diagonal_natural_gradient() {
289        let mut dng = DiagonalNaturalGradient::new(1.0, 2);
290
291        let mut params = vec![0.0, 0.0];
292        let gradient = vec![1.0, 2.0];
293
294        let norm = dng.step(&mut params, &gradient).unwrap();
295
296        assert!(norm > 0.0);
297        // Parameters should have moved
298        assert!(params[0] < 0.0); // Moved in negative gradient direction
299    }
300
301    #[test]
302    fn test_optimizer_reset() {
303        let mut ng = NaturalGradient::new(0.1);
304
305        let samples = vec![vec![1.0, 2.0]];
306        let _ = ng.step(&[1.0, 1.0], Some(&samples));
307
308        ng.reset();
309        assert!(ng.fim_estimate.is_none());
310    }
311}