1use crate::error::OptimizeError;
8use crate::simd_ops::{SimdConfig, SimdVectorOps};
9use crate::unconstrained::line_search::backtracking_line_search;
10use crate::unconstrained::{Bounds, OptimizeResult, Options};
11use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
12
13#[derive(Debug, Clone)]
15pub struct SimdBfgsOptions {
16 pub base_options: Options,
18 pub simd_config: Option<SimdConfig>,
20 pub force_simd: bool,
22 pub simd_threshold: usize,
24}
25
26impl Default for SimdBfgsOptions {
27 fn default() -> Self {
28 Self {
29 base_options: Options::default(),
30 simd_config: None,
31 force_simd: false,
32 simd_threshold: 8, }
34 }
35}
36
37struct SimdBfgsState {
39 hessian_inv: Array2<f64>,
41 simd_ops: SimdVectorOps,
43 gradient: Array1<f64>,
45 prev_gradient: Array1<f64>,
47 position: Array1<f64>,
49 prev_position: Array1<f64>,
51 function_value: f64,
53 nfev: usize,
55 njev: usize,
57}
58
59impl SimdBfgsState {
60 fn new(x0: &Array1<f64>, simd_config: Option<SimdConfig>) -> Self {
61 let n = x0.len();
62 let simd_ops = if let Some(config) = simd_config {
63 SimdVectorOps::with_config(config)
64 } else {
65 SimdVectorOps::new()
66 };
67
68 Self {
69 hessian_inv: Array2::eye(n),
70 simd_ops,
71 gradient: Array1::zeros(n),
72 prev_gradient: Array1::zeros(n),
73 position: x0.clone(),
74 prev_position: x0.clone(),
75 function_value: 0.0,
76 nfev: 0,
77 njev: 0,
78 }
79 }
80
81 fn update_hessian(&mut self) {
83 let n = self.position.len();
84
85 let s = self
87 .simd_ops
88 .sub(&self.position.view(), &self.prev_position.view());
89
90 let y = self
92 .simd_ops
93 .sub(&self.gradient.view(), &self.prev_gradient.view());
94
95 let s_dot_y = self.simd_ops.dot_product(&s.view(), &y.view());
97
98 if s_dot_y.abs() < 1e-14 {
99 return;
101 }
102
103 let rho = 1.0 / s_dot_y;
104
105 let hy = self.matrix_vector_multiply_simd(&self.hessian_inv.view(), &y.view());
114
115 let ythhy = self.simd_ops.dot_product(&y.view(), &hy.view());
117
118 for i in 0..n {
120 for j in 0..n {
121 let hess_update = -hy[i] * hy[j] / ythhy + rho * s[i] * s[j];
122 self.hessian_inv[[i, j]] += hess_update;
123 }
124 }
125 }
126
127 fn matrix_vector_multiply_simd(
129 &self,
130 matrix: &scirs2_core::ndarray::ArrayView2<f64>,
131 vector: &ArrayView1<f64>,
132 ) -> Array1<f64> {
133 self.simd_ops.matvec(matrix, vector)
134 }
135
136 #[allow(dead_code)]
138 fn vector_matrix_multiply_simd(
139 &self,
140 vector: &ArrayView1<f64>,
141 matrix: &scirs2_core::ndarray::ArrayView2<f64>,
142 ) -> Array1<f64> {
143 let n = matrix.ncols();
144 let mut result = Array1::zeros(n);
145
146 for j in 0..n {
148 let column = matrix.column(j);
149 result[j] = self.simd_ops.dot_product(vector, &column);
150 }
151
152 result
153 }
154
155 fn compute_search_direction(&self) -> Array1<f64> {
157 let neg_grad = self.simd_ops.scale(-1.0, &self.gradient.view());
159 self.matrix_vector_multiply_simd(&self.hessian_inv.view(), &neg_grad.view())
160 }
161}
162
163#[allow(dead_code)]
172pub fn minimize_simd_bfgs<F, G>(
173 mut fun: F,
174 grad: Option<G>,
175 x0: Array1<f64>,
176 options: Option<SimdBfgsOptions>,
177) -> Result<OptimizeResult<f64>, OptimizeError>
178where
179 F: FnMut(&ArrayView1<f64>) -> f64 + Clone,
180 G: Fn(&ArrayView1<f64>) -> Array1<f64>,
181{
182 let options = options.unwrap_or_default();
183 let n = x0.len();
184
185 let use_simd = options.force_simd
187 || (n >= options.simd_threshold
188 && options
189 .simd_config
190 .as_ref()
191 .map_or_else(|| SimdConfig::detect().has_simd(), |c| c.has_simd()));
192
193 if !use_simd {
194 return crate::unconstrained::bfgs::minimize_bfgs(fun, grad, x0, &options.base_options);
196 }
197
198 let mut state = SimdBfgsState::new(&x0, options.simd_config);
199
200 state.function_value = fun(&state.position.view());
202 state.nfev += 1;
203
204 state.gradient = if let Some(ref grad_fn) = grad {
206 grad_fn(&state.position.view())
207 } else {
208 compute_gradient_finite_diff(&mut fun, &state.position, &mut state.nfev)
209 };
210 state.njev += 1;
211
212 let mut prev_f = state.function_value;
213
214 for iteration in 0..options.base_options.max_iter {
215 let grad_norm = state.simd_ops.norm(&state.gradient.view());
217 if grad_norm < options.base_options.gtol {
218 return Ok(OptimizeResult {
219 x: state.position,
220 fun: state.function_value,
221 nit: iteration,
222 func_evals: state.nfev,
223 nfev: state.nfev,
224 jacobian: Some(state.gradient),
225 hessian: Some(state.hessian_inv),
226 success: true,
227 message: "SIMD BFGS optimization terminated successfully.".to_string(),
228 });
229 }
230
231 if iteration > 0 {
233 let f_change = (prev_f - state.function_value).abs();
234 if f_change < options.base_options.ftol {
235 return Ok(OptimizeResult {
236 x: state.position,
237 fun: state.function_value,
238 nit: iteration,
239 func_evals: state.nfev,
240 nfev: state.nfev,
241 jacobian: Some(state.gradient),
242 hessian: Some(state.hessian_inv),
243 success: true,
244 message: "SIMD BFGS optimization terminated successfully.".to_string(),
245 });
246 }
247 }
248
249 state.prev_position = state.position.clone();
251 state.prev_gradient = state.gradient.clone();
252 prev_f = state.function_value;
253
254 let search_direction = state.compute_search_direction();
256
257 let directional_derivative = state
259 .simd_ops
260 .dot_product(&state.gradient.view(), &search_direction.view());
261 if directional_derivative >= 0.0 {
262 state.hessian_inv = Array2::eye(n);
264 let neg_grad = state.simd_ops.scale(-1.0, &state.gradient.view());
265 state.position = state.simd_ops.add(
266 &state.position.view(),
267 &state.simd_ops.scale(0.001, &neg_grad.view()).view(),
268 );
269 } else {
270 let (step_size, line_search_nfev) = backtracking_line_search(
272 &mut |x| fun(x),
273 &state.position.view(),
274 state.function_value,
275 &search_direction.view(),
276 &state.gradient.view(),
277 1.0,
278 1e-4,
279 0.9,
280 options.base_options.bounds.as_ref(),
281 );
282 state.nfev += line_search_nfev as usize;
283
284 let step_vec = state.simd_ops.scale(step_size, &search_direction.view());
286 state.position = state.simd_ops.add(&state.position.view(), &step_vec.view());
287 }
288
289 if let Some(ref bounds) = options.base_options.bounds {
291 apply_bounds(&mut state.position, bounds);
292 }
293
294 state.function_value = fun(&state.position.view());
296 state.nfev += 1;
297
298 state.gradient = if let Some(ref grad_fn) = grad {
300 grad_fn(&state.position.view())
301 } else {
302 compute_gradient_finite_diff(&mut fun, &state.position, &mut state.nfev)
303 };
304 state.njev += 1;
305
306 if iteration > 0 {
308 state.update_hessian();
309 }
310
311 let position_change = state
313 .simd_ops
314 .sub(&state.position.view(), &state.prev_position.view());
315 let position_change_norm = state.simd_ops.norm(&position_change.view());
316 if position_change_norm < options.base_options.xtol {
317 return Ok(OptimizeResult {
318 x: state.position,
319 fun: state.function_value,
320 nit: iteration + 1,
321 func_evals: state.nfev,
322 nfev: state.nfev,
323 jacobian: Some(state.gradient),
324 hessian: Some(state.hessian_inv),
325 success: true,
326 message: "SIMD BFGS optimization terminated successfully.".to_string(),
327 });
328 }
329 }
330
331 Ok(OptimizeResult {
333 x: state.position,
334 fun: state.function_value,
335 nit: options.base_options.max_iter,
336 func_evals: state.nfev,
337 nfev: state.nfev,
338 jacobian: Some(state.gradient),
339 hessian: Some(state.hessian_inv),
340 success: false,
341 message: "Maximum iterations reached in SIMD BFGS.".to_string(),
342 })
343}
344
345#[allow(dead_code)]
347fn compute_gradient_finite_diff<F>(fun: &mut F, x: &Array1<f64>, nfev: &mut usize) -> Array1<f64>
348where
349 F: FnMut(&ArrayView1<f64>) -> f64,
350{
351 let n = x.len();
352 let mut grad = Array1::zeros(n);
353 let eps = (f64::EPSILON).sqrt();
354 let f0 = fun(&x.view());
355 *nfev += 1;
356
357 for i in 0..n {
358 let mut x_plus = x.clone();
359 x_plus[i] += eps;
360 let f_plus = fun(&x_plus.view());
361 *nfev += 1;
362
363 grad[i] = (f_plus - f0) / eps;
364 }
365
366 grad
367}
368
369#[allow(dead_code)]
371fn apply_bounds(x: &mut Array1<f64>, bounds: &Bounds) {
372 for (i, xi) in x.iter_mut().enumerate() {
373 if i < bounds.lower.len() {
374 if let Some(lb) = bounds.lower[i] {
375 if *xi < lb {
376 *xi = lb;
377 }
378 }
379 }
380 if i < bounds.upper.len() {
381 if let Some(ub) = bounds.upper[i] {
382 if *xi > ub {
383 *xi = ub;
384 }
385 }
386 }
387 }
388}
389
390#[allow(dead_code)]
394pub fn minimize_simd_bfgs_default<F>(
395 fun: F,
396 x0: Array1<f64>,
397) -> Result<OptimizeResult<f64>, OptimizeError>
398where
399 F: FnMut(&ArrayView1<f64>) -> f64 + Clone,
400{
401 minimize_simd_bfgs(fun, None::<fn(&ArrayView1<f64>) -> Array1<f64>>, x0, None)
402}
403
404#[cfg(test)]
405mod tests {
406 use super::*;
407 use approx::assert_abs_diff_eq;
408 use scirs2_core::ndarray::array;
409
410 #[test]
411 fn test_simd_bfgs_quadratic() {
412 let fun = |x: &ArrayView1<f64>| x.iter().map(|&xi| xi.powi(2)).sum::<f64>();
414
415 let x0 = array![1.0, 2.0, 3.0, 4.0];
416 let options = SimdBfgsOptions {
417 base_options: Options {
418 max_iter: 100,
419 gtol: 1e-8,
420 ..Default::default()
421 },
422 force_simd: true,
423 ..Default::default()
424 };
425
426 let result = minimize_simd_bfgs(
427 fun,
428 None::<fn(&ArrayView1<f64>) -> Array1<f64>>,
429 x0,
430 Some(options),
431 )
432 .expect("Operation failed");
433
434 assert!(result.success);
435 for &xi in result.x.iter() {
436 assert_abs_diff_eq!(xi, 0.0, epsilon = 1e-6);
437 }
438 assert!(result.fun < 1e-10);
439 }
440
441 #[test]
442 fn test_simd_bfgs_rosenbrock() {
443 let rosenbrock = |x: &ArrayView1<f64>| {
445 let mut sum = 0.0;
446 for i in 0..x.len() - 1 {
447 let a = 1.0 - x[i];
448 let b = x[i + 1] - x[i].powi(2);
449 sum += a.powi(2) + 100.0 * b.powi(2);
450 }
451 sum
452 };
453
454 let x0 = array![0.0, 0.0, 0.0, 0.0];
455 let options = SimdBfgsOptions {
456 base_options: Options {
457 max_iter: 1000,
458 gtol: 1e-6,
459 ftol: 1e-9,
460 ..Default::default()
461 },
462 force_simd: true,
463 ..Default::default()
464 };
465
466 let result = minimize_simd_bfgs(
467 rosenbrock,
468 None::<fn(&ArrayView1<f64>) -> Array1<f64>>,
469 x0,
470 Some(options),
471 )
472 .expect("Operation failed");
473
474 for &xi in result.x.iter() {
476 assert_abs_diff_eq!(xi, 1.0, epsilon = 1e-3);
477 }
478 assert!(result.fun < 1e-6);
479 }
480
481 #[test]
482 fn test_simd_bfgs_with_bounds() {
483 let fun = |x: &ArrayView1<f64>| (x[0] + 2.0).powi(2) + (x[1] + 2.0).powi(2);
485
486 let bounds = Bounds::new(&[(Some(0.0), Some(1.0)), (Some(0.0), Some(1.0))]);
487 let x0 = array![0.5, 0.5];
488 let options = SimdBfgsOptions {
489 base_options: Options {
490 max_iter: 100,
491 gtol: 1e-6,
492 bounds: Some(bounds),
493 ..Default::default()
494 },
495 force_simd: true,
496 ..Default::default()
497 };
498
499 let result = minimize_simd_bfgs(
500 fun,
501 None::<fn(&ArrayView1<f64>) -> Array1<f64>>,
502 x0,
503 Some(options),
504 )
505 .expect("Operation failed");
506
507 assert!(result.x[0] >= 0.0 && result.x[0] <= 1.0);
509 assert!(result.x[1] >= 0.0 && result.x[1] <= 1.0);
510 assert_abs_diff_eq!(result.x[0], 0.0, epsilon = 1e-6);
511 assert_abs_diff_eq!(result.x[1], 0.0, epsilon = 1e-6);
512 }
513
514 #[test]
515 fn test_simd_config_detection() {
516 let config = SimdConfig::detect();
517 println!("SIMD capabilities detected:");
518 println!(" AVX2: {}", config.avx2_available);
519 println!(" SSE4.1: {}", config.sse41_available);
520 println!(" FMA: {}", config.fma_available);
521 println!(" Vector width: {}", config.vector_width);
522
523 let options = SimdBfgsOptions {
525 simd_config: Some(config),
526 force_simd: false,
527 ..Default::default()
528 };
529
530 let fun = |x: &ArrayView1<f64>| x[0].powi(2);
531 let x0 = array![1.0];
532 let result = minimize_simd_bfgs(
533 fun,
534 None::<fn(&ArrayView1<f64>) -> Array1<f64>>,
535 x0,
536 Some(options),
537 );
538 assert!(result.is_ok());
539 }
540
541 #[test]
542 fn test_fallback_to_regular_bfgs() {
543 let fun = |x: &ArrayView1<f64>| x[0].powi(2) + x[1].powi(2);
545 let x0 = array![1.0, 2.0];
546
547 let options = SimdBfgsOptions {
548 force_simd: false,
549 simd_threshold: 10, ..Default::default()
551 };
552
553 let result = minimize_simd_bfgs(
554 fun,
555 None::<fn(&ArrayView1<f64>) -> Array1<f64>>,
556 x0,
557 Some(options),
558 )
559 .expect("Operation failed");
560 assert!(result.success);
561 assert_abs_diff_eq!(result.x[0], 0.0, epsilon = 1e-6);
562 assert_abs_diff_eq!(result.x[1], 0.0, epsilon = 1e-6);
563 }
564
565 #[test]
566 fn test_simd_bfgs_with_analytic_gradient() {
567 let fun = |x: &ArrayView1<f64>| x.iter().map(|&xi| xi.powi(2)).sum::<f64>();
569 let grad_fn = |x: &ArrayView1<f64>| x.mapv(|xi| 2.0 * xi);
570
571 let x0 = array![1.0, 2.0, 3.0, 4.0];
572 let options = SimdBfgsOptions {
573 base_options: Options {
574 max_iter: 100,
575 gtol: 1e-8,
576 ..Default::default()
577 },
578 force_simd: true,
579 ..Default::default()
580 };
581
582 let result =
583 minimize_simd_bfgs(fun, Some(grad_fn), x0, Some(options)).expect("Operation failed");
584
585 assert!(result.success);
586 for &xi in result.x.iter() {
587 assert_abs_diff_eq!(xi, 0.0, epsilon = 1e-6);
588 }
589 }
590}