1use scirs2_core::ndarray::{Array1, ScalarOperand};
6use scirs2_core::numeric::{Float, FromPrimitive};
7use std::fmt::{Debug, Display};
8
9use crate::error::{Result, TimeSeriesError};
10
11#[derive(Debug, Clone)]
13pub struct OptimizationOptions<F> {
14 pub max_iter: usize,
16 pub tolerance: F,
18 pub initial_step: F,
20 pub line_search_alpha: F,
22 pub line_search_beta: F,
24 pub grad_tolerance: F,
26}
27
28impl<F: Float + FromPrimitive> Default for OptimizationOptions<F> {
29 fn default() -> Self {
30 Self {
31 max_iter: 1000,
32 tolerance: F::from(1e-8).unwrap(),
33 initial_step: F::from(0.1).unwrap(),
34 line_search_alpha: F::from(0.3).unwrap(),
35 line_search_beta: F::from(0.8).unwrap(),
36 grad_tolerance: F::from(1e-6).unwrap(),
37 }
38 }
39}
40
41#[derive(Debug, Clone)]
43pub struct OptimizationResult<F> {
44 pub x: Array1<F>,
46 pub fval: F,
48 pub iterations: usize,
50 pub converged: bool,
52 pub grad_norm: F,
54}
55
56pub struct LBFGSOptimizer<F> {
58 options: OptimizationOptions<F>,
59 memory_size: usize,
60 s_history: Vec<Array1<F>>,
61 y_history: Vec<Array1<F>>,
62 rho_history: Vec<F>,
63}
64
65impl<F> LBFGSOptimizer<F>
66where
67 F: Float + FromPrimitive + Debug + Display + ScalarOperand,
68{
69 pub fn new(options: OptimizationOptions<F>) -> Self {
71 Self {
72 options,
73 memory_size: 10,
74 s_history: Vec::new(),
75 y_history: Vec::new(),
76 rho_history: Vec::new(),
77 }
78 }
79
80 pub fn optimize<Func, Grad>(
82 &mut self,
83 f: Func,
84 grad: Grad,
85 x0: &Array1<F>,
86 ) -> Result<OptimizationResult<F>>
87 where
88 Func: Fn(&Array1<F>) -> F,
89 Grad: Fn(&Array1<F>) -> Array1<F>,
90 {
91 let mut x = x0.clone();
92 let mut g = grad(&x);
93 let mut fval = f(&x);
94
95 self.s_history.clear();
96 self.y_history.clear();
97 self.rho_history.clear();
98
99 for iter in 0..self.options.max_iter {
100 let grad_norm = g.dot(&g).sqrt();
102 if grad_norm < self.options.grad_tolerance {
103 return Ok(OptimizationResult {
104 x,
105 fval,
106 iterations: iter,
107 converged: true,
108 grad_norm,
109 });
110 }
111
112 let d = self.compute_direction(&g)?;
114
115 let alpha = line_search_armijo(&x, &d, &f, &grad, &self.options)?;
117
118 let x_new = &x + &(&d * alpha);
120 let g_new = grad(&x_new);
121 let fval_new = f(&x_new);
122
123 let s = &x_new - &x;
125 let y = &g_new - &g;
126 let rho = F::one() / y.dot(&s);
127
128 if self.s_history.len() >= self.memory_size {
129 self.s_history.remove(0);
130 self.y_history.remove(0);
131 self.rho_history.remove(0);
132 }
133
134 self.s_history.push(s);
135 self.y_history.push(y);
136 self.rho_history.push(rho);
137
138 if (fval - fval_new).abs() < self.options.tolerance {
140 return Ok(OptimizationResult {
141 x: x_new,
142 fval: fval_new,
143 iterations: iter + 1,
144 converged: true,
145 grad_norm,
146 });
147 }
148
149 x = x_new;
150 g = g_new;
151 fval = fval_new;
152 }
153
154 Ok(OptimizationResult {
155 x,
156 fval,
157 iterations: self.options.max_iter,
158 converged: false,
159 grad_norm: g.dot(&g).sqrt(),
160 })
161 }
162
163 fn compute_direction(&self, g: &Array1<F>) -> Result<Array1<F>> {
165 if self.s_history.is_empty() {
166 return Ok(g.mapv(|x| -x));
167 }
168
169 let mut q = g.clone();
170 let mut alpha = vec![F::zero(); self.s_history.len()];
171
172 for i in (0..self.s_history.len()).rev() {
174 alpha[i] = self.rho_history[i] * self.s_history[i].dot(&q);
175 q = &q - &(&self.y_history[i] * alpha[i]);
176 }
177
178 let gamma = self
180 .s_history
181 .last()
182 .unwrap()
183 .dot(self.y_history.last().unwrap())
184 / self
185 .y_history
186 .last()
187 .unwrap()
188 .dot(self.y_history.last().unwrap());
189 let mut r = &q * gamma;
190
191 for (i, alpha_val) in alpha.iter().enumerate() {
193 let beta = self.rho_history[i] * self.y_history[i].dot(&r);
194 r = &r + &(&self.s_history[i] * (*alpha_val - beta));
195 }
196
197 Ok(r.mapv(|x| -x))
198 }
199}
200
201pub struct BFGSOptimizer<F> {
203 options: OptimizationOptions<F>,
204 h_inv: Option<Array2<F>>,
205}
206
207use scirs2_core::ndarray::Array2;
208
209impl<F> BFGSOptimizer<F>
210where
211 F: Float + FromPrimitive + Debug + Display + ScalarOperand,
212{
213 pub fn new(options: OptimizationOptions<F>) -> Self {
215 Self {
216 options,
217 h_inv: None,
218 }
219 }
220
221 pub fn optimize<Func, Grad>(
223 &mut self,
224 f: Func,
225 grad: Grad,
226 x0: &Array1<F>,
227 ) -> Result<OptimizationResult<F>>
228 where
229 Func: Fn(&Array1<F>) -> F,
230 Grad: Fn(&Array1<F>) -> Array1<F>,
231 {
232 let n = x0.len();
233 let mut x = x0.clone();
234 let mut g = grad(&x);
235 let mut fval = f(&x);
236
237 if self.h_inv.is_none() {
239 self.h_inv = Some(Array2::eye(n));
240 }
241 let h_inv = self.h_inv.as_mut().unwrap();
242
243 for iter in 0..self.options.max_iter {
244 let grad_norm = g.dot(&g).sqrt();
246 if grad_norm < self.options.grad_tolerance {
247 return Ok(OptimizationResult {
248 x,
249 fval,
250 iterations: iter,
251 converged: true,
252 grad_norm,
253 });
254 }
255
256 let d = -h_inv.dot(&g);
258
259 let alpha = line_search_armijo(&x, &d, &f, &grad, &self.options)?;
261
262 let x_new = &x + &(&d * alpha);
264 let g_new = grad(&x_new);
265 let fval_new = f(&x_new);
266
267 let s = &x_new - &x;
269 let y = &g_new - &g;
270 let sy = s.dot(&y);
271
272 if sy > F::from(1e-8).unwrap() {
273 let rho = F::one() / sy;
274 let sy_outer = s
275 .clone()
276 .insert_axis(scirs2_core::ndarray::Axis(1))
277 .dot(&y.clone().insert_axis(scirs2_core::ndarray::Axis(0)));
278 let ys_outer = y
279 .clone()
280 .insert_axis(scirs2_core::ndarray::Axis(1))
281 .dot(&s.clone().insert_axis(scirs2_core::ndarray::Axis(0)));
282 let ss_outer = s
283 .clone()
284 .insert_axis(scirs2_core::ndarray::Axis(1))
285 .dot(&s.clone().insert_axis(scirs2_core::ndarray::Axis(0)));
286
287 let i_minus_rho_sy = Array2::eye(n) - &sy_outer * rho;
288 let i_minus_rho_ys = Array2::eye(n) - &ys_outer * rho;
289
290 *h_inv = i_minus_rho_sy.dot(h_inv).dot(&i_minus_rho_ys) + &ss_outer * rho;
291 }
292
293 if (fval - fval_new).abs() < self.options.tolerance {
295 return Ok(OptimizationResult {
296 x: x_new,
297 fval: fval_new,
298 iterations: iter + 1,
299 converged: true,
300 grad_norm,
301 });
302 }
303
304 x = x_new;
305 g = g_new;
306 fval = fval_new;
307 }
308
309 Ok(OptimizationResult {
310 x,
311 fval,
312 iterations: self.options.max_iter,
313 converged: false,
314 grad_norm: g.dot(&g).sqrt(),
315 })
316 }
317}
318
319#[allow(dead_code)]
321fn line_search_armijo<F, Func, Grad>(
322 x: &Array1<F>,
323 d: &Array1<F>,
324 f: &Func,
325 grad: &Grad,
326 options: &OptimizationOptions<F>,
327) -> Result<F>
328where
329 F: Float + FromPrimitive + Debug + Display + ScalarOperand,
330 Func: Fn(&Array1<F>) -> F,
331 Grad: Fn(&Array1<F>) -> Array1<F>,
332{
333 let mut alpha = F::one();
334 let f0 = f(x);
335 let g0 = grad(x);
336 let dg0 = g0.dot(d);
337
338 if dg0 > F::zero() {
339 return Err(TimeSeriesError::ComputationError(
340 "Invalid search direction".to_string(),
341 ));
342 }
343
344 while alpha > F::from(1e-10).unwrap() {
345 let x_new = x + &(d * alpha);
346 let f_new = f(&x_new);
347
348 if f_new <= f0 + options.line_search_alpha * alpha * dg0 {
349 return Ok(alpha);
350 }
351
352 alpha = alpha * options.line_search_beta;
353 }
354
355 Ok(F::from(1e-10).unwrap())
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361 use scirs2_core::ndarray::array;
362
363 #[test]
364 fn test_lbfgs_optimizer() {
365 let f = |x: &Array1<f64>| x.dot(x);
367 let grad = |x: &Array1<f64>| 2.0 * x;
368
369 let mut optimizer = LBFGSOptimizer::new(OptimizationOptions::default());
370 let x0 = array![1.0, 2.0, 3.0];
371 let result = optimizer.optimize(f, grad, &x0).unwrap();
372
373 assert!(result.converged);
374 assert!(result.fval < 1e-6);
375 assert!(result.grad_norm < 1e-6);
376 }
377
378 #[test]
379 fn test_rosenbrock_function() {
380 let f = |x: &Array1<f64>| {
382 let a = 1.0 - x[0];
383 let b = x[1] - x[0] * x[0];
384 a * a + 100.0 * b * b
385 };
386
387 let grad = |x: &Array1<f64>| {
388 let dx = -2.0 * (1.0 - x[0]) - 400.0 * x[0] * (x[1] - x[0] * x[0]);
389 let dy = 200.0 * (x[1] - x[0] * x[0]);
390 array![dx, dy]
391 };
392
393 let mut optimizer = LBFGSOptimizer::new(OptimizationOptions {
394 max_iter: 1000,
395 tolerance: 1e-8,
396 grad_tolerance: 1e-6,
397 ..Default::default()
398 });
399
400 let x0 = array![-1.0, 1.0];
401 let result = optimizer.optimize(f, grad, &x0).unwrap();
402
403 assert!(result.converged);
404 assert!((result.x[0] - 1.0).abs() < 0.01);
405 assert!((result.x[1] - 1.0).abs() < 0.01);
406 }
407}