Skip to main content

trustformers_optim/second_order/
lbfgs.rs

1use anyhow::Result;
2use std::collections::{HashMap, VecDeque};
3use trustformers_core::tensor::Tensor;
4
5/// Limited-memory Broyden-Fletcher-Goldfarb-Shanno (L-BFGS) optimizer.
6///
7/// L-BFGS is a quasi-Newton method that approximates the second-order derivative
8/// information using only first-order gradients. It maintains a limited history
9/// of gradient and parameter updates to approximate the inverse Hessian matrix.
10#[derive(Debug)]
11pub struct LBFGS {
12    pub learning_rate: f32,
13    pub history_size: usize,
14    pub line_search_fn: Option<LineSearchMethod>,
15    pub max_iter: usize,
16    pub tolerance_grad: f32,
17    pub tolerance_change: f32,
18
19    // Internal state
20    pub step: usize,
21    pub s_history: VecDeque<HashMap<String, Vec<f32>>>, // parameter differences
22    pub y_history: VecDeque<HashMap<String, Vec<f32>>>, // gradient differences
23    pub rho_history: VecDeque<f32>,                     // 1 / (y^T s)
24    pub prev_params: HashMap<String, Vec<f32>>,
25    pub prev_grads: HashMap<String, Vec<f32>>,
26}
27
28#[derive(Debug, Clone)]
29pub enum LineSearchMethod {
30    None,
31    StrongWolfe,
32    Backtracking,
33}
34
35impl Default for LBFGS {
36    fn default() -> Self {
37        Self {
38            learning_rate: 1.0,
39            history_size: 10,
40            line_search_fn: Some(LineSearchMethod::StrongWolfe),
41            max_iter: 20,
42            tolerance_grad: 1e-7,
43            tolerance_change: 1e-9,
44            step: 0,
45            s_history: VecDeque::new(),
46            y_history: VecDeque::new(),
47            rho_history: VecDeque::new(),
48            prev_params: HashMap::new(),
49            prev_grads: HashMap::new(),
50        }
51    }
52}
53
54impl LBFGS {
55    pub fn new(learning_rate: f32) -> Self {
56        Self {
57            learning_rate,
58            ..Default::default()
59        }
60    }
61
62    pub fn with_config(
63        learning_rate: f32,
64        history_size: usize,
65        line_search_fn: Option<LineSearchMethod>,
66        max_iter: usize,
67    ) -> Self {
68        Self {
69            learning_rate,
70            history_size,
71            line_search_fn,
72            max_iter,
73            ..Default::default()
74        }
75    }
76
77    pub fn step(
78        &mut self,
79        parameters: &mut HashMap<String, Tensor>,
80        gradients: &HashMap<String, Tensor>,
81    ) -> Result<()> {
82        // First step - store current state
83        if self.step == 0 {
84            for (name, param) in parameters.iter() {
85                self.prev_params.insert(name.clone(), param.data()?);
86            }
87            for (name, grad) in gradients.iter() {
88                self.prev_grads.insert(name.clone(), grad.data()?);
89            }
90
91            // Simple gradient descent for first step
92            for (name, param) in parameters.iter_mut() {
93                let grad = gradients
94                    .get(name)
95                    .ok_or_else(|| anyhow::anyhow!("Missing gradient for parameter: {}", name))?;
96                let mut param_data = param.data()?;
97                let grad_data = grad.data()?;
98
99                for i in 0..param_data.len() {
100                    param_data[i] -= self.learning_rate * grad_data[i];
101                }
102
103                *param = Tensor::new(param_data)?;
104            }
105
106            self.step += 1;
107            return Ok(());
108        }
109
110        // Subsequent steps - use L-BFGS
111        let mut s_k = HashMap::new();
112        let mut y_k = HashMap::new();
113
114        // Compute parameter and gradient differences
115        for (name, param) in parameters.iter() {
116            let param_data = param.data()?;
117            let prev_param = self.prev_params.get(name).unwrap();
118
119            let s: Vec<f32> =
120                param_data.iter().zip(prev_param.iter()).map(|(p, prev_p)| p - prev_p).collect();
121            s_k.insert(name.clone(), s);
122        }
123
124        for (name, grad) in gradients.iter() {
125            let grad_data = grad.data()?;
126            let prev_grad = self.prev_grads.get(name).unwrap();
127
128            let y: Vec<f32> =
129                grad_data.iter().zip(prev_grad.iter()).map(|(g, prev_g)| g - prev_g).collect();
130            y_k.insert(name.clone(), y);
131        }
132
133        // Compute rho = 1 / (y^T s)
134        let mut rho = 0.0;
135        for name in parameters.keys() {
136            let s = s_k.get(name).unwrap();
137            let y = y_k.get(name).unwrap();
138
139            rho += s.iter().zip(y.iter()).map(|(s_i, y_i)| s_i * y_i).sum::<f32>();
140        }
141
142        if rho.abs() < 1e-10 {
143            // Skip this update if rho is too small
144            self.step += 1;
145            return Ok(());
146        }
147
148        rho = 1.0 / rho;
149
150        // Store in history
151        self.s_history.push_back(s_k);
152        self.y_history.push_back(y_k);
153        self.rho_history.push_back(rho);
154
155        // Maintain history size
156        if self.s_history.len() > self.history_size {
157            self.s_history.pop_front();
158            self.y_history.pop_front();
159            self.rho_history.pop_front();
160        }
161
162        // Compute search direction using two-loop recursion
163        let search_direction = self.compute_search_direction(gradients)?;
164
165        // Apply update
166        for (name, param) in parameters.iter_mut() {
167            let direction = search_direction.get(name).unwrap();
168            let mut param_data = param.data()?;
169
170            for i in 0..param_data.len() {
171                param_data[i] -= self.learning_rate * direction[i];
172            }
173
174            *param = Tensor::new(param_data)?;
175        }
176
177        // Update stored state
178        for (name, param) in parameters.iter() {
179            self.prev_params.insert(name.clone(), param.data()?);
180        }
181        for (name, grad) in gradients.iter() {
182            self.prev_grads.insert(name.clone(), grad.data()?);
183        }
184
185        self.step += 1;
186        Ok(())
187    }
188
189    fn compute_search_direction(
190        &self,
191        gradients: &HashMap<String, Tensor>,
192    ) -> Result<HashMap<String, Vec<f32>>> {
193        let mut q: HashMap<String, Vec<f32>> = HashMap::new();
194
195        // Initialize q with current gradients
196        for (name, grad) in gradients.iter() {
197            q.insert(name.clone(), grad.data()?);
198        }
199
200        let history_len = self.s_history.len();
201        let mut alpha = vec![0.0; history_len];
202
203        // First loop (backward)
204        for i in (0..history_len).rev() {
205            let rho_i = self.rho_history[i];
206            let s_i = &self.s_history[i];
207
208            let mut alpha_i = 0.0;
209            for name in gradients.keys() {
210                let s_i_param = s_i.get(name).unwrap();
211                let q_param = q.get(name).unwrap();
212
213                alpha_i +=
214                    s_i_param.iter().zip(q_param.iter()).map(|(s, q_val)| s * q_val).sum::<f32>();
215            }
216            alpha_i *= rho_i;
217            alpha[i] = alpha_i;
218
219            // Update q
220            for name in gradients.keys() {
221                let y_i_param = self.y_history[i].get(name).unwrap();
222                let q_param = q.get_mut(name).unwrap();
223
224                for j in 0..q_param.len() {
225                    q_param[j] -= alpha_i * y_i_param[j];
226                }
227            }
228        }
229
230        // Scale by initial Hessian approximation (H_0 = I / gamma)
231        if !self.s_history.is_empty() {
232            let recent_idx = self.s_history.len() - 1;
233            let recent_s = &self.s_history[recent_idx];
234            let recent_y = &self.y_history[recent_idx];
235
236            let mut s_dot_y = 0.0;
237            let mut y_dot_y = 0.0;
238
239            for name in gradients.keys() {
240                let s_param = recent_s.get(name).unwrap();
241                let y_param = recent_y.get(name).unwrap();
242
243                s_dot_y += s_param.iter().zip(y_param.iter()).map(|(s, y)| s * y).sum::<f32>();
244                y_dot_y += y_param.iter().map(|y| y * y).sum::<f32>();
245            }
246
247            if y_dot_y > 1e-10 {
248                let gamma = s_dot_y / y_dot_y;
249                for (_, q_param) in q.iter_mut() {
250                    for val in q_param.iter_mut() {
251                        *val *= gamma;
252                    }
253                }
254            }
255        }
256
257        // Second loop (forward)
258        for i in 0..history_len {
259            let rho_i = self.rho_history[i];
260            let y_i = &self.y_history[i];
261
262            let mut beta = 0.0;
263            for name in gradients.keys() {
264                let y_i_param = y_i.get(name).unwrap();
265                let q_param = q.get(name).unwrap();
266
267                beta +=
268                    y_i_param.iter().zip(q_param.iter()).map(|(y, q_val)| y * q_val).sum::<f32>();
269            }
270            beta *= rho_i;
271
272            let correction = alpha[i] - beta;
273
274            // Update q
275            for name in gradients.keys() {
276                let s_i_param = self.s_history[i].get(name).unwrap();
277                let q_param = q.get_mut(name).unwrap();
278
279                for j in 0..q_param.len() {
280                    q_param[j] += correction * s_i_param[j];
281                }
282            }
283        }
284
285        Ok(q)
286    }
287
288    pub fn reset(&mut self) {
289        self.step = 0;
290        self.s_history.clear();
291        self.y_history.clear();
292        self.rho_history.clear();
293        self.prev_params.clear();
294        self.prev_grads.clear();
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    #[test]
303    fn test_lbfgs_creation() {
304        let optimizer = LBFGS::new(0.01);
305        assert_eq!(optimizer.learning_rate, 0.01);
306        assert_eq!(optimizer.history_size, 10);
307        assert_eq!(optimizer.step, 0);
308    }
309
310    #[test]
311    fn test_lbfgs_with_config() {
312        let optimizer = LBFGS::with_config(0.1, 5, None, 10);
313        assert_eq!(optimizer.learning_rate, 0.1);
314        assert_eq!(optimizer.history_size, 5);
315        assert_eq!(optimizer.max_iter, 10);
316    }
317
318    #[test]
319    fn test_lbfgs_reset() {
320        let mut optimizer = LBFGS::new(0.01);
321        optimizer.step = 5;
322        optimizer.reset();
323        assert_eq!(optimizer.step, 0);
324        assert!(optimizer.s_history.is_empty());
325        assert!(optimizer.y_history.is_empty());
326        assert!(optimizer.rho_history.is_empty());
327    }
328
329    #[test]
330    fn test_lbfgs_first_step() -> Result<(), Box<dyn std::error::Error>> {
331        let mut optimizer = LBFGS::new(0.01);
332        let mut parameters = HashMap::new();
333        let mut gradients = HashMap::new();
334
335        let param_data = vec![1.0, 2.0, 3.0];
336        let grad_data = vec![0.1, 0.2, 0.3];
337
338        parameters.insert(
339            "param1".to_string(),
340            Tensor::new(param_data.clone()).unwrap(),
341        );
342        gradients.insert(
343            "param1".to_string(),
344            Tensor::new(grad_data.clone()).unwrap(),
345        );
346
347        optimizer.step(&mut parameters, &gradients).unwrap();
348
349        assert_eq!(optimizer.step, 1);
350
351        let updated_data = parameters.get("param1").unwrap().data()?;
352        for i in 0..updated_data.len() {
353            let expected = param_data[i] - 0.01 * grad_data[i];
354            assert!((updated_data[i] - expected).abs() < 1e-6);
355        }
356        Ok(())
357    }
358}