trustformers_optim/second_order/
lbfgs.rs1use anyhow::Result;
2use std::collections::{HashMap, VecDeque};
3use trustformers_core::tensor::Tensor;
4
5#[derive(Debug)]
11pub struct LBFGS {
12 pub learning_rate: f32,
13 pub history_size: usize,
14 pub line_search_fn: Option<LineSearchMethod>,
15 pub max_iter: usize,
16 pub tolerance_grad: f32,
17 pub tolerance_change: f32,
18
19 pub step: usize,
21 pub s_history: VecDeque<HashMap<String, Vec<f32>>>, pub y_history: VecDeque<HashMap<String, Vec<f32>>>, pub rho_history: VecDeque<f32>, pub prev_params: HashMap<String, Vec<f32>>,
25 pub prev_grads: HashMap<String, Vec<f32>>,
26}
27
28#[derive(Debug, Clone)]
29pub enum LineSearchMethod {
30 None,
31 StrongWolfe,
32 Backtracking,
33}
34
35impl Default for LBFGS {
36 fn default() -> Self {
37 Self {
38 learning_rate: 1.0,
39 history_size: 10,
40 line_search_fn: Some(LineSearchMethod::StrongWolfe),
41 max_iter: 20,
42 tolerance_grad: 1e-7,
43 tolerance_change: 1e-9,
44 step: 0,
45 s_history: VecDeque::new(),
46 y_history: VecDeque::new(),
47 rho_history: VecDeque::new(),
48 prev_params: HashMap::new(),
49 prev_grads: HashMap::new(),
50 }
51 }
52}
53
54impl LBFGS {
55 pub fn new(learning_rate: f32) -> Self {
56 Self {
57 learning_rate,
58 ..Default::default()
59 }
60 }
61
62 pub fn with_config(
63 learning_rate: f32,
64 history_size: usize,
65 line_search_fn: Option<LineSearchMethod>,
66 max_iter: usize,
67 ) -> Self {
68 Self {
69 learning_rate,
70 history_size,
71 line_search_fn,
72 max_iter,
73 ..Default::default()
74 }
75 }
76
77 pub fn step(
78 &mut self,
79 parameters: &mut HashMap<String, Tensor>,
80 gradients: &HashMap<String, Tensor>,
81 ) -> Result<()> {
82 if self.step == 0 {
84 for (name, param) in parameters.iter() {
85 self.prev_params.insert(name.clone(), param.data()?);
86 }
87 for (name, grad) in gradients.iter() {
88 self.prev_grads.insert(name.clone(), grad.data()?);
89 }
90
91 for (name, param) in parameters.iter_mut() {
93 let grad = gradients
94 .get(name)
95 .ok_or_else(|| anyhow::anyhow!("Missing gradient for parameter: {}", name))?;
96 let mut param_data = param.data()?;
97 let grad_data = grad.data()?;
98
99 for i in 0..param_data.len() {
100 param_data[i] -= self.learning_rate * grad_data[i];
101 }
102
103 *param = Tensor::new(param_data)?;
104 }
105
106 self.step += 1;
107 return Ok(());
108 }
109
110 let mut s_k = HashMap::new();
112 let mut y_k = HashMap::new();
113
114 for (name, param) in parameters.iter() {
116 let param_data = param.data()?;
117 let prev_param = self.prev_params.get(name).unwrap();
118
119 let s: Vec<f32> =
120 param_data.iter().zip(prev_param.iter()).map(|(p, prev_p)| p - prev_p).collect();
121 s_k.insert(name.clone(), s);
122 }
123
124 for (name, grad) in gradients.iter() {
125 let grad_data = grad.data()?;
126 let prev_grad = self.prev_grads.get(name).unwrap();
127
128 let y: Vec<f32> =
129 grad_data.iter().zip(prev_grad.iter()).map(|(g, prev_g)| g - prev_g).collect();
130 y_k.insert(name.clone(), y);
131 }
132
133 let mut rho = 0.0;
135 for name in parameters.keys() {
136 let s = s_k.get(name).unwrap();
137 let y = y_k.get(name).unwrap();
138
139 rho += s.iter().zip(y.iter()).map(|(s_i, y_i)| s_i * y_i).sum::<f32>();
140 }
141
142 if rho.abs() < 1e-10 {
143 self.step += 1;
145 return Ok(());
146 }
147
148 rho = 1.0 / rho;
149
150 self.s_history.push_back(s_k);
152 self.y_history.push_back(y_k);
153 self.rho_history.push_back(rho);
154
155 if self.s_history.len() > self.history_size {
157 self.s_history.pop_front();
158 self.y_history.pop_front();
159 self.rho_history.pop_front();
160 }
161
162 let search_direction = self.compute_search_direction(gradients)?;
164
165 for (name, param) in parameters.iter_mut() {
167 let direction = search_direction.get(name).unwrap();
168 let mut param_data = param.data()?;
169
170 for i in 0..param_data.len() {
171 param_data[i] -= self.learning_rate * direction[i];
172 }
173
174 *param = Tensor::new(param_data)?;
175 }
176
177 for (name, param) in parameters.iter() {
179 self.prev_params.insert(name.clone(), param.data()?);
180 }
181 for (name, grad) in gradients.iter() {
182 self.prev_grads.insert(name.clone(), grad.data()?);
183 }
184
185 self.step += 1;
186 Ok(())
187 }
188
189 fn compute_search_direction(
190 &self,
191 gradients: &HashMap<String, Tensor>,
192 ) -> Result<HashMap<String, Vec<f32>>> {
193 let mut q: HashMap<String, Vec<f32>> = HashMap::new();
194
195 for (name, grad) in gradients.iter() {
197 q.insert(name.clone(), grad.data()?);
198 }
199
200 let history_len = self.s_history.len();
201 let mut alpha = vec![0.0; history_len];
202
203 for i in (0..history_len).rev() {
205 let rho_i = self.rho_history[i];
206 let s_i = &self.s_history[i];
207
208 let mut alpha_i = 0.0;
209 for name in gradients.keys() {
210 let s_i_param = s_i.get(name).unwrap();
211 let q_param = q.get(name).unwrap();
212
213 alpha_i +=
214 s_i_param.iter().zip(q_param.iter()).map(|(s, q_val)| s * q_val).sum::<f32>();
215 }
216 alpha_i *= rho_i;
217 alpha[i] = alpha_i;
218
219 for name in gradients.keys() {
221 let y_i_param = self.y_history[i].get(name).unwrap();
222 let q_param = q.get_mut(name).unwrap();
223
224 for j in 0..q_param.len() {
225 q_param[j] -= alpha_i * y_i_param[j];
226 }
227 }
228 }
229
230 if !self.s_history.is_empty() {
232 let recent_idx = self.s_history.len() - 1;
233 let recent_s = &self.s_history[recent_idx];
234 let recent_y = &self.y_history[recent_idx];
235
236 let mut s_dot_y = 0.0;
237 let mut y_dot_y = 0.0;
238
239 for name in gradients.keys() {
240 let s_param = recent_s.get(name).unwrap();
241 let y_param = recent_y.get(name).unwrap();
242
243 s_dot_y += s_param.iter().zip(y_param.iter()).map(|(s, y)| s * y).sum::<f32>();
244 y_dot_y += y_param.iter().map(|y| y * y).sum::<f32>();
245 }
246
247 if y_dot_y > 1e-10 {
248 let gamma = s_dot_y / y_dot_y;
249 for (_, q_param) in q.iter_mut() {
250 for val in q_param.iter_mut() {
251 *val *= gamma;
252 }
253 }
254 }
255 }
256
257 for i in 0..history_len {
259 let rho_i = self.rho_history[i];
260 let y_i = &self.y_history[i];
261
262 let mut beta = 0.0;
263 for name in gradients.keys() {
264 let y_i_param = y_i.get(name).unwrap();
265 let q_param = q.get(name).unwrap();
266
267 beta +=
268 y_i_param.iter().zip(q_param.iter()).map(|(y, q_val)| y * q_val).sum::<f32>();
269 }
270 beta *= rho_i;
271
272 let correction = alpha[i] - beta;
273
274 for name in gradients.keys() {
276 let s_i_param = self.s_history[i].get(name).unwrap();
277 let q_param = q.get_mut(name).unwrap();
278
279 for j in 0..q_param.len() {
280 q_param[j] += correction * s_i_param[j];
281 }
282 }
283 }
284
285 Ok(q)
286 }
287
288 pub fn reset(&mut self) {
289 self.step = 0;
290 self.s_history.clear();
291 self.y_history.clear();
292 self.rho_history.clear();
293 self.prev_params.clear();
294 self.prev_grads.clear();
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301
302 #[test]
303 fn test_lbfgs_creation() {
304 let optimizer = LBFGS::new(0.01);
305 assert_eq!(optimizer.learning_rate, 0.01);
306 assert_eq!(optimizer.history_size, 10);
307 assert_eq!(optimizer.step, 0);
308 }
309
310 #[test]
311 fn test_lbfgs_with_config() {
312 let optimizer = LBFGS::with_config(0.1, 5, None, 10);
313 assert_eq!(optimizer.learning_rate, 0.1);
314 assert_eq!(optimizer.history_size, 5);
315 assert_eq!(optimizer.max_iter, 10);
316 }
317
318 #[test]
319 fn test_lbfgs_reset() {
320 let mut optimizer = LBFGS::new(0.01);
321 optimizer.step = 5;
322 optimizer.reset();
323 assert_eq!(optimizer.step, 0);
324 assert!(optimizer.s_history.is_empty());
325 assert!(optimizer.y_history.is_empty());
326 assert!(optimizer.rho_history.is_empty());
327 }
328
329 #[test]
330 fn test_lbfgs_first_step() -> Result<(), Box<dyn std::error::Error>> {
331 let mut optimizer = LBFGS::new(0.01);
332 let mut parameters = HashMap::new();
333 let mut gradients = HashMap::new();
334
335 let param_data = vec![1.0, 2.0, 3.0];
336 let grad_data = vec![0.1, 0.2, 0.3];
337
338 parameters.insert(
339 "param1".to_string(),
340 Tensor::new(param_data.clone()).unwrap(),
341 );
342 gradients.insert(
343 "param1".to_string(),
344 Tensor::new(grad_data.clone()).unwrap(),
345 );
346
347 optimizer.step(&mut parameters, &gradients).unwrap();
348
349 assert_eq!(optimizer.step, 1);
350
351 let updated_data = parameters.get("param1").unwrap().data()?;
352 for i in 0..updated_data.len() {
353 let expected = param_data[i] - 0.01 * grad_data[i];
354 assert!((updated_data[i] - expected).abs() < 1e-6);
355 }
356 Ok(())
357 }
358}