1use crate::error::OptimizeError;
8use crate::simd_ops::{SimdConfig, SimdVectorOps};
9use crate::unconstrained::line_search::backtracking_line_search;
10use crate::unconstrained::{Bounds, OptimizeResult, Options};
11use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
12
13#[derive(Debug, Clone)]
15pub struct SimdBfgsOptions {
16 pub base_options: Options,
18 pub simd_config: Option<SimdConfig>,
20 pub force_simd: bool,
22 pub simd_threshold: usize,
24}
25
26impl Default for SimdBfgsOptions {
27 fn default() -> Self {
28 Self {
29 base_options: Options::default(),
30 simd_config: None,
31 force_simd: false,
32 simd_threshold: 8, }
34 }
35}
36
37struct SimdBfgsState {
39 hessian_inv: Array2<f64>,
41 simd_ops: SimdVectorOps,
43 gradient: Array1<f64>,
45 prev_gradient: Array1<f64>,
47 position: Array1<f64>,
49 prev_position: Array1<f64>,
51 function_value: f64,
53 nfev: usize,
55 njev: usize,
57}
58
59impl SimdBfgsState {
60 fn new(x0: &Array1<f64>, simd_config: Option<SimdConfig>) -> Self {
61 let n = x0.len();
62 let simd_ops = if let Some(config) = simd_config {
63 SimdVectorOps::with_config(config)
64 } else {
65 SimdVectorOps::new()
66 };
67
68 Self {
69 hessian_inv: Array2::eye(n),
70 simd_ops,
71 gradient: Array1::zeros(n),
72 prev_gradient: Array1::zeros(n),
73 position: x0.clone(),
74 prev_position: x0.clone(),
75 function_value: 0.0,
76 nfev: 0,
77 njev: 0,
78 }
79 }
80
81 fn update_hessian(&mut self) {
83 let n = self.position.len();
84
85 let s = self
87 .simd_ops
88 .sub(&self.position.view(), &self.prev_position.view());
89
90 let y = self
92 .simd_ops
93 .sub(&self.gradient.view(), &self.prev_gradient.view());
94
95 let s_dot_y = self.simd_ops.dot_product(&s.view(), &y.view());
97
98 if s_dot_y.abs() < 1e-14 {
99 return;
101 }
102
103 let rho = 1.0 / s_dot_y;
104
105 let hy = self.matrix_vector_multiply_simd(&self.hessian_inv.view(), &y.view());
114
115 let ythhy = self.simd_ops.dot_product(&y.view(), &hy.view());
117
118 for i in 0..n {
122 for j in 0..n {
123 let hess_update = -hy[i] * hy[j] / ythhy + rho * s[i] * s[j];
124 self.hessian_inv[[i, j]] += hess_update;
125 }
126 }
127 }
128
129 fn matrix_vector_multiply_simd(
131 &self,
132 matrix: &scirs2_core::ndarray::ArrayView2<f64>,
133 vector: &ArrayView1<f64>,
134 ) -> Array1<f64> {
135 self.simd_ops.matvec(matrix, vector)
136 }
137
138 #[allow(dead_code)]
140 fn vector_matrix_multiply_simd(
141 &self,
142 vector: &ArrayView1<f64>,
143 matrix: &scirs2_core::ndarray::ArrayView2<f64>,
144 ) -> Array1<f64> {
145 let n = matrix.ncols();
146 let mut result = Array1::zeros(n);
147
148 for j in 0..n {
150 let column = matrix.column(j);
151 result[j] = self.simd_ops.dot_product(vector, &column);
152 }
153
154 result
155 }
156
157 fn compute_search_direction(&self) -> Array1<f64> {
159 let neg_grad = self.simd_ops.scale(-1.0, &self.gradient.view());
161 self.matrix_vector_multiply_simd(&self.hessian_inv.view(), &neg_grad.view())
162 }
163}
164
165#[allow(dead_code)]
167pub fn minimize_simd_bfgs<F>(
168 mut fun: F,
169 x0: Array1<f64>,
170 options: Option<SimdBfgsOptions>,
171) -> Result<OptimizeResult<f64>, OptimizeError>
172where
173 F: FnMut(&ArrayView1<f64>) -> f64 + Clone,
174{
175 let options = options.unwrap_or_default();
176 let n = x0.len();
177
178 let use_simd = options.force_simd
180 || (n >= options.simd_threshold
181 && options
182 .simd_config
183 .as_ref()
184 .map_or_else(|| SimdConfig::detect().has_simd(), |c| c.has_simd()));
185
186 if !use_simd {
187 return crate::unconstrained::bfgs::minimize_bfgs(fun, x0, &options.base_options);
189 }
190
191 let mut state = SimdBfgsState::new(&x0, options.simd_config);
192
193 state.function_value = fun(&state.position.view());
195 state.nfev += 1;
196
197 state.gradient = compute_gradient_finite_diff(&mut fun, &state.position, &mut state.nfev);
199 state.njev += 1;
200
201 let mut prev_f = state.function_value;
202
203 for iteration in 0..options.base_options.max_iter {
204 let grad_norm = state.simd_ops.norm(&state.gradient.view());
206 if grad_norm < options.base_options.gtol {
207 return Ok(OptimizeResult {
208 x: state.position,
209 fun: state.function_value,
210 nit: iteration,
211 func_evals: state.nfev,
212 nfev: state.nfev,
213 jacobian: Some(state.gradient),
214 hessian: Some(state.hessian_inv),
215 success: true,
216 message: "SIMD BFGS optimization terminated successfully.".to_string(),
217 });
218 }
219
220 if iteration > 0 {
222 let f_change = (prev_f - state.function_value).abs();
223 if f_change < options.base_options.ftol {
224 return Ok(OptimizeResult {
225 x: state.position,
226 fun: state.function_value,
227 nit: iteration,
228 func_evals: state.nfev,
229 nfev: state.nfev,
230 jacobian: Some(state.gradient),
231 hessian: Some(state.hessian_inv),
232 success: true,
233 message: "SIMD BFGS optimization terminated successfully.".to_string(),
234 });
235 }
236 }
237
238 state.prev_position = state.position.clone();
240 state.prev_gradient = state.gradient.clone();
241 prev_f = state.function_value;
242
243 let search_direction = state.compute_search_direction();
245
246 let directional_derivative = state
248 .simd_ops
249 .dot_product(&state.gradient.view(), &search_direction.view());
250 if directional_derivative >= 0.0 {
251 state.hessian_inv = Array2::eye(n);
253 let neg_grad = state.simd_ops.scale(-1.0, &state.gradient.view());
254 state.position = state.simd_ops.add(
255 &state.position.view(),
256 &state.simd_ops.scale(0.001, &neg_grad.view()).view(),
257 );
258 } else {
259 let (step_size, line_search_nfev) = backtracking_line_search(
261 &mut |x| fun(x),
262 &state.position.view(),
263 state.function_value,
264 &search_direction.view(),
265 &state.gradient.view(),
266 1.0,
267 1e-4,
268 0.9,
269 options.base_options.bounds.as_ref(),
270 );
271 state.nfev += line_search_nfev as usize;
272
273 let step_vec = state.simd_ops.scale(step_size, &search_direction.view());
275 state.position = state.simd_ops.add(&state.position.view(), &step_vec.view());
276 }
277
278 if let Some(ref bounds) = options.base_options.bounds {
280 apply_bounds(&mut state.position, bounds);
281 }
282
283 state.function_value = fun(&state.position.view());
285 state.nfev += 1;
286
287 state.gradient = compute_gradient_finite_diff(&mut fun, &state.position, &mut state.nfev);
289 state.njev += 1;
290
291 if iteration > 0 {
293 state.update_hessian();
294 }
295
296 let position_change = state
298 .simd_ops
299 .sub(&state.position.view(), &state.prev_position.view());
300 let position_change_norm = state.simd_ops.norm(&position_change.view());
301 if position_change_norm < options.base_options.xtol {
302 return Ok(OptimizeResult {
303 x: state.position,
304 fun: state.function_value,
305 nit: iteration + 1,
306 func_evals: state.nfev,
307 nfev: state.nfev,
308 jacobian: Some(state.gradient),
309 hessian: Some(state.hessian_inv),
310 success: true,
311 message: "SIMD BFGS optimization terminated successfully.".to_string(),
312 });
313 }
314 }
315
316 Ok(OptimizeResult {
318 x: state.position,
319 fun: state.function_value,
320 nit: options.base_options.max_iter,
321 func_evals: state.nfev,
322 nfev: state.nfev,
323 jacobian: Some(state.gradient),
324 hessian: Some(state.hessian_inv),
325 success: false,
326 message: "Maximum iterations reached in SIMD BFGS.".to_string(),
327 })
328}
329
330#[allow(dead_code)]
332fn compute_gradient_finite_diff<F>(fun: &mut F, x: &Array1<f64>, nfev: &mut usize) -> Array1<f64>
333where
334 F: FnMut(&ArrayView1<f64>) -> f64,
335{
336 let n = x.len();
337 let mut grad = Array1::zeros(n);
338 let eps = (f64::EPSILON).sqrt();
339 let f0 = fun(&x.view());
340 *nfev += 1;
341
342 for i in 0..n {
343 let mut x_plus = x.clone();
344 x_plus[i] += eps;
345 let f_plus = fun(&x_plus.view());
346 *nfev += 1;
347
348 grad[i] = (f_plus - f0) / eps;
349 }
350
351 grad
352}
353
354#[allow(dead_code)]
356fn apply_bounds(x: &mut Array1<f64>, bounds: &Bounds) {
357 for (i, xi) in x.iter_mut().enumerate() {
358 if i < bounds.lower.len() {
359 if let Some(lb) = bounds.lower[i] {
360 if *xi < lb {
361 *xi = lb;
362 }
363 }
364 }
365 if i < bounds.upper.len() {
366 if let Some(ub) = bounds.upper[i] {
367 if *xi > ub {
368 *xi = ub;
369 }
370 }
371 }
372 }
373}
374
375#[allow(dead_code)]
377pub fn minimize_simd_bfgs_default<F>(
378 fun: F,
379 x0: Array1<f64>,
380) -> Result<OptimizeResult<f64>, OptimizeError>
381where
382 F: FnMut(&ArrayView1<f64>) -> f64 + Clone,
383{
384 minimize_simd_bfgs(fun, x0, None)
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390 use approx::assert_abs_diff_eq;
391 use scirs2_core::ndarray::array;
392
393 #[test]
394 fn test_simd_bfgs_quadratic() {
395 let fun = |x: &ArrayView1<f64>| x.iter().map(|&xi| xi.powi(2)).sum::<f64>();
397
398 let x0 = array![1.0, 2.0, 3.0, 4.0];
399 let options = SimdBfgsOptions {
400 base_options: Options {
401 max_iter: 100,
402 gtol: 1e-8,
403 ..Default::default()
404 },
405 force_simd: true,
406 ..Default::default()
407 };
408
409 let result = minimize_simd_bfgs(fun, x0, Some(options)).unwrap();
410
411 assert!(result.success);
412 for &xi in result.x.iter() {
413 assert_abs_diff_eq!(xi, 0.0, epsilon = 1e-6);
414 }
415 assert!(result.fun < 1e-10);
416 }
417
418 #[test]
419 fn test_simd_bfgs_rosenbrock() {
420 let rosenbrock = |x: &ArrayView1<f64>| {
422 let mut sum = 0.0;
423 for i in 0..x.len() - 1 {
424 let a = 1.0 - x[i];
425 let b = x[i + 1] - x[i].powi(2);
426 sum += a.powi(2) + 100.0 * b.powi(2);
427 }
428 sum
429 };
430
431 let x0 = array![0.0, 0.0, 0.0, 0.0];
432 let options = SimdBfgsOptions {
433 base_options: Options {
434 max_iter: 1000,
435 gtol: 1e-6,
436 ftol: 1e-9,
437 ..Default::default()
438 },
439 force_simd: true,
440 ..Default::default()
441 };
442
443 let result = minimize_simd_bfgs(rosenbrock, x0, Some(options)).unwrap();
444
445 for &xi in result.x.iter() {
447 assert_abs_diff_eq!(xi, 1.0, epsilon = 1e-3);
448 }
449 assert!(result.fun < 1e-6);
450 }
451
452 #[test]
453 fn test_simd_bfgs_with_bounds() {
454 let fun = |x: &ArrayView1<f64>| (x[0] + 2.0).powi(2) + (x[1] + 2.0).powi(2);
456
457 let bounds = Bounds::new(&[(Some(0.0), Some(1.0)), (Some(0.0), Some(1.0))]);
458 let x0 = array![0.5, 0.5];
459 let options = SimdBfgsOptions {
460 base_options: Options {
461 max_iter: 100,
462 gtol: 1e-6,
463 bounds: Some(bounds),
464 ..Default::default()
465 },
466 force_simd: true,
467 ..Default::default()
468 };
469
470 let result = minimize_simd_bfgs(fun, x0, Some(options)).unwrap();
471
472 assert!(result.x[0] >= 0.0 && result.x[0] <= 1.0);
474 assert!(result.x[1] >= 0.0 && result.x[1] <= 1.0);
475 assert_abs_diff_eq!(result.x[0], 0.0, epsilon = 1e-6);
476 assert_abs_diff_eq!(result.x[1], 0.0, epsilon = 1e-6);
477 }
478
479 #[test]
480 fn test_simd_config_detection() {
481 let config = SimdConfig::detect();
482 println!("SIMD capabilities detected:");
483 println!(" AVX2: {}", config.avx2_available);
484 println!(" SSE4.1: {}", config.sse41_available);
485 println!(" FMA: {}", config.fma_available);
486 println!(" Vector width: {}", config.vector_width);
487
488 let options = SimdBfgsOptions {
490 simd_config: Some(config),
491 force_simd: false,
492 ..Default::default()
493 };
494
495 let fun = |x: &ArrayView1<f64>| x[0].powi(2);
496 let x0 = array![1.0];
497 let result = minimize_simd_bfgs(fun, x0, Some(options));
498 assert!(result.is_ok());
499 }
500
501 #[test]
502 fn test_fallback_to_regular_bfgs() {
503 let fun = |x: &ArrayView1<f64>| x[0].powi(2) + x[1].powi(2);
505 let x0 = array![1.0, 2.0];
506
507 let options = SimdBfgsOptions {
508 force_simd: false,
509 simd_threshold: 10, ..Default::default()
511 };
512
513 let result = minimize_simd_bfgs(fun, x0, Some(options)).unwrap();
514 assert!(result.success);
515 assert_abs_diff_eq!(result.x[0], 0.0, epsilon = 1e-6);
516 assert_abs_diff_eq!(result.x[1], 0.0, epsilon = 1e-6);
517 }
518}