1use super::implicit_diff::identify_active_constraints;
18use super::kkt_sensitivity::{kkt_sensitivity, regularize_q};
19use super::types::{DiffOptGrad, DiffOptParams, DiffOptResult, DiffOptStatus};
20use crate::error::{OptimizeError, OptimizeResult};
21
22#[derive(Debug, Clone)]
28pub struct QpLayerConfig {
29 pub max_iter: usize,
31 pub tol: f64,
33 pub rho: f64,
35 pub regularization: f64,
37 pub active_tol: f64,
39 pub verbose: bool,
41}
42
43impl Default for QpLayerConfig {
44 fn default() -> Self {
45 Self {
46 max_iter: 100,
47 tol: 1e-8,
48 rho: 1.0,
49 regularization: 1e-7,
50 active_tol: 1e-6,
51 verbose: false,
52 }
53 }
54}
55
56fn cholesky(a: &[Vec<f64>]) -> OptimizeResult<Vec<Vec<f64>>> {
63 let n = a.len();
64 let mut l = vec![vec![0.0_f64; n]; n];
65
66 for i in 0..n {
67 for j in 0..=i {
68 let mut sum = 0.0_f64;
69 for k in 0..j {
70 sum += l[i][k] * l[j][k];
71 }
72 if i == j {
73 let diag = a[i][i] - sum;
74 if diag <= 0.0 {
75 return Err(OptimizeError::ComputationError(format!(
76 "Cholesky failed: non-positive diagonal at index {}. diag = {diag}",
77 i
78 )));
79 }
80 l[i][j] = diag.sqrt();
81 } else {
82 let l_jj = l[j][j];
83 if l_jj.abs() < 1e-30 {
84 return Err(OptimizeError::ComputationError(
85 "Cholesky failed: zero diagonal element".to_string(),
86 ));
87 }
88 l[i][j] = (a[i][j] - sum) / l_jj;
89 }
90 }
91 }
92 Ok(l)
93}
94
95fn forward_sub(l: &[Vec<f64>], b: &[f64]) -> Vec<f64> {
97 let n = b.len();
98 let mut y = vec![0.0_f64; n];
99 for i in 0..n {
100 let mut sum = b[i];
101 for j in 0..i {
102 sum -= l[i][j] * y[j];
103 }
104 let diag = l[i][i];
105 y[i] = if diag.abs() < 1e-30 { 0.0 } else { sum / diag };
106 }
107 y
108}
109
110fn backward_sub(l: &[Vec<f64>], y: &[f64]) -> Vec<f64> {
112 let n = y.len();
113 let mut x = vec![0.0_f64; n];
114 for i in (0..n).rev() {
115 let mut sum = y[i];
116 for j in (i + 1)..n {
117 sum -= l[j][i] * x[j];
118 }
119 let diag = l[i][i];
120 x[i] = if diag.abs() < 1e-30 { 0.0 } else { sum / diag };
121 }
122 x
123}
124
125fn cholesky_solve(a: &[Vec<f64>], b: &[f64]) -> OptimizeResult<Vec<f64>> {
128 match cholesky(a) {
129 Ok(l) => {
130 let y = forward_sub(&l, b);
131 Ok(backward_sub(&l, &y))
132 }
133 Err(_) => {
134 super::implicit_diff::solve_implicit_system(a, b)
136 }
137 }
138}
139
140#[derive(Debug, Clone)]
149pub struct QpLayer {
150 config: QpLayerConfig,
151 warm_x: Option<Vec<f64>>,
153 warm_z: Option<Vec<f64>>,
155 warm_u: Option<Vec<f64>>,
157 last_result: Option<QpForwardCache>,
159}
160
161#[derive(Debug, Clone)]
163struct QpForwardCache {
164 x: Vec<f64>,
165 lambda: Vec<f64>, nu: Vec<f64>, q: Vec<Vec<f64>>,
168 c: Vec<f64>,
169 a_eq: Vec<Vec<f64>>,
170 b_eq: Vec<f64>,
171 g_ineq: Vec<Vec<f64>>,
172 h_ineq: Vec<f64>,
173}
174
175impl QpLayer {
176 pub fn new() -> Self {
178 Self {
179 config: QpLayerConfig::default(),
180 warm_x: None,
181 warm_z: None,
182 warm_u: None,
183 last_result: None,
184 }
185 }
186
187 pub fn with_config(config: QpLayerConfig) -> Self {
189 Self {
190 config,
191 warm_x: None,
192 warm_z: None,
193 warm_u: None,
194 last_result: None,
195 }
196 }
197
198 pub fn forward(
214 &mut self,
215 q: Vec<Vec<f64>>,
216 c: Vec<f64>,
217 a_eq: Vec<Vec<f64>>,
218 b_eq: Vec<f64>,
219 g_ineq: Vec<Vec<f64>>,
220 h_ineq: Vec<f64>,
221 ) -> OptimizeResult<DiffOptResult> {
222 let n = c.len();
223 let p = b_eq.len();
224 let m = h_ineq.len();
225 let nc = p + m; if q.len() != n {
229 return Err(OptimizeError::InvalidInput(format!(
230 "Q rows ({}) != n ({})",
231 q.len(),
232 n
233 )));
234 }
235 if a_eq.len() != p {
236 return Err(OptimizeError::InvalidInput(format!(
237 "A_eq rows ({}) != p ({})",
238 a_eq.len(),
239 p
240 )));
241 }
242 if g_ineq.len() != m {
243 return Err(OptimizeError::InvalidInput(format!(
244 "G_ineq rows ({}) != m ({})",
245 g_ineq.len(),
246 m
247 )));
248 }
249
250 let q_reg = regularize_q(&q, self.config.regularization);
252 let rho = self.config.rho;
253
254 let c_mat: Vec<Vec<f64>> = a_eq.iter().cloned().chain(g_ineq.iter().cloned()).collect();
256
257 let mut m_mat = q_reg.clone();
259 for row in &c_mat {
260 for i in 0..n {
261 for j in 0..n {
262 let ci = if i < row.len() { row[i] } else { 0.0 };
263 let cj = if j < row.len() { row[j] } else { 0.0 };
264 m_mat[i][j] += rho * ci * cj;
265 }
266 }
267 }
268
269 let mut x = self
271 .warm_x
272 .as_ref()
273 .filter(|wx| wx.len() == n)
274 .cloned()
275 .unwrap_or_else(|| vec![0.0_f64; n]);
276
277 let mut z = self
278 .warm_z
279 .as_ref()
280 .filter(|wz| wz.len() == nc)
281 .cloned()
282 .unwrap_or_else(|| {
283 let mut z0 = Vec::with_capacity(nc);
285 z0.extend_from_slice(&b_eq);
286 z0.extend(h_ineq.iter().map(|&hi| hi / 2.0));
287 z0
288 });
289
290 let mut u = self
291 .warm_u
292 .as_ref()
293 .filter(|wu| wu.len() == nc)
294 .cloned()
295 .unwrap_or_else(|| vec![0.0_f64; nc]);
296
297 let mut converged = false;
298 let mut iterations = 0_usize;
299
300 for iter in 0..self.config.max_iter {
301 iterations = iter + 1;
302
303 let mut rhs_x = c.iter().map(|&ci| -ci).collect::<Vec<_>>();
305 for (k, row) in c_mat.iter().enumerate() {
306 let zu_k =
307 if k < z.len() { z[k] } else { 0.0 } - if k < u.len() { u[k] } else { 0.0 };
308 for j in 0..n {
309 let ckj = if j < row.len() { row[j] } else { 0.0 };
310 rhs_x[j] += rho * ckj * zu_k;
311 }
312 }
313
314 let x_new = cholesky_solve(&m_mat, &rhs_x)?;
315
316 let mut cx = vec![0.0_f64; nc];
318 for (k, row) in c_mat.iter().enumerate() {
319 for j in 0..n {
320 let ckj = if j < row.len() { row[j] } else { 0.0 };
321 cx[k] += ckj * x_new[j];
322 }
323 }
324
325 let mut z_new = vec![0.0_f64; nc];
326 for k in 0..p {
328 z_new[k] = if k < b_eq.len() { b_eq[k] } else { 0.0 };
329 }
330 for k in 0..m {
332 let raw = cx[p + k] + u[p + k];
333 let h_k = if k < h_ineq.len() { h_ineq[k] } else { 0.0 };
334 z_new[p + k] = raw.min(h_k);
335 }
336
337 let mut u_new = vec![0.0_f64; nc];
339 for k in 0..nc {
340 u_new[k] = u[k] + cx[k] - z_new[k];
341 }
342
343 let primal_res: f64 = cx
345 .iter()
346 .zip(z_new.iter())
347 .map(|(a, b)| (a - b).powi(2))
348 .sum::<f64>()
349 .sqrt();
350 let dual_res: f64 = {
351 let mut dr = 0.0_f64;
353 for k in 0..nc {
354 let dz = z_new[k] - z[k];
355 for j in 0..n {
356 let ckj = if j < c_mat[k].len() { c_mat[k][j] } else { 0.0 };
357 dr += (rho * ckj * dz).powi(2);
358 }
359 }
360 dr.sqrt()
361 };
362
363 if self.config.verbose {
364 eprintln!(
365 "iter {}: primal_res={:.2e}, dual_res={:.2e}",
366 iter, primal_res, dual_res
367 );
368 }
369
370 x = x_new;
371 z = z_new;
372 u = u_new;
373
374 if primal_res < self.config.tol && dual_res < self.config.tol {
375 converged = true;
376 break;
377 }
378 }
379
380 let nu: Vec<f64> = u[..p].iter().map(|&ui| rho * ui).collect();
383 let lambda: Vec<f64> = u[p..].iter().map(|&ui| rho * ui.max(0.0)).collect();
384
385 let mut obj = 0.0_f64;
387 for i in 0..n {
388 obj += c[i] * x[i];
389 for j in 0..n {
390 let q_ij = if i < q.len() && j < q[i].len() {
391 q[i][j]
392 } else {
393 0.0
394 };
395 obj += 0.5 * q_ij * x[i] * x[j];
396 }
397 }
398
399 let status = if converged {
400 DiffOptStatus::Optimal
401 } else {
402 DiffOptStatus::MaxIterations
403 };
404
405 self.warm_x = Some(x.clone());
407 self.warm_z = Some(z);
408 self.warm_u = Some(u);
409
410 self.last_result = Some(QpForwardCache {
412 x: x.clone(),
413 lambda: lambda.clone(),
414 nu: nu.clone(),
415 q: q.clone(),
416 c: c.clone(),
417 a_eq: a_eq.clone(),
418 b_eq: b_eq.clone(),
419 g_ineq: g_ineq.clone(),
420 h_ineq: h_ineq.clone(),
421 });
422
423 Ok(DiffOptResult {
424 x,
425 lambda,
426 nu,
427 objective: obj,
428 status,
429 iterations,
430 })
431 }
432
433 pub fn backward(&self, dl_dx: &[f64]) -> OptimizeResult<DiffOptGrad> {
446 let cache = self.last_result.as_ref().ok_or_else(|| {
447 OptimizeError::ComputationError("QpLayer::backward called before forward".to_string())
448 })?;
449
450 let n = cache.x.len();
451 if dl_dx.len() != n {
452 return Err(OptimizeError::InvalidInput(format!(
453 "dl_dx length {} != n {}",
454 dl_dx.len(),
455 n
456 )));
457 }
458
459 let active_idx = identify_active_constraints(
461 &cache.g_ineq,
462 &cache.h_ineq,
463 &cache.x,
464 self.config.active_tol,
465 );
466
467 let mut a_aug: Vec<Vec<f64>> = cache.a_eq.clone();
469 let mut b_aug: Vec<f64> = cache.b_eq.clone();
470 let mut nu_aug: Vec<f64> = cache.nu.clone();
471
472 for &ai in &active_idx {
473 if ai < cache.g_ineq.len() {
474 a_aug.push(cache.g_ineq[ai].clone());
475 b_aug.push(cache.h_ineq.get(ai).copied().unwrap_or(0.0));
476 nu_aug.push(cache.lambda.get(ai).copied().unwrap_or(0.0));
477 }
478 }
479
480 let q_reg = regularize_q(&cache.q, self.config.regularization);
482
483 let kkt_grad = kkt_sensitivity(&q_reg, &a_aug, &cache.x, &nu_aug, dl_dx)?;
485
486 let p = cache.a_eq.len();
488 let m_full = cache.g_ineq.len();
489
490 let dl_da_eq: Option<Vec<Vec<f64>>> = if p > 0 {
491 Some(kkt_grad.dl_da[..p].to_vec())
492 } else {
493 None
494 };
495
496 let dl_db_eq = kkt_grad.dl_db[..p].to_vec();
497
498 let mut dl_dg = vec![vec![0.0_f64; n]; m_full];
500 let mut dl_dh = vec![0.0_f64; m_full];
501 for (idx, &ai) in active_idx.iter().enumerate() {
502 let aug_idx = p + idx;
503 if ai < m_full && aug_idx < kkt_grad.dl_da.len() {
504 dl_dg[ai] = kkt_grad.dl_da[aug_idx].clone();
505 dl_dh[ai] = kkt_grad.dl_db.get(aug_idx).copied().unwrap_or(0.0);
506 }
507 }
508
509 Ok(DiffOptGrad {
510 dl_dq: Some(kkt_grad.dl_dq),
511 dl_dc: kkt_grad.dl_dc,
512 dl_da: dl_da_eq,
513 dl_db: dl_db_eq,
514 dl_dg: Some(dl_dg),
515 dl_dh,
516 })
517 }
518
519 pub fn last_solution(&self) -> Option<&[f64]> {
521 self.last_result.as_ref().map(|r| r.x.as_slice())
522 }
523
524 pub fn reset_warm_start(&mut self) {
526 self.warm_x = None;
527 self.warm_z = None;
528 self.warm_u = None;
529 }
530}
531
532impl Default for QpLayer {
533 fn default() -> Self {
534 Self::new()
535 }
536}
537
538#[cfg(test)]
543mod tests {
544 use super::*;
545
546 fn make_identity_qp(n: usize) -> (Vec<Vec<f64>>, Vec<f64>) {
547 let q = (0..n)
548 .map(|i| {
549 let mut row = vec![0.0_f64; n];
550 row[i] = 2.0; row
552 })
553 .collect();
554 let c = vec![0.0_f64; n];
555 (q, c)
556 }
557
558 #[test]
559 fn test_qp_layer_config_default() {
560 let cfg = QpLayerConfig::default();
561 assert_eq!(cfg.max_iter, 100);
562 assert!((cfg.tol - 1e-8).abs() < 1e-15);
563 assert!(!cfg.verbose);
564 assert!((cfg.rho - 1.0).abs() < 1e-15);
565 }
566
567 #[test]
568 fn test_qp_layer_identity_q_zero_c() {
569 let mut layer = QpLayer::new();
572 let (q, c) = make_identity_qp(2);
573 let a_eq = vec![vec![1.0, 1.0]];
574 let b_eq = vec![0.0];
575
576 let result = layer
577 .forward(q, c, a_eq, b_eq, vec![], vec![])
578 .expect("Forward failed");
579
580 assert!(
581 result.x[0].abs() < 1e-4,
582 "x[0] = {} (expected 0)",
583 result.x[0]
584 );
585 assert!(
586 result.x[1].abs() < 1e-4,
587 "x[1] = {} (expected 0)",
588 result.x[1]
589 );
590 }
591
592 #[test]
593 fn test_qp_layer_forward_unconstrained() {
594 let mut layer = QpLayer::new();
596 let q = vec![vec![2.0, 0.0], vec![0.0, 2.0]];
597 let c = vec![1.0, 2.0];
598
599 let result = layer
600 .forward(q, c, vec![], vec![], vec![], vec![])
601 .expect("Forward failed");
602
603 assert!(
604 (result.x[0] - (-0.5)).abs() < 1e-3,
605 "x[0] = {} (expected -0.5)",
606 result.x[0]
607 );
608 assert!(
609 (result.x[1] - (-1.0)).abs() < 1e-3,
610 "x[1] = {} (expected -1.0)",
611 result.x[1]
612 );
613 }
614
615 #[test]
616 fn test_qp_layer_forward_with_equality() {
617 let mut layer = QpLayer::new();
619 let q = vec![vec![2.0, 0.0], vec![0.0, 2.0]];
620 let c = vec![0.0, 0.0];
621 let a_eq = vec![vec![1.0, 1.0]];
622 let b_eq = vec![1.0];
623
624 let result = layer
625 .forward(q, c, a_eq, b_eq, vec![], vec![])
626 .expect("Forward failed");
627
628 assert!(
629 (result.x[0] - 0.5).abs() < 1e-3,
630 "x[0] = {} (expected 0.5)",
631 result.x[0]
632 );
633 assert!(
634 (result.x[1] - 0.5).abs() < 1e-3,
635 "x[1] = {} (expected 0.5)",
636 result.x[1]
637 );
638 }
639
640 #[test]
641 fn test_qp_layer_forward_with_inequality() {
642 let mut layer = QpLayer::new();
644 let q = vec![vec![2.0, 0.0], vec![0.0, 2.0]];
645 let c = vec![0.0, 0.0];
646 let g = vec![vec![-1.0, -1.0]];
647 let h = vec![-1.0];
648
649 let result = layer
650 .forward(q, c, vec![], vec![], g, h)
651 .expect("Forward failed");
652
653 let sum = result.x[0] + result.x[1];
655 assert!(sum >= 1.0 - 1e-3, "x + y = {} (should be >= 1)", sum);
656 }
657
658 #[test]
659 fn test_qp_layer_backward_no_forward_error() {
660 let layer = QpLayer::new();
661 let result = layer.backward(&[1.0, 0.0]);
662 assert!(result.is_err(), "Should error without forward pass");
663 }
664
665 #[test]
666 fn test_qp_layer_backward_dl_dc_finite() {
667 let mut layer = QpLayer::new();
668 let q = vec![vec![2.0, 0.0], vec![0.0, 2.0]];
669 let c = vec![1.0, 2.0];
670
671 let result = layer
672 .forward(q, c, vec![], vec![], vec![], vec![])
673 .expect("Forward failed");
674 let _ = result;
675
676 let grad = layer.backward(&[1.0, 0.0]).expect("Backward failed");
677 assert_eq!(grad.dl_dc.len(), 2);
678 assert!(grad.dl_dc[0].is_finite(), "dl/dc[0] not finite");
679 assert!(grad.dl_dc[1].is_finite(), "dl/dc[1] not finite");
680 }
681
682 #[test]
683 fn test_qp_layer_backward_gradient_check() {
684 let eps = 1e-5_f64;
691 let c_base = vec![1.0_f64, 0.0];
692 let q = vec![vec![2.0, 0.0], vec![0.0, 2.0]];
693
694 let solve_and_loss = |c_vec: Vec<f64>| -> f64 {
695 let mut layer = QpLayer::new();
696 let res = layer
697 .forward(q.clone(), c_vec, vec![], vec![], vec![], vec![])
698 .expect("Forward failed");
699 res.x.iter().map(|&xi| 0.5 * xi * xi).sum::<f64>()
700 };
701
702 let mut layer = QpLayer::new();
704 let res = layer
705 .forward(q.clone(), c_base.clone(), vec![], vec![], vec![], vec![])
706 .expect("Forward failed");
707 let dl_dx = res.x.clone(); let grad = layer.backward(&dl_dx).expect("Backward failed");
709
710 let mut c_plus = c_base.clone();
712 c_plus[0] += eps;
713 let mut c_minus = c_base.clone();
714 c_minus[0] -= eps;
715 let fd_dc0 = (solve_and_loss(c_plus) - solve_and_loss(c_minus)) / (2.0 * eps);
716
717 assert!(
718 (grad.dl_dc[0] - fd_dc0).abs() < 1e-3,
719 "dl/dc[0] analytical={} vs FD={}",
720 grad.dl_dc[0],
721 fd_dc0
722 );
723 }
724
725 #[test]
726 fn test_qp_layer_active_set_identification() {
727 let mut layer = QpLayer::new();
730 let q = vec![vec![2.0, 0.0], vec![0.0, 2.0]];
731 let c = vec![0.0, 0.0];
732 let g = vec![
733 vec![-1.0, 0.0], vec![0.0, -1.0], vec![-1.0, -1.0], ];
737 let h = vec![0.0, 0.0, -0.5];
738
739 let result = layer
740 .forward(q, c, vec![], vec![], g, h)
741 .expect("Forward failed");
742
743 let sum = result.x[0] + result.x[1];
745 assert!(sum >= 0.5 - 1e-3, "x + y = {} (should be >= 0.5)", sum);
746 }
747
748 #[test]
749 fn test_qp_layer_warm_start() {
750 let mut layer = QpLayer::new();
752 let q = vec![vec![2.0, 0.0], vec![0.0, 2.0]];
753 let c = vec![1.0, 1.0];
754
755 let res1 = layer
756 .forward(q.clone(), c.clone(), vec![], vec![], vec![], vec![])
757 .expect("Forward 1 failed");
758
759 let res2 = layer
760 .forward(q, c, vec![], vec![], vec![], vec![])
761 .expect("Forward 2 failed");
762
763 assert!(
765 (res1.x[0] - res2.x[0]).abs() < 1e-6,
766 "Warm-start inconsistency"
767 );
768 }
769
770 #[test]
771 fn test_qp_layer_last_solution() {
772 let mut layer = QpLayer::new();
773 let q = vec![vec![2.0, 0.0], vec![0.0, 2.0]];
774 let c = vec![1.0, 0.0];
775
776 assert!(layer.last_solution().is_none());
777 layer
778 .forward(q, c, vec![], vec![], vec![], vec![])
779 .expect("Forward failed");
780 assert!(layer.last_solution().is_some());
781 }
782
783 #[test]
784 fn test_cholesky_solve_identity() {
785 let a = vec![vec![4.0, 0.0], vec![0.0, 9.0]];
786 let b = vec![8.0, 18.0];
787 let x = cholesky_solve(&a, &b).expect("Cholesky solve failed");
788 assert!((x[0] - 2.0).abs() < 1e-10, "x[0] = {}", x[0]);
789 assert!((x[1] - 2.0).abs() < 1e-10, "x[1] = {}", x[1]);
790 }
791}