ruvector_math/information_geometry/
natural_gradient.rs1use super::FisherInformation;
19use crate::error::{MathError, Result};
20use crate::utils::EPS;
21
22#[derive(Debug, Clone)]
24pub struct NaturalGradient {
25 learning_rate: f64,
27 damping: f64,
29 use_diagonal: bool,
31 ema_factor: f64,
33 fim_estimate: Option<FimEstimate>,
35}
36
37#[derive(Debug, Clone)]
38enum FimEstimate {
39 Full(Vec<Vec<f64>>),
40 Diagonal(Vec<f64>),
41}
42
43impl NaturalGradient {
44 pub fn new(learning_rate: f64) -> Self {
49 Self {
50 learning_rate: learning_rate.max(EPS),
51 damping: 1e-4,
52 use_diagonal: false,
53 ema_factor: 0.9,
54 fim_estimate: None,
55 }
56 }
57
58 pub fn with_damping(mut self, damping: f64) -> Self {
60 self.damping = damping.max(EPS);
61 self
62 }
63
64 pub fn with_diagonal(mut self, use_diagonal: bool) -> Self {
66 self.use_diagonal = use_diagonal;
67 self
68 }
69
70 pub fn with_ema(mut self, ema: f64) -> Self {
72 self.ema_factor = ema.clamp(0.0, 1.0);
73 self
74 }
75
76 pub fn step(
82 &mut self,
83 gradient: &[f64],
84 gradient_samples: Option<&[Vec<f64>]>,
85 ) -> Result<Vec<f64>> {
86 if let Some(samples) = gradient_samples {
88 self.update_fim(samples)?;
89 }
90
91 let nat_grad = match &self.fim_estimate {
93 Some(FimEstimate::Full(fim)) => {
94 let fisher = FisherInformation::new().with_damping(self.damping);
95 fisher.natural_gradient(fim, gradient)?
96 }
97 Some(FimEstimate::Diagonal(diag)) => {
98 gradient
100 .iter()
101 .zip(diag.iter())
102 .map(|(&g, &d)| g / (d + self.damping))
103 .collect()
104 }
105 None => {
106 gradient.to_vec()
108 }
109 };
110
111 Ok(nat_grad.iter().map(|&g| -self.learning_rate * g).collect())
113 }
114
115 fn update_fim(&mut self, gradient_samples: &[Vec<f64>]) -> Result<()> {
117 let fisher = FisherInformation::new().with_damping(0.0);
118
119 if self.use_diagonal {
120 let new_diag = fisher.diagonal_fim(gradient_samples)?;
121
122 self.fim_estimate = Some(FimEstimate::Diagonal(match &self.fim_estimate {
123 Some(FimEstimate::Diagonal(old)) => {
124 old.iter()
126 .zip(new_diag.iter())
127 .map(|(&o, &n)| self.ema_factor * o + (1.0 - self.ema_factor) * n)
128 .collect()
129 }
130 _ => new_diag,
131 }));
132 } else {
133 let new_fim = fisher.empirical_fim(gradient_samples)?;
134 let dim = new_fim.len();
135
136 self.fim_estimate = Some(FimEstimate::Full(match &self.fim_estimate {
137 Some(FimEstimate::Full(old)) if old.len() == dim => {
138 (0..dim)
140 .map(|i| {
141 (0..dim)
142 .map(|j| {
143 self.ema_factor * old[i][j]
144 + (1.0 - self.ema_factor) * new_fim[i][j]
145 })
146 .collect()
147 })
148 .collect()
149 }
150 _ => new_fim,
151 }));
152 }
153
154 Ok(())
155 }
156
157 pub fn apply_update(parameters: &mut [f64], update: &[f64]) -> Result<()> {
159 if parameters.len() != update.len() {
160 return Err(MathError::dimension_mismatch(
161 parameters.len(),
162 update.len(),
163 ));
164 }
165
166 for (p, &u) in parameters.iter_mut().zip(update.iter()) {
167 *p += u;
168 }
169
170 Ok(())
171 }
172
173 pub fn optimize_step(
175 &mut self,
176 parameters: &mut [f64],
177 gradient: &[f64],
178 gradient_samples: Option<&[Vec<f64>]>,
179 ) -> Result<f64> {
180 let update = self.step(gradient, gradient_samples)?;
181
182 let update_norm: f64 = update.iter().map(|&u| u * u).sum::<f64>().sqrt();
183
184 Self::apply_update(parameters, &update)?;
185
186 Ok(update_norm)
187 }
188
189 pub fn reset(&mut self) {
191 self.fim_estimate = None;
192 }
193}
194
195#[derive(Debug, Clone)]
197pub struct DiagonalNaturalGradient {
198 learning_rate: f64,
200 damping: f64,
202 accumulator: Vec<f64>,
204}
205
206impl DiagonalNaturalGradient {
207 pub fn new(learning_rate: f64, dim: usize) -> Self {
209 Self {
210 learning_rate: learning_rate.max(EPS),
211 damping: 1e-8,
212 accumulator: vec![0.0; dim],
213 }
214 }
215
216 pub fn with_damping(mut self, damping: f64) -> Self {
218 self.damping = damping.max(EPS);
219 self
220 }
221
222 pub fn step(&mut self, parameters: &mut [f64], gradient: &[f64]) -> Result<f64> {
224 if parameters.len() != gradient.len() || parameters.len() != self.accumulator.len() {
225 return Err(MathError::dimension_mismatch(
226 parameters.len(),
227 gradient.len(),
228 ));
229 }
230
231 let mut update_norm_sq = 0.0;
232
233 for (i, (p, &g)) in parameters.iter_mut().zip(gradient.iter()).enumerate() {
234 self.accumulator[i] += g * g;
236
237 let update = -self.learning_rate * g / (self.accumulator[i].sqrt() + self.damping);
239 *p += update;
240 update_norm_sq += update * update;
241 }
242
243 Ok(update_norm_sq.sqrt())
244 }
245
246 pub fn reset(&mut self) {
248 self.accumulator.fill(0.0);
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255
256 #[test]
257 fn test_natural_gradient_step() {
258 let mut ng = NaturalGradient::new(0.1).with_diagonal(true);
259
260 let gradient = vec![1.0, 2.0, 3.0];
261
262 let update = ng.step(&gradient, None).unwrap();
264
265 assert_eq!(update.len(), 3);
266 assert!((update[0] + 0.1).abs() < 1e-10);
268 }
269
270 #[test]
271 fn test_natural_gradient_with_fim() {
272 let mut ng = NaturalGradient::new(0.1)
273 .with_diagonal(true)
274 .with_damping(0.0);
275
276 let gradient = vec![2.0, 4.0];
277
278 let samples = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
280
281 let update = ng.step(&gradient, Some(&samples)).unwrap();
282
283 assert_eq!(update.len(), 2);
285 }
286
287 #[test]
288 fn test_diagonal_natural_gradient() {
289 let mut dng = DiagonalNaturalGradient::new(1.0, 2);
290
291 let mut params = vec![0.0, 0.0];
292 let gradient = vec![1.0, 2.0];
293
294 let norm = dng.step(&mut params, &gradient).unwrap();
295
296 assert!(norm > 0.0);
297 assert!(params[0] < 0.0); }
300
301 #[test]
302 fn test_optimizer_reset() {
303 let mut ng = NaturalGradient::new(0.1);
304
305 let samples = vec![vec![1.0, 2.0]];
306 let _ = ng.step(&[1.0, 1.0], Some(&samples));
307
308 ng.reset();
309 assert!(ng.fim_estimate.is_none());
310 }
311}