1use crate::error::OptimizeError;
8use crate::simd_ops::{SimdConfig, SimdVectorOps};
9use crate::unconstrained::line_search::backtracking_line_search;
10use crate::unconstrained::{Bounds, OptimizeResult, Options};
11use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
12
13#[derive(Debug, Clone)]
15pub struct SimdBfgsOptions {
16    pub base_options: Options,
18    pub simd_config: Option<SimdConfig>,
20    pub force_simd: bool,
22    pub simd_threshold: usize,
24}
25
26impl Default for SimdBfgsOptions {
27    fn default() -> Self {
28        Self {
29            base_options: Options::default(),
30            simd_config: None,
31            force_simd: false,
32            simd_threshold: 8, }
34    }
35}
36
37struct SimdBfgsState {
39    hessian_inv: Array2<f64>,
41    simd_ops: SimdVectorOps,
43    gradient: Array1<f64>,
45    prev_gradient: Array1<f64>,
47    position: Array1<f64>,
49    prev_position: Array1<f64>,
51    function_value: f64,
53    nfev: usize,
55    njev: usize,
57}
58
59impl SimdBfgsState {
60    fn new(x0: &Array1<f64>, simd_config: Option<SimdConfig>) -> Self {
61        let n = x0.len();
62        let simd_ops = if let Some(config) = simd_config {
63            SimdVectorOps::with_config(config)
64        } else {
65            SimdVectorOps::new()
66        };
67
68        Self {
69            hessian_inv: Array2::eye(n),
70            simd_ops,
71            gradient: Array1::zeros(n),
72            prev_gradient: Array1::zeros(n),
73            position: x0.clone(),
74            prev_position: x0.clone(),
75            function_value: 0.0,
76            nfev: 0,
77            njev: 0,
78        }
79    }
80
81    fn update_hessian(&mut self) {
83        let n = self.position.len();
84
85        let s = self
87            .simd_ops
88            .sub(&self.position.view(), &self.prev_position.view());
89
90        let y = self
92            .simd_ops
93            .sub(&self.gradient.view(), &self.prev_gradient.view());
94
95        let s_dot_y = self.simd_ops.dot_product(&s.view(), &y.view());
97
98        if s_dot_y.abs() < 1e-14 {
99            return;
101        }
102
103        let rho = 1.0 / s_dot_y;
104
105        let hy = self.matrix_vector_multiply_simd(&self.hessian_inv.view(), &y.view());
114
115        let ythhy = self.simd_ops.dot_product(&y.view(), &hy.view());
117
118        for i in 0..n {
122            for j in 0..n {
123                let hess_update = -hy[i] * hy[j] / ythhy + rho * s[i] * s[j];
124                self.hessian_inv[[i, j]] += hess_update;
125            }
126        }
127    }
128
129    fn matrix_vector_multiply_simd(
131        &self,
132        matrix: &scirs2_core::ndarray::ArrayView2<f64>,
133        vector: &ArrayView1<f64>,
134    ) -> Array1<f64> {
135        self.simd_ops.matvec(matrix, vector)
136    }
137
138    #[allow(dead_code)]
140    fn vector_matrix_multiply_simd(
141        &self,
142        vector: &ArrayView1<f64>,
143        matrix: &scirs2_core::ndarray::ArrayView2<f64>,
144    ) -> Array1<f64> {
145        let n = matrix.ncols();
146        let mut result = Array1::zeros(n);
147
148        for j in 0..n {
150            let column = matrix.column(j);
151            result[j] = self.simd_ops.dot_product(vector, &column);
152        }
153
154        result
155    }
156
157    fn compute_search_direction(&self) -> Array1<f64> {
159        let neg_grad = self.simd_ops.scale(-1.0, &self.gradient.view());
161        self.matrix_vector_multiply_simd(&self.hessian_inv.view(), &neg_grad.view())
162    }
163}
164
165#[allow(dead_code)]
167pub fn minimize_simd_bfgs<F>(
168    mut fun: F,
169    x0: Array1<f64>,
170    options: Option<SimdBfgsOptions>,
171) -> Result<OptimizeResult<f64>, OptimizeError>
172where
173    F: FnMut(&ArrayView1<f64>) -> f64 + Clone,
174{
175    let options = options.unwrap_or_default();
176    let n = x0.len();
177
178    let use_simd = options.force_simd
180        || (n >= options.simd_threshold
181            && options
182                .simd_config
183                .as_ref()
184                .map_or_else(|| SimdConfig::detect().has_simd(), |c| c.has_simd()));
185
186    if !use_simd {
187        return crate::unconstrained::bfgs::minimize_bfgs(fun, x0, &options.base_options);
189    }
190
191    let mut state = SimdBfgsState::new(&x0, options.simd_config);
192
193    state.function_value = fun(&state.position.view());
195    state.nfev += 1;
196
197    state.gradient = compute_gradient_finite_diff(&mut fun, &state.position, &mut state.nfev);
199    state.njev += 1;
200
201    let mut prev_f = state.function_value;
202
203    for iteration in 0..options.base_options.max_iter {
204        let grad_norm = state.simd_ops.norm(&state.gradient.view());
206        if grad_norm < options.base_options.gtol {
207            return Ok(OptimizeResult {
208                x: state.position,
209                fun: state.function_value,
210                nit: iteration,
211                func_evals: state.nfev,
212                nfev: state.nfev,
213                jacobian: Some(state.gradient),
214                hessian: Some(state.hessian_inv),
215                success: true,
216                message: "SIMD BFGS optimization terminated successfully.".to_string(),
217            });
218        }
219
220        if iteration > 0 {
222            let f_change = (prev_f - state.function_value).abs();
223            if f_change < options.base_options.ftol {
224                return Ok(OptimizeResult {
225                    x: state.position,
226                    fun: state.function_value,
227                    nit: iteration,
228                    func_evals: state.nfev,
229                    nfev: state.nfev,
230                    jacobian: Some(state.gradient),
231                    hessian: Some(state.hessian_inv),
232                    success: true,
233                    message: "SIMD BFGS optimization terminated successfully.".to_string(),
234                });
235            }
236        }
237
238        state.prev_position = state.position.clone();
240        state.prev_gradient = state.gradient.clone();
241        prev_f = state.function_value;
242
243        let search_direction = state.compute_search_direction();
245
246        let directional_derivative = state
248            .simd_ops
249            .dot_product(&state.gradient.view(), &search_direction.view());
250        if directional_derivative >= 0.0 {
251            state.hessian_inv = Array2::eye(n);
253            let neg_grad = state.simd_ops.scale(-1.0, &state.gradient.view());
254            state.position = state.simd_ops.add(
255                &state.position.view(),
256                &state.simd_ops.scale(0.001, &neg_grad.view()).view(),
257            );
258        } else {
259            let (step_size, line_search_nfev) = backtracking_line_search(
261                &mut |x| fun(x),
262                &state.position.view(),
263                state.function_value,
264                &search_direction.view(),
265                &state.gradient.view(),
266                1.0,
267                1e-4,
268                0.9,
269                options.base_options.bounds.as_ref(),
270            );
271            state.nfev += line_search_nfev as usize;
272
273            let step_vec = state.simd_ops.scale(step_size, &search_direction.view());
275            state.position = state.simd_ops.add(&state.position.view(), &step_vec.view());
276        }
277
278        if let Some(ref bounds) = options.base_options.bounds {
280            apply_bounds(&mut state.position, bounds);
281        }
282
283        state.function_value = fun(&state.position.view());
285        state.nfev += 1;
286
287        state.gradient = compute_gradient_finite_diff(&mut fun, &state.position, &mut state.nfev);
289        state.njev += 1;
290
291        if iteration > 0 {
293            state.update_hessian();
294        }
295
296        let position_change = state
298            .simd_ops
299            .sub(&state.position.view(), &state.prev_position.view());
300        let position_change_norm = state.simd_ops.norm(&position_change.view());
301        if position_change_norm < options.base_options.xtol {
302            return Ok(OptimizeResult {
303                x: state.position,
304                fun: state.function_value,
305                nit: iteration + 1,
306                func_evals: state.nfev,
307                nfev: state.nfev,
308                jacobian: Some(state.gradient),
309                hessian: Some(state.hessian_inv),
310                success: true,
311                message: "SIMD BFGS optimization terminated successfully.".to_string(),
312            });
313        }
314    }
315
316    Ok(OptimizeResult {
318        x: state.position,
319        fun: state.function_value,
320        nit: options.base_options.max_iter,
321        func_evals: state.nfev,
322        nfev: state.nfev,
323        jacobian: Some(state.gradient),
324        hessian: Some(state.hessian_inv),
325        success: false,
326        message: "Maximum iterations reached in SIMD BFGS.".to_string(),
327    })
328}
329
330#[allow(dead_code)]
332fn compute_gradient_finite_diff<F>(fun: &mut F, x: &Array1<f64>, nfev: &mut usize) -> Array1<f64>
333where
334    F: FnMut(&ArrayView1<f64>) -> f64,
335{
336    let n = x.len();
337    let mut grad = Array1::zeros(n);
338    let eps = (f64::EPSILON).sqrt();
339    let f0 = fun(&x.view());
340    *nfev += 1;
341
342    for i in 0..n {
343        let mut x_plus = x.clone();
344        x_plus[i] += eps;
345        let f_plus = fun(&x_plus.view());
346        *nfev += 1;
347
348        grad[i] = (f_plus - f0) / eps;
349    }
350
351    grad
352}
353
354#[allow(dead_code)]
356fn apply_bounds(x: &mut Array1<f64>, bounds: &Bounds) {
357    for (i, xi) in x.iter_mut().enumerate() {
358        if i < bounds.lower.len() {
359            if let Some(lb) = bounds.lower[i] {
360                if *xi < lb {
361                    *xi = lb;
362                }
363            }
364        }
365        if i < bounds.upper.len() {
366            if let Some(ub) = bounds.upper[i] {
367                if *xi > ub {
368                    *xi = ub;
369                }
370            }
371        }
372    }
373}
374
375#[allow(dead_code)]
377pub fn minimize_simd_bfgs_default<F>(
378    fun: F,
379    x0: Array1<f64>,
380) -> Result<OptimizeResult<f64>, OptimizeError>
381where
382    F: FnMut(&ArrayView1<f64>) -> f64 + Clone,
383{
384    minimize_simd_bfgs(fun, x0, None)
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390    use approx::assert_abs_diff_eq;
391    use scirs2_core::ndarray::array;
392
393    #[test]
394    fn test_simd_bfgs_quadratic() {
395        let fun = |x: &ArrayView1<f64>| x.iter().map(|&xi| xi.powi(2)).sum::<f64>();
397
398        let x0 = array![1.0, 2.0, 3.0, 4.0];
399        let options = SimdBfgsOptions {
400            base_options: Options {
401                max_iter: 100,
402                gtol: 1e-8,
403                ..Default::default()
404            },
405            force_simd: true,
406            ..Default::default()
407        };
408
409        let result = minimize_simd_bfgs(fun, x0, Some(options)).unwrap();
410
411        assert!(result.success);
412        for &xi in result.x.iter() {
413            assert_abs_diff_eq!(xi, 0.0, epsilon = 1e-6);
414        }
415        assert!(result.fun < 1e-10);
416    }
417
418    #[test]
419    fn test_simd_bfgs_rosenbrock() {
420        let rosenbrock = |x: &ArrayView1<f64>| {
422            let mut sum = 0.0;
423            for i in 0..x.len() - 1 {
424                let a = 1.0 - x[i];
425                let b = x[i + 1] - x[i].powi(2);
426                sum += a.powi(2) + 100.0 * b.powi(2);
427            }
428            sum
429        };
430
431        let x0 = array![0.0, 0.0, 0.0, 0.0];
432        let options = SimdBfgsOptions {
433            base_options: Options {
434                max_iter: 1000,
435                gtol: 1e-6,
436                ftol: 1e-9,
437                ..Default::default()
438            },
439            force_simd: true,
440            ..Default::default()
441        };
442
443        let result = minimize_simd_bfgs(rosenbrock, x0, Some(options)).unwrap();
444
445        for &xi in result.x.iter() {
447            assert_abs_diff_eq!(xi, 1.0, epsilon = 1e-3);
448        }
449        assert!(result.fun < 1e-6);
450    }
451
452    #[test]
453    fn test_simd_bfgs_with_bounds() {
454        let fun = |x: &ArrayView1<f64>| (x[0] + 2.0).powi(2) + (x[1] + 2.0).powi(2);
456
457        let bounds = Bounds::new(&[(Some(0.0), Some(1.0)), (Some(0.0), Some(1.0))]);
458        let x0 = array![0.5, 0.5];
459        let options = SimdBfgsOptions {
460            base_options: Options {
461                max_iter: 100,
462                gtol: 1e-6,
463                bounds: Some(bounds),
464                ..Default::default()
465            },
466            force_simd: true,
467            ..Default::default()
468        };
469
470        let result = minimize_simd_bfgs(fun, x0, Some(options)).unwrap();
471
472        assert!(result.x[0] >= 0.0 && result.x[0] <= 1.0);
474        assert!(result.x[1] >= 0.0 && result.x[1] <= 1.0);
475        assert_abs_diff_eq!(result.x[0], 0.0, epsilon = 1e-6);
476        assert_abs_diff_eq!(result.x[1], 0.0, epsilon = 1e-6);
477    }
478
479    #[test]
480    fn test_simd_config_detection() {
481        let config = SimdConfig::detect();
482        println!("SIMD capabilities detected:");
483        println!("  AVX2: {}", config.avx2_available);
484        println!("  SSE4.1: {}", config.sse41_available);
485        println!("  FMA: {}", config.fma_available);
486        println!("  Vector width: {}", config.vector_width);
487
488        let options = SimdBfgsOptions {
490            simd_config: Some(config),
491            force_simd: false,
492            ..Default::default()
493        };
494
495        let fun = |x: &ArrayView1<f64>| x[0].powi(2);
496        let x0 = array![1.0];
497        let result = minimize_simd_bfgs(fun, x0, Some(options));
498        assert!(result.is_ok());
499    }
500
501    #[test]
502    fn test_fallback_to_regular_bfgs() {
503        let fun = |x: &ArrayView1<f64>| x[0].powi(2) + x[1].powi(2);
505        let x0 = array![1.0, 2.0];
506
507        let options = SimdBfgsOptions {
508            force_simd: false,
509            simd_threshold: 10, ..Default::default()
511        };
512
513        let result = minimize_simd_bfgs(fun, x0, Some(options)).unwrap();
514        assert!(result.success);
515        assert_abs_diff_eq!(result.x[0], 0.0, epsilon = 1e-6);
516        assert_abs_diff_eq!(result.x[1], 0.0, epsilon = 1e-6);
517    }
518}