ruvector_math/information_geometry/
natural_gradient.rs1use crate::error::{MathError, Result};
19use crate::utils::EPS;
20use super::FisherInformation;
21
22#[derive(Debug, Clone)]
24pub struct NaturalGradient {
25 learning_rate: f64,
27 damping: f64,
29 use_diagonal: bool,
31 ema_factor: f64,
33 fim_estimate: Option<FimEstimate>,
35}
36
37#[derive(Debug, Clone)]
38enum FimEstimate {
39 Full(Vec<Vec<f64>>),
40 Diagonal(Vec<f64>),
41}
42
43impl NaturalGradient {
44 pub fn new(learning_rate: f64) -> Self {
49 Self {
50 learning_rate: learning_rate.max(EPS),
51 damping: 1e-4,
52 use_diagonal: false,
53 ema_factor: 0.9,
54 fim_estimate: None,
55 }
56 }
57
58 pub fn with_damping(mut self, damping: f64) -> Self {
60 self.damping = damping.max(EPS);
61 self
62 }
63
64 pub fn with_diagonal(mut self, use_diagonal: bool) -> Self {
66 self.use_diagonal = use_diagonal;
67 self
68 }
69
70 pub fn with_ema(mut self, ema: f64) -> Self {
72 self.ema_factor = ema.clamp(0.0, 1.0);
73 self
74 }
75
76 pub fn step(
82 &mut self,
83 gradient: &[f64],
84 gradient_samples: Option<&[Vec<f64>]>,
85 ) -> Result<Vec<f64>> {
86 if let Some(samples) = gradient_samples {
88 self.update_fim(samples)?;
89 }
90
91 let nat_grad = match &self.fim_estimate {
93 Some(FimEstimate::Full(fim)) => {
94 let fisher = FisherInformation::new().with_damping(self.damping);
95 fisher.natural_gradient(fim, gradient)?
96 }
97 Some(FimEstimate::Diagonal(diag)) => {
98 gradient
100 .iter()
101 .zip(diag.iter())
102 .map(|(&g, &d)| g / (d + self.damping))
103 .collect()
104 }
105 None => {
106 gradient.to_vec()
108 }
109 };
110
111 Ok(nat_grad.iter().map(|&g| -self.learning_rate * g).collect())
113 }
114
115 fn update_fim(&mut self, gradient_samples: &[Vec<f64>]) -> Result<()> {
117 let fisher = FisherInformation::new().with_damping(0.0);
118
119 if self.use_diagonal {
120 let new_diag = fisher.diagonal_fim(gradient_samples)?;
121
122 self.fim_estimate = Some(FimEstimate::Diagonal(match &self.fim_estimate {
123 Some(FimEstimate::Diagonal(old)) => {
124 old.iter()
126 .zip(new_diag.iter())
127 .map(|(&o, &n)| self.ema_factor * o + (1.0 - self.ema_factor) * n)
128 .collect()
129 }
130 _ => new_diag,
131 }));
132 } else {
133 let new_fim = fisher.empirical_fim(gradient_samples)?;
134 let dim = new_fim.len();
135
136 self.fim_estimate = Some(FimEstimate::Full(match &self.fim_estimate {
137 Some(FimEstimate::Full(old)) if old.len() == dim => {
138 (0..dim)
140 .map(|i| {
141 (0..dim)
142 .map(|j| {
143 self.ema_factor * old[i][j]
144 + (1.0 - self.ema_factor) * new_fim[i][j]
145 })
146 .collect()
147 })
148 .collect()
149 }
150 _ => new_fim,
151 }));
152 }
153
154 Ok(())
155 }
156
157 pub fn apply_update(parameters: &mut [f64], update: &[f64]) -> Result<()> {
159 if parameters.len() != update.len() {
160 return Err(MathError::dimension_mismatch(parameters.len(), update.len()));
161 }
162
163 for (p, &u) in parameters.iter_mut().zip(update.iter()) {
164 *p += u;
165 }
166
167 Ok(())
168 }
169
170 pub fn optimize_step(
172 &mut self,
173 parameters: &mut [f64],
174 gradient: &[f64],
175 gradient_samples: Option<&[Vec<f64>]>,
176 ) -> Result<f64> {
177 let update = self.step(gradient, gradient_samples)?;
178
179 let update_norm: f64 = update.iter().map(|&u| u * u).sum::<f64>().sqrt();
180
181 Self::apply_update(parameters, &update)?;
182
183 Ok(update_norm)
184 }
185
186 pub fn reset(&mut self) {
188 self.fim_estimate = None;
189 }
190}
191
192#[derive(Debug, Clone)]
194pub struct DiagonalNaturalGradient {
195 learning_rate: f64,
197 damping: f64,
199 accumulator: Vec<f64>,
201}
202
203impl DiagonalNaturalGradient {
204 pub fn new(learning_rate: f64, dim: usize) -> Self {
206 Self {
207 learning_rate: learning_rate.max(EPS),
208 damping: 1e-8,
209 accumulator: vec![0.0; dim],
210 }
211 }
212
213 pub fn with_damping(mut self, damping: f64) -> Self {
215 self.damping = damping.max(EPS);
216 self
217 }
218
219 pub fn step(&mut self, parameters: &mut [f64], gradient: &[f64]) -> Result<f64> {
221 if parameters.len() != gradient.len() || parameters.len() != self.accumulator.len() {
222 return Err(MathError::dimension_mismatch(
223 parameters.len(),
224 gradient.len(),
225 ));
226 }
227
228 let mut update_norm_sq = 0.0;
229
230 for (i, (p, &g)) in parameters.iter_mut().zip(gradient.iter()).enumerate() {
231 self.accumulator[i] += g * g;
233
234 let update = -self.learning_rate * g / (self.accumulator[i].sqrt() + self.damping);
236 *p += update;
237 update_norm_sq += update * update;
238 }
239
240 Ok(update_norm_sq.sqrt())
241 }
242
243 pub fn reset(&mut self) {
245 self.accumulator.fill(0.0);
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 #[test]
254 fn test_natural_gradient_step() {
255 let mut ng = NaturalGradient::new(0.1).with_diagonal(true);
256
257 let gradient = vec![1.0, 2.0, 3.0];
258
259 let update = ng.step(&gradient, None).unwrap();
261
262 assert_eq!(update.len(), 3);
263 assert!((update[0] + 0.1).abs() < 1e-10);
265 }
266
267 #[test]
268 fn test_natural_gradient_with_fim() {
269 let mut ng = NaturalGradient::new(0.1).with_diagonal(true).with_damping(0.0);
270
271 let gradient = vec![2.0, 4.0];
272
273 let samples = vec![
275 vec![1.0, 0.0],
276 vec![0.0, 1.0],
277 vec![1.0, 1.0],
278 ];
279
280 let update = ng.step(&gradient, Some(&samples)).unwrap();
281
282 assert_eq!(update.len(), 2);
284 }
285
286 #[test]
287 fn test_diagonal_natural_gradient() {
288 let mut dng = DiagonalNaturalGradient::new(1.0, 2);
289
290 let mut params = vec![0.0, 0.0];
291 let gradient = vec![1.0, 2.0];
292
293 let norm = dng.step(&mut params, &gradient).unwrap();
294
295 assert!(norm > 0.0);
296 assert!(params[0] < 0.0); }
299
300 #[test]
301 fn test_optimizer_reset() {
302 let mut ng = NaturalGradient::new(0.1);
303
304 let samples = vec![vec![1.0, 2.0]];
305 let _ = ng.step(&[1.0, 1.0], Some(&samples));
306
307 ng.reset();
308 assert!(ng.fim_estimate.is_none());
309 }
310}