ruvector_attention/info_geometry/
natural_gradient.rs1use super::fisher::{FisherConfig, FisherMetric};
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct NaturalGradientConfig {
12 pub lr: f32,
14 pub fisher: FisherConfig,
16 pub use_diagonal: bool,
18}
19
20impl Default for NaturalGradientConfig {
21 fn default() -> Self {
22 Self {
23 lr: 0.1,
24 fisher: FisherConfig::default(),
25 use_diagonal: false,
26 }
27 }
28}
29
30#[derive(Debug, Clone)]
32pub struct NaturalGradient {
33 config: NaturalGradientConfig,
34 fisher: FisherMetric,
35}
36
37impl NaturalGradient {
38 pub fn new(config: NaturalGradientConfig) -> Self {
40 let fisher = FisherMetric::new(config.fisher.clone());
41 Self { config, fisher }
42 }
43
44 pub fn step_logits(&self, logits: &[f32], grad_logits: &[f32]) -> Vec<f32> {
47 let probs = Self::softmax(logits);
48
49 let nat_grad = if self.config.use_diagonal {
51 self.fisher.apply_inverse_approx(&probs, grad_logits)
52 } else {
53 self.fisher.solve_cg(&probs, grad_logits)
54 };
55
56 let mut new_logits = logits.to_vec();
58 for i in 0..new_logits.len() {
59 new_logits[i] -= self.config.lr * nat_grad[i];
60 }
61
62 new_logits
63 }
64
65 pub fn step_diagonal(&self, params: &[f32], grads: &[f32], fisher_diag: &[f32]) -> Vec<f32> {
68 let n = params.len();
69 let mut new_params = params.to_vec();
70 let eps = self.config.fisher.eps;
71
72 for i in 0..n {
73 let f_inv = 1.0 / (fisher_diag[i].abs() + eps);
74 new_params[i] -= self.config.lr * grads[i] * f_inv;
75 }
76
77 new_params
78 }
79
80 pub fn step_attention_logits(&self, logits: &[f32], grad_logits: &[f32]) -> Vec<f32> {
83 self.step_logits(logits, grad_logits)
84 }
85
86 fn softmax(logits: &[f32]) -> Vec<f32> {
88 if logits.is_empty() {
89 return vec![];
90 }
91
92 let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
93 let exp_logits: Vec<f32> = logits.iter().map(|&l| (l - max_logit).exp()).collect();
94 let sum: f32 = exp_logits.iter().sum();
95
96 if sum > 0.0 {
97 exp_logits.iter().map(|&e| e / sum).collect()
98 } else {
99 vec![1.0 / logits.len() as f32; logits.len()]
100 }
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107
108 #[test]
109 fn test_natural_gradient_step() {
110 let config = NaturalGradientConfig {
111 lr: 0.1,
112 ..Default::default()
113 };
114 let ng = NaturalGradient::new(config);
115
116 let logits = vec![1.0, 2.0, 0.5, 0.5];
117 let grads = vec![0.1, -0.1, 0.05, -0.05];
118
119 let new_logits = ng.step_logits(&logits, &grads);
120
121 assert_eq!(new_logits.len(), 4);
122 assert!(
124 (new_logits[0] - logits[0]).abs() > 1e-6 || (new_logits[1] - logits[1]).abs() > 1e-6
125 );
126 }
127
128 #[test]
129 fn test_diagonal_step() {
130 let ng = NaturalGradient::new(NaturalGradientConfig::default());
131
132 let params = vec![1.0, 2.0, 3.0];
133 let grads = vec![0.1, 0.1, 0.1]; let fisher_diag = vec![1.0, 2.0, 0.5]; let new_params = ng.step_diagonal(¶ms, &grads, &fisher_diag);
137
138 assert_eq!(new_params.len(), 3);
139 let step0 = (new_params[0] - params[0]).abs();
141 let step1 = (new_params[1] - params[1]).abs();
142 let step2 = (new_params[2] - params[2]).abs();
143 assert!(step1 < step0);
145 assert!(step0 < step2);
146 }
147
148 #[test]
149 fn test_attention_logits_step() {
150 let ng = NaturalGradient::new(NaturalGradientConfig::default());
151
152 let logits = vec![0.0; 10];
153 let grads = vec![0.1; 10];
154
155 let new_logits = ng.step_attention_logits(&logits, &grads);
156
157 assert_eq!(new_logits.len(), 10);
158 }
159}