1use crate::constrained::{Constraint, ConstraintFn, ConstraintKind, Options};
25use crate::error::{OptimizeError, OptimizeResult};
26use crate::result::OptimizeResults;
27use scirs2_core::ndarray::{Array1, ArrayBase, Data, Ix1};
28
29pub struct SQPSolver {
31 pub max_iter: usize,
33 pub tol: f64,
35 pub constraint_tol: f64,
37 pub eps: f64,
39 pub lambda_init: f64,
41}
42
43impl Default for SQPSolver {
44 fn default() -> Self {
45 SQPSolver {
46 max_iter: 200,
47 tol: 1e-8,
48 constraint_tol: 1e-8,
49 eps: 1e-7,
50 lambda_init: 0.0,
51 }
52 }
53}
54
55impl SQPSolver {
56 pub fn new() -> Self {
58 SQPSolver::default()
59 }
60
61 pub fn solve<F, GF, E, G>(
73 &self,
74 f: F,
75 grad_f: Option<GF>,
76 eq_cons: &[E],
77 ineq_cons: &[G],
78 x0: &[f64],
79 ) -> OptimizeResult<OptimizeResults<f64>>
80 where
81 F: Fn(&[f64]) -> f64 + Clone,
82 GF: Fn(&[f64]) -> Vec<f64>,
83 E: Fn(&[f64]) -> f64 + Clone,
84 G: Fn(&[f64]) -> f64 + Clone,
85 {
86 let n = x0.len();
87 if n == 0 {
88 return Err(OptimizeError::InvalidInput(
89 "x0 must be non-empty".to_string(),
90 ));
91 }
92
93 let mut x = x0.to_vec();
94 let mut nfev = 0usize;
95 let mut njev = 0usize;
96 let mut nit = 0usize;
97
98 let mut hess: Vec<Vec<f64>> = (0..n)
100 .map(|i| {
101 let mut row = vec![0.0; n];
102 row[i] = 1.0;
103 row
104 })
105 .collect();
106
107 let n_eq = eq_cons.len();
109 let n_ineq = ineq_cons.len();
110 let n_lambda = n_eq + n_ineq;
111 let mut lambda = vec![self.lambda_init; n_lambda];
112
113 let compute_grad = |xv: &[f64], nfev: &mut usize, njev: &mut usize| -> Vec<f64> {
115 *njev += 1;
116 if let Some(ref gf) = grad_f {
117 gf(xv)
118 } else {
119 let h = self.eps;
120 let mut g = vec![0.0; n];
121 let mut xp = xv.to_vec();
122 let mut xm = xv.to_vec();
123 for i in 0..n {
124 xp[i] = xv[i] + h;
125 xm[i] = xv[i] - h;
126 *nfev += 2;
127 g[i] = (f(&xp) - f(&xm)) / (2.0 * h);
128 xp[i] = xv[i];
129 xm[i] = xv[i];
130 }
131 g
132 }
133 };
134
135 let compute_constraint_jac = |xv: &[f64], nfev: &mut usize| -> (Vec<f64>, Vec<Vec<f64>>) {
137 let h = self.eps;
138 let mut c_vals = Vec::with_capacity(n_lambda);
139 let mut c_jac: Vec<Vec<f64>> = Vec::with_capacity(n_lambda);
140
141 for e in eq_cons {
142 let cv = e(xv);
143 c_vals.push(cv);
144 *nfev += 1;
145 let mut jrow = vec![0.0; n];
146 let mut xp = xv.to_vec();
147 for j in 0..n {
148 xp[j] = xv[j] + h;
149 *nfev += 1;
150 jrow[j] = (e(&xp) - cv) / h;
151 xp[j] = xv[j];
152 }
153 c_jac.push(jrow);
154 }
155
156 for g in ineq_cons {
157 let cv = g(xv);
158 c_vals.push(cv);
159 *nfev += 1;
160 let mut jrow = vec![0.0; n];
161 let mut xp = xv.to_vec();
162 for j in 0..n {
163 xp[j] = xv[j] + h;
164 *nfev += 1;
165 jrow[j] = (g(&xp) - cv) / h;
166 xp[j] = xv[j];
167 }
168 c_jac.push(jrow);
169 }
170
171 (c_vals, c_jac)
172 };
173
174 let mut prev_grad_lag: Option<Vec<f64>> = None;
175 let mut prev_x: Option<Vec<f64>> = None;
176
177 for _iter in 0..self.max_iter {
178 nit += 1;
179 nfev += 1;
180 let fx = f(&x);
181
182 let grad = compute_grad(&x, &mut nfev, &mut njev);
183 let (c_vals, c_jac) = compute_constraint_jac(&x, &mut nfev);
184
185 let mut grad_lag: Vec<f64> = grad.clone();
187 for j in 0..n_lambda {
188 for i in 0..n {
189 grad_lag[i] += lambda[j] * c_jac[j][i];
190 }
191 }
192
193 let grad_lag_norm = grad_lag.iter().map(|v| v * v).sum::<f64>().sqrt();
195 let eq_viol: f64 = c_vals[..n_eq].iter().map(|v| v.abs()).sum();
196 let ineq_viol: f64 = c_vals[n_eq..].iter().map(|v| v.max(0.0)).sum();
197 let cv = eq_viol + ineq_viol;
198
199 if grad_lag_norm <= self.tol && cv <= self.constraint_tol {
200 return Ok(OptimizeResults {
201 x: Array1::from_vec(x),
202 fun: fx,
203 jac: Some(grad),
204 hess: None,
205 constr: Some(Array1::from_vec(c_vals)),
206 nit,
207 nfev,
208 njev,
209 nhev: 0,
210 maxcv: 0,
211 message: "KKT conditions satisfied".to_string(),
212 success: true,
213 status: 0,
214 });
215 }
216
217 if let (Some(ref px), Some(ref pg)) = (&prev_x, &prev_grad_lag) {
219 let s: Vec<f64> = x
220 .iter()
221 .zip(px.iter())
222 .map(|(&xi, &pxi)| xi - pxi)
223 .collect();
224 let y: Vec<f64> = grad_lag
225 .iter()
226 .zip(pg.iter())
227 .map(|(&gi, &pgi)| gi - pgi)
228 .collect();
229 let sy: f64 = s.iter().zip(y.iter()).map(|(&si, &yi)| si * yi).sum();
230 let hs: Vec<f64> = (0..n)
232 .map(|i| {
233 hess[i]
234 .iter()
235 .zip(s.iter())
236 .map(|(&h, &si)| h * si)
237 .sum::<f64>()
238 })
239 .collect();
240 let sths: f64 = s.iter().zip(hs.iter()).map(|(&si, &hsi)| si * hsi).sum();
241
242 let sy_damp = sy.max(0.2 * sths); if sy_damp.abs() > 1e-10 && sths.abs() > 1e-10 {
245 for i in 0..n {
247 for j in 0..n {
248 hess[i][j] += y[i] * y[j] / sy_damp - hs[i] * hs[j] / sths;
249 }
250 }
251 }
252 }
253
254 prev_x = Some(x.clone());
255 prev_grad_lag = Some(grad_lag.clone());
256
257 let active_flags: Vec<bool> = (0..n_ineq).map(|i| c_vals[n_eq + i] > -1e-3).collect();
267 let n_active = n_eq + active_flags.iter().filter(|&&a| a).count();
268
269 let d = if n_active > 0 {
270 solve_kkt_step(&hess, &grad, &c_jac, &c_vals, n_eq, &active_flags, n)
271 } else {
272 solve_newton_step(&hess, &grad, n)
273 };
274
275 let mu_merit = lambda.iter().map(|v| v.abs()).fold(1.0_f64, f64::max) + 1.0;
278 let mut merit_fn = |xv: &[f64]| -> f64 {
279 let fv = f(xv);
280 nfev += 1;
281 let cv_eq: f64 = eq_cons.iter().map(|e| e(xv).abs()).sum::<f64>();
282 let cv_ineq: f64 = ineq_cons.iter().map(|g| g(xv).max(0.0)).sum::<f64>();
283 fv + mu_merit * (cv_eq + cv_ineq)
284 };
285
286 let merit0 = merit_fn(&x);
287 let d_obj: f64 = grad.iter().zip(d.iter()).map(|(&gi, &di)| gi * di).sum();
289 let d_merit = d_obj - mu_merit * (eq_viol + ineq_viol);
290
291 let mut alpha = 1.0_f64;
292 let armijo_c = 1e-4;
293 let backtrack = 0.5;
294 let max_ls = 30;
295
296 for _ls in 0..max_ls {
297 let xnew: Vec<f64> = x
298 .iter()
299 .zip(d.iter())
300 .map(|(&xi, &di)| xi + alpha * di)
301 .collect();
302 let m_new = merit_fn(&xnew);
303 if m_new <= merit0 + armijo_c * alpha * d_merit {
304 break;
305 }
306 alpha *= backtrack;
307 }
308
309 let xnew: Vec<f64> = x
310 .iter()
311 .zip(d.iter())
312 .map(|(&xi, &di)| xi + alpha * di)
313 .collect();
314 x = xnew;
315
316 let (new_c_vals, new_c_jac) = compute_constraint_jac(&x, &mut nfev);
318 let new_grad = compute_grad(&x, &mut nfev, &mut njev);
319
320 update_lagrange_multipliers(
323 &new_grad,
324 &new_c_jac,
325 &new_c_vals,
326 &mut lambda,
327 n_eq,
328 n_ineq,
329 );
330 }
331
332 nfev += 1;
334 let fx = f(&x);
335 let (c_vals, _) = compute_constraint_jac(&x, &mut nfev);
336 let eq_viol: f64 = c_vals[..n_eq].iter().map(|v| v.abs()).sum();
337 let ineq_viol: f64 = c_vals[n_eq..].iter().map(|v| v.max(0.0)).sum();
338 let cv = eq_viol + ineq_viol;
339
340 Ok(OptimizeResults {
341 x: Array1::from_vec(x),
342 fun: fx,
343 jac: None,
344 hess: None,
345 constr: Some(Array1::from_vec(c_vals)),
346 nit,
347 nfev,
348 njev,
349 nhev: 0,
350 maxcv: 0,
351 message: format!("Maximum iterations reached (cv={:.2e})", cv),
352 success: cv <= self.constraint_tol,
353 status: if cv <= self.constraint_tol { 0 } else { 1 },
354 })
355 }
356}
357
358fn solve_kkt_step(
368 hess: &[Vec<f64>],
369 grad_f: &[f64],
370 c_jac: &[Vec<f64>],
371 c_vals: &[f64],
372 n_eq: usize,
373 active_flags: &[bool],
374 n: usize,
375) -> Vec<f64> {
376 let mut j_active: Vec<Vec<f64>> = Vec::new();
378 let mut c_active: Vec<f64> = Vec::new();
379 for j in 0..n_eq {
380 j_active.push(c_jac[j].clone());
381 c_active.push(c_vals[j]);
382 }
383 for (k, &is_active) in active_flags.iter().enumerate() {
384 if is_active {
385 j_active.push(c_jac[n_eq + k].clone());
386 c_active.push(c_vals[n_eq + k]);
387 }
388 }
389
390 let m = j_active.len();
391 if m == 0 {
392 return solve_newton_step_slice(hess, grad_f, n);
393 }
394
395 let total = n + m;
399 let mut kkt: Vec<Vec<f64>> = vec![vec![0.0; total]; total];
400 let mut rhs: Vec<f64> = vec![0.0; total];
401
402 for i in 0..n {
404 for j in 0..n {
405 kkt[i][j] = hess[i][j];
406 }
407 kkt[i][i] += 1e-8;
409 }
410
411 for k in 0..m {
413 for i in 0..n {
414 kkt[i][n + k] = j_active[k][i];
415 }
416 }
417
418 for k in 0..m {
420 for j in 0..n {
421 kkt[n + k][j] = j_active[k][j];
422 }
423 }
424
425 for i in 0..n {
427 rhs[i] = -grad_f[i];
428 }
429 for k in 0..m {
430 rhs[n + k] = -c_active[k];
431 }
432
433 let sol = gaussian_elimination(&mut kkt, &mut rhs, total);
435 match sol {
436 Some(x) => {
437 let d = x[0..n].to_vec();
438 let dn = d.iter().map(|v| v * v).sum::<f64>().sqrt();
440 if dn > 10.0 {
441 let scale = 10.0 / dn;
442 d.iter().map(|&v| v * scale).collect()
443 } else {
444 d
445 }
446 }
447 None => {
448 solve_newton_step_slice(hess, grad_f, n)
450 }
451 }
452}
453
454fn solve_newton_step_slice(hess: &[Vec<f64>], grad: &[f64], n: usize) -> Vec<f64> {
456 let mut a: Vec<Vec<f64>> = hess.iter().map(|row| row.to_vec()).collect();
457 let mut b: Vec<f64> = grad.iter().map(|&gi| -gi).collect();
458 for i in 0..n {
460 a[i][i] += 1e-8;
461 }
462 match gaussian_elimination(&mut a, &mut b, n) {
463 Some(d) => {
464 let dn = d.iter().map(|v| v * v).sum::<f64>().sqrt();
465 if dn > 10.0 {
466 let scale = 10.0 / dn;
467 d.iter().map(|&v| v * scale).collect()
468 } else {
469 d
470 }
471 }
472 None => grad.iter().map(|&gi| -gi * 0.01).collect(),
473 }
474}
475
476fn gaussian_elimination(a: &mut Vec<Vec<f64>>, b: &mut Vec<f64>, n: usize) -> Option<Vec<f64>> {
479 for col in 0..n {
480 let mut max_row = col;
482 let mut max_val = a[col][col].abs();
483 for row in (col + 1)..n {
484 if a[row][col].abs() > max_val {
485 max_val = a[row][col].abs();
486 max_row = row;
487 }
488 }
489 if max_val < 1e-14 {
490 return None;
491 }
492 a.swap(col, max_row);
493 b.swap(col, max_row);
494
495 let pivot = a[col][col];
496 for row in (col + 1)..n {
497 let factor = a[row][col] / pivot;
498 for k in col..n {
499 let val = a[col][k] * factor;
500 a[row][k] -= val;
501 }
502 let bv = b[col] * factor;
503 b[row] -= bv;
504 }
505 }
506
507 let mut x = vec![0.0; n];
508 for i in (0..n).rev() {
509 let mut sum = b[i];
510 for j in (i + 1)..n {
511 sum -= a[i][j] * x[j];
512 }
513 if a[i][i].abs() < 1e-14 {
514 return None;
515 }
516 x[i] = sum / a[i][i];
517 }
518 Some(x)
519}
520
521fn solve_newton_step(hess: &Vec<Vec<f64>>, grad: &[f64], n: usize) -> Vec<f64> {
523 let mut a: Vec<Vec<f64>> = hess.iter().map(|row| row.clone()).collect();
525 let mut b: Vec<f64> = grad.iter().map(|&gi| -gi).collect();
526
527 for col in 0..n {
529 let mut max_row = col;
531 let mut max_val = a[col][col].abs();
532 for row in (col + 1)..n {
533 if a[row][col].abs() > max_val {
534 max_val = a[row][col].abs();
535 max_row = row;
536 }
537 }
538
539 if max_val < 1e-12 {
540 return grad.iter().map(|&gi| -gi * 0.01).collect();
542 }
543
544 a.swap(col, max_row);
545 b.swap(col, max_row);
546
547 let pivot = a[col][col];
548 for row in (col + 1)..n {
549 let factor = a[row][col] / pivot;
550 for k in col..n {
551 let val = a[col][k] * factor;
552 a[row][k] -= val;
553 }
554 let bv = b[col] * factor;
555 b[row] -= bv;
556 }
557 }
558
559 let mut d = vec![0.0; n];
561 for i in (0..n).rev() {
562 let mut sum = b[i];
563 for j in (i + 1)..n {
564 sum -= a[i][j] * d[j];
565 }
566 if a[i][i].abs() < 1e-12 {
567 d[i] = 0.0;
568 } else {
569 d[i] = sum / a[i][i];
570 }
571 }
572
573 let dn = d.iter().map(|v| v * v).sum::<f64>().sqrt();
575 if dn > 10.0 {
576 let scale = 10.0 / dn;
577 d.iter_mut().for_each(|v| *v *= scale);
578 }
579 d
580}
581
582fn update_lagrange_multipliers(
584 grad: &[f64],
585 c_jac: &[Vec<f64>],
586 c_vals: &[f64],
587 lambda: &mut Vec<f64>,
588 n_eq: usize,
589 n_ineq: usize,
590) {
591 let n = grad.len();
592 let m = lambda.len();
593 if m == 0 {
594 return;
595 }
596
597 let mut jjt = vec![vec![0.0_f64; m]; m];
602 let mut jg = vec![0.0_f64; m];
603
604 for i in 0..m {
605 for j in 0..m {
606 let dot: f64 = c_jac[i]
607 .iter()
608 .zip(c_jac[j].iter())
609 .map(|(&a, &b)| a * b)
610 .sum();
611 jjt[i][j] = dot;
612 }
613 jg[i] = c_jac[i]
614 .iter()
615 .zip(grad.iter())
616 .map(|(&a, &b)| a * b)
617 .sum::<f64>();
618 }
619
620 for i in 0..m {
622 jjt[i][i] += 1e-8;
623 }
624
625 if let Some(new_lambda) = solve_small_system_sqp(&jjt, &jg) {
626 for i in 0..m {
627 lambda[i] = -new_lambda[i];
628 }
629 for i in n_eq..n_eq + n_ineq {
631 if c_vals[i] > 0.0 {
632 lambda[i] = lambda[i].max(0.0);
633 } else if c_vals[i] < -1e-6 {
634 lambda[i] = 0.0;
636 }
637 }
638 }
639}
640
641fn solve_small_system_sqp(a: &Vec<Vec<f64>>, b: &[f64]) -> Option<Vec<f64>> {
643 let n = b.len();
644 if n == 0 {
645 return Some(vec![]);
646 }
647 let mut m: Vec<Vec<f64>> = a.iter().map(|row| row.clone()).collect();
648 let mut r: Vec<f64> = b.to_vec();
649
650 for col in 0..n {
651 let mut max_row = col;
652 let mut max_val = m[col][col].abs();
653 for row in (col + 1)..n {
654 if m[row][col].abs() > max_val {
655 max_val = m[row][col].abs();
656 max_row = row;
657 }
658 }
659 if max_val < 1e-12 {
660 return None;
661 }
662 m.swap(col, max_row);
663 r.swap(col, max_row);
664
665 let pivot = m[col][col];
666 for row in (col + 1)..n {
667 let factor = m[row][col] / pivot;
668 for k in col..n {
669 let val = m[col][k] * factor;
670 m[row][k] -= val;
671 }
672 let rv = r[col] * factor;
673 r[row] -= rv;
674 }
675 }
676
677 let mut x = vec![0.0; n];
678 for i in (0..n).rev() {
679 let mut sum = r[i];
680 for j in (i + 1)..n {
681 sum -= m[i][j] * x[j];
682 }
683 if m[i][i].abs() < 1e-12 {
684 return None;
685 }
686 x[i] = sum / m[i][i];
687 }
688 Some(x)
689}
690
691#[cfg(test)]
692mod tests {
693 use super::*;
694 use approx::assert_abs_diff_eq;
695
696 #[test]
697 fn test_sqp_equality_constrained() {
698 let f = |x: &[f64]| x[0].powi(2) + x[1].powi(2);
700 let h = |x: &[f64]| x[0] + x[1] - 1.0;
701
702 let solver = SQPSolver::default();
703 let result = solver
704 .solve(
705 f,
706 None::<fn(&[f64]) -> Vec<f64>>,
707 &[h],
708 &[] as &[fn(&[f64]) -> f64],
709 &[0.0, 0.0],
710 )
711 .expect("solve failed");
712
713 assert_abs_diff_eq!(result.x[0], 0.5, epsilon = 1e-3);
714 assert_abs_diff_eq!(result.x[1], 0.5, epsilon = 1e-3);
715 assert_abs_diff_eq!(result.fun, 0.5, epsilon = 1e-3);
716 }
717
718 #[test]
719 fn test_sqp_inequality_constrained() {
720 let f = |x: &[f64]| x[0].powi(2) + x[1].powi(2);
722 let g = |x: &[f64]| 1.0 - x[0] - x[1];
723
724 let solver = SQPSolver::default();
725 let result = solver
726 .solve(
727 f,
728 None::<fn(&[f64]) -> Vec<f64>>,
729 &[] as &[fn(&[f64]) -> f64],
730 &[g],
731 &[2.0, 2.0],
732 )
733 .expect("solve failed");
734
735 assert_abs_diff_eq!(result.fun, 0.5, epsilon = 1e-2);
737 }
738
739 #[test]
740 fn test_sqp_unconstrained() {
741 let f = |x: &[f64]| (x[0] - 1.0).powi(2) + (x[1] - 2.0).powi(2);
742
743 let solver = SQPSolver::default();
744 let result = solver
745 .solve(
746 f,
747 None::<fn(&[f64]) -> Vec<f64>>,
748 &[] as &[fn(&[f64]) -> f64],
749 &[] as &[fn(&[f64]) -> f64],
750 &[0.0, 0.0],
751 )
752 .expect("solve failed");
753
754 assert_abs_diff_eq!(result.x[0], 1.0, epsilon = 1e-3);
755 assert_abs_diff_eq!(result.x[1], 2.0, epsilon = 1e-3);
756 assert_abs_diff_eq!(result.fun, 0.0, epsilon = 1e-4);
757 }
758
759 #[test]
760 fn test_sqp_with_gradient() {
761 let f = |x: &[f64]| x[0].powi(2) + x[1].powi(2);
763 let gf = |x: &[f64]| vec![2.0 * x[0], 2.0 * x[1]];
764 let h = |x: &[f64]| x[0] + x[1] - 2.0;
765
766 let solver = SQPSolver::default();
767 let result = solver
768 .solve(f, Some(gf), &[h], &[] as &[fn(&[f64]) -> f64], &[0.0, 0.0])
769 .expect("solve failed");
770
771 assert_abs_diff_eq!(result.x[0], 1.0, epsilon = 1e-3);
772 assert_abs_diff_eq!(result.x[1], 1.0, epsilon = 1e-3);
773 assert_abs_diff_eq!(result.fun, 2.0, epsilon = 1e-3);
774 }
775
776 #[test]
777 fn test_sqp_mixed_constraints() {
778 let f = |x: &[f64]| (x[0] - 2.0).powi(2) + (x[1] - 2.0).powi(2);
780 let h = |x: &[f64]| x[0] + x[1] - 3.0;
781 let g = |x: &[f64]| x[0] - 2.0;
782
783 let solver = SQPSolver {
784 max_iter: 300,
785 tol: 1e-6,
786 constraint_tol: 1e-5,
787 ..Default::default()
788 };
789 let result = solver
790 .solve(f, None::<fn(&[f64]) -> Vec<f64>>, &[h], &[g], &[1.0, 2.0])
791 .expect("solve failed");
792
793 assert!(result.fun <= 1.0 + 1e-3, "fun={}", result.fun);
796 }
797
798 #[test]
799 fn test_sqp_3d_equality() {
800 let f = |x: &[f64]| x[0].powi(2) + x[1].powi(2) + x[2].powi(2);
803 let h = |x: &[f64]| x[0] + 2.0 * x[1] + 3.0 * x[2] - 6.0;
804
805 let solver = SQPSolver {
806 max_iter: 500,
807 ..Default::default()
808 };
809 let result = solver
810 .solve(
811 f,
812 None::<fn(&[f64]) -> Vec<f64>>,
813 &[h],
814 &[] as &[fn(&[f64]) -> f64],
815 &[1.0, 1.0, 1.0],
816 )
817 .expect("solve failed");
818
819 assert!(result.fun <= 3.0, "fun={}", result.fun);
821 }
822}