ruvector_attention/info_geometry/
natural_gradient.rs1use super::fisher::{FisherMetric, FisherConfig};
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct NaturalGradientConfig {
12 pub lr: f32,
14 pub fisher: FisherConfig,
16 pub use_diagonal: bool,
18}
19
20impl Default for NaturalGradientConfig {
21 fn default() -> Self {
22 Self {
23 lr: 0.1,
24 fisher: FisherConfig::default(),
25 use_diagonal: false,
26 }
27 }
28}
29
30#[derive(Debug, Clone)]
32pub struct NaturalGradient {
33 config: NaturalGradientConfig,
34 fisher: FisherMetric,
35}
36
37impl NaturalGradient {
38 pub fn new(config: NaturalGradientConfig) -> Self {
40 let fisher = FisherMetric::new(config.fisher.clone());
41 Self { config, fisher }
42 }
43
44 pub fn step_logits(&self, logits: &[f32], grad_logits: &[f32]) -> Vec<f32> {
47 let probs = Self::softmax(logits);
48
49 let nat_grad = if self.config.use_diagonal {
51 self.fisher.apply_inverse_approx(&probs, grad_logits)
52 } else {
53 self.fisher.solve_cg(&probs, grad_logits)
54 };
55
56 let mut new_logits = logits.to_vec();
58 for i in 0..new_logits.len() {
59 new_logits[i] -= self.config.lr * nat_grad[i];
60 }
61
62 new_logits
63 }
64
65 pub fn step_diagonal(
68 &self,
69 params: &[f32],
70 grads: &[f32],
71 fisher_diag: &[f32],
72 ) -> Vec<f32> {
73 let n = params.len();
74 let mut new_params = params.to_vec();
75 let eps = self.config.fisher.eps;
76
77 for i in 0..n {
78 let f_inv = 1.0 / (fisher_diag[i].abs() + eps);
79 new_params[i] -= self.config.lr * grads[i] * f_inv;
80 }
81
82 new_params
83 }
84
85 pub fn step_attention_logits(
88 &self,
89 logits: &[f32],
90 grad_logits: &[f32],
91 ) -> Vec<f32> {
92 self.step_logits(logits, grad_logits)
93 }
94
95 fn softmax(logits: &[f32]) -> Vec<f32> {
97 if logits.is_empty() {
98 return vec![];
99 }
100
101 let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
102 let exp_logits: Vec<f32> = logits.iter().map(|&l| (l - max_logit).exp()).collect();
103 let sum: f32 = exp_logits.iter().sum();
104
105 if sum > 0.0 {
106 exp_logits.iter().map(|&e| e / sum).collect()
107 } else {
108 vec![1.0 / logits.len() as f32; logits.len()]
109 }
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116
117 #[test]
118 fn test_natural_gradient_step() {
119 let config = NaturalGradientConfig {
120 lr: 0.1,
121 ..Default::default()
122 };
123 let ng = NaturalGradient::new(config);
124
125 let logits = vec![1.0, 2.0, 0.5, 0.5];
126 let grads = vec![0.1, -0.1, 0.05, -0.05];
127
128 let new_logits = ng.step_logits(&logits, &grads);
129
130 assert_eq!(new_logits.len(), 4);
131 assert!((new_logits[0] - logits[0]).abs() > 1e-6 ||
133 (new_logits[1] - logits[1]).abs() > 1e-6);
134 }
135
136 #[test]
137 fn test_diagonal_step() {
138 let ng = NaturalGradient::new(NaturalGradientConfig::default());
139
140 let params = vec![1.0, 2.0, 3.0];
141 let grads = vec![0.1, 0.1, 0.1]; let fisher_diag = vec![1.0, 2.0, 0.5]; let new_params = ng.step_diagonal(¶ms, &grads, &fisher_diag);
145
146 assert_eq!(new_params.len(), 3);
147 let step0 = (new_params[0] - params[0]).abs();
149 let step1 = (new_params[1] - params[1]).abs();
150 let step2 = (new_params[2] - params[2]).abs();
151 assert!(step1 < step0);
153 assert!(step0 < step2);
154 }
155
156 #[test]
157 fn test_attention_logits_step() {
158 let ng = NaturalGradient::new(NaturalGradientConfig::default());
159
160 let logits = vec![0.0; 10];
161 let grads = vec![0.1; 10];
162
163 let new_logits = ng.step_attention_logits(&logits, &grads);
164
165 assert_eq!(new_logits.len(), 10);
166 }
167}