1use super::implicit_diff;
18use super::types::{BackwardMode, DiffQPConfig, DiffQPResult, ImplicitGradient};
19use crate::error::{OptimizeError, OptimizeResult};
20
21#[derive(Debug, Clone)]
26pub struct DifferentiableQP {
27 pub q: Vec<Vec<f64>>,
29 pub c: Vec<f64>,
31 pub g: Vec<Vec<f64>>,
33 pub h: Vec<f64>,
35 pub a: Vec<Vec<f64>>,
37 pub b: Vec<f64>,
39}
40
41impl DifferentiableQP {
42 pub fn new(
52 q: Vec<Vec<f64>>,
53 c: Vec<f64>,
54 g: Vec<Vec<f64>>,
55 h: Vec<f64>,
56 a: Vec<Vec<f64>>,
57 b: Vec<f64>,
58 ) -> OptimizeResult<Self> {
59 let n = c.len();
60 if q.len() != n {
61 return Err(OptimizeError::InvalidInput(format!(
62 "Q has {} rows but c has length {}",
63 q.len(),
64 n
65 )));
66 }
67 for (i, row) in q.iter().enumerate() {
68 if row.len() != n {
69 return Err(OptimizeError::InvalidInput(format!(
70 "Q row {} has length {} but expected {}",
71 i,
72 row.len(),
73 n
74 )));
75 }
76 }
77 for (i, row) in g.iter().enumerate() {
78 if row.len() != n {
79 return Err(OptimizeError::InvalidInput(format!(
80 "G row {} has length {} but expected {}",
81 i,
82 row.len(),
83 n
84 )));
85 }
86 }
87 if g.len() != h.len() {
88 return Err(OptimizeError::InvalidInput(format!(
89 "G has {} rows but h has length {}",
90 g.len(),
91 h.len()
92 )));
93 }
94 for (i, row) in a.iter().enumerate() {
95 if row.len() != n {
96 return Err(OptimizeError::InvalidInput(format!(
97 "A row {} has length {} but expected {}",
98 i,
99 row.len(),
100 n
101 )));
102 }
103 }
104 if a.len() != b.len() {
105 return Err(OptimizeError::InvalidInput(format!(
106 "A has {} rows but b has length {}",
107 a.len(),
108 b.len()
109 )));
110 }
111
112 Ok(Self { q, c, g, h, a, b })
113 }
114
115 pub fn n(&self) -> usize {
117 self.c.len()
118 }
119
120 pub fn m(&self) -> usize {
122 self.h.len()
123 }
124
125 pub fn p(&self) -> usize {
127 self.b.len()
128 }
129
130 pub fn forward(&self, config: &DiffQPConfig) -> OptimizeResult<DiffQPResult> {
135 let n = self.n();
136 let m = self.m();
137 let p = self.p();
138
139 let mut q_reg = self.q.clone();
141 for i in 0..n {
142 q_reg[i][i] += config.regularization;
143 }
144
145 let mut x = vec![0.0; n];
147 let mut lam = vec![1.0; m]; let mut nu = vec![0.0; p]; let mut s = vec![1.0; m]; for i in 0..m {
153 let mut gx_i = 0.0;
154 for j in 0..n {
155 gx_i += self.g[i][j] * x[j];
156 }
157 s[i] = self.h[i] - gx_i;
158 if s[i] <= 0.0 {
159 s[i] = 1.0; }
161 }
162
163 let mut converged = false;
164 let mut iterations = 0;
165
166 for iter in 0..config.max_iterations {
167 iterations = iter + 1;
168
169 let mut r_stat = vec![0.0; n];
172 for i in 0..n {
173 let mut qx_i = 0.0;
174 for j in 0..n {
175 qx_i += q_reg[i][j] * x[j];
176 }
177 r_stat[i] = qx_i + self.c[i];
178 }
179 for k in 0..m {
180 for i in 0..n {
181 r_stat[i] += self.g[k][i] * lam[k];
182 }
183 }
184 for k in 0..p {
185 for i in 0..n {
186 r_stat[i] += self.a[k][i] * nu[k];
187 }
188 }
189
190 let mut r_eq = vec![0.0; p];
192 for i in 0..p {
193 for j in 0..n {
194 r_eq[i] += self.a[i][j] * x[j];
195 }
196 r_eq[i] -= self.b[i];
197 }
198
199 let mut r_ineq = vec![0.0; m];
201 for i in 0..m {
202 let mut gx_i = 0.0;
203 for j in 0..n {
204 gx_i += self.g[i][j] * x[j];
205 }
206 r_ineq[i] = s[i] + gx_i - self.h[i];
207 }
208
209 let mu: f64 = if m > 0 {
211 lam.iter()
212 .zip(s.iter())
213 .map(|(&li, &si)| li * si)
214 .sum::<f64>()
215 / m as f64
216 } else {
217 0.0
218 };
219
220 let res_stat: f64 = r_stat.iter().map(|v| v.abs()).fold(0.0, f64::max);
222 let res_eq: f64 = r_eq.iter().map(|v| v.abs()).fold(0.0, f64::max);
223 let res_ineq: f64 = r_ineq.iter().map(|v| v.abs()).fold(0.0, f64::max);
224 let max_res = res_stat.max(res_eq).max(res_ineq).max(mu);
225
226 if max_res < config.tolerance {
227 converged = true;
228 break;
229 }
230
231 let dim = n + m + p;
235 let mut kkt = vec![vec![0.0; dim]; dim];
236 let mut rhs = vec![0.0; dim];
237
238 for i in 0..n {
240 for j in 0..n {
241 kkt[i][j] = q_reg[i][j];
242 }
243 for k in 0..m {
244 kkt[i][n + k] = self.g[k][i];
245 }
246 for k in 0..p {
247 kkt[i][n + m + k] = self.a[k][i];
248 }
249 rhs[i] = -r_stat[i];
250 }
251
252 let sigma = 0.1_f64; for i in 0..m {
259 let li = lam[i];
260 let si = s[i];
261 for j in 0..n {
262 kkt[n + i][j] = -li * self.g[i][j];
263 }
264 kkt[n + i][n + i] = si;
265 rhs[n + i] = -li * si + sigma * mu + li * r_ineq[i];
266 }
267
268 for i in 0..p {
270 for j in 0..n {
271 kkt[n + m + i][j] = self.a[i][j];
272 }
273 rhs[n + m + i] = -r_eq[i];
274 }
275
276 let dir = match implicit_diff::solve_implicit_system(&kkt, &rhs) {
277 Ok(d) => d,
278 Err(_) => break, };
280
281 let dx = &dir[..n];
282 let dlam = &dir[n..n + m];
283 let dnu = &dir[n + m..];
284
285 let mut ds = vec![0.0; m];
287 for i in 0..m {
288 let mut gx_i = 0.0;
289 for j in 0..n {
290 gx_i += self.g[i][j] * dx[j];
291 }
292 ds[i] = -r_ineq[i] - gx_i;
293 }
294
295 let tau = 0.995;
297 let mut alpha_p = 1.0_f64;
298 let mut alpha_d = 1.0_f64;
299
300 for i in 0..m {
301 if ds[i] < 0.0 {
302 let ratio = -tau * s[i] / ds[i];
303 if ratio < alpha_p {
304 alpha_p = ratio;
305 }
306 }
307 if dlam[i] < 0.0 {
308 let ratio = -tau * lam[i] / dlam[i];
309 if ratio < alpha_d {
310 alpha_d = ratio;
311 }
312 }
313 }
314
315 alpha_p = alpha_p.min(1.0).max(1e-12);
316 alpha_d = alpha_d.min(1.0).max(1e-12);
317
318 for i in 0..n {
320 x[i] += alpha_p * dx[i];
321 }
322 for i in 0..m {
323 s[i] += alpha_p * ds[i];
324 lam[i] += alpha_d * dlam[i];
325 if s[i] < 1e-14 {
327 s[i] = 1e-14;
328 }
329 if lam[i] < 1e-14 {
330 lam[i] = 1e-14;
331 }
332 }
333 for i in 0..p {
334 nu[i] += alpha_d * dnu[i];
335 }
336 }
337
338 let mut obj = 0.0;
340 for i in 0..n {
341 obj += self.c[i] * x[i];
342 for j in 0..n {
343 obj += 0.5 * self.q[i][j] * x[i] * x[j];
344 }
345 }
346
347 Ok(DiffQPResult {
348 optimal_x: x,
349 optimal_lambda: lam,
350 optimal_nu: nu,
351 objective: obj,
352 converged,
353 iterations,
354 })
355 }
356
357 pub fn backward(
362 &self,
363 result: &DiffQPResult,
364 dl_dx: &[f64],
365 config: &DiffQPConfig,
366 ) -> OptimizeResult<ImplicitGradient> {
367 let n = self.n();
368 if dl_dx.len() != n {
369 return Err(OptimizeError::InvalidInput(format!(
370 "dl_dx length {} != n {}",
371 dl_dx.len(),
372 n
373 )));
374 }
375
376 let mut q_reg = self.q.clone();
378 for i in 0..n {
379 q_reg[i][i] += config.regularization;
380 }
381
382 match config.backward_mode {
383 BackwardMode::FullDifferentiation => implicit_diff::compute_full_implicit_gradient(
384 &q_reg,
385 &self.g,
386 &self.h,
387 &self.a,
388 &result.optimal_x,
389 &result.optimal_lambda,
390 &result.optimal_nu,
391 dl_dx,
392 ),
393 BackwardMode::ActiveSetOnly => {
394 implicit_diff::compute_active_set_implicit_gradient(
395 &q_reg,
396 &self.g,
397 &self.h,
398 &self.a,
399 &result.optimal_x,
400 &result.optimal_lambda,
401 &result.optimal_nu,
402 dl_dx,
403 config.tolerance * 100.0, )
405 }
406 _ => Err(OptimizeError::NotImplementedError(
407 "Unknown backward mode".to_string(),
408 )),
409 }
410 }
411
412 pub fn batched_forward(
416 params_list: &[DifferentiableQP],
417 config: &DiffQPConfig,
418 ) -> OptimizeResult<Vec<DiffQPResult>> {
419 params_list.iter().map(|qp| qp.forward(config)).collect()
420 }
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426
427 #[test]
431 fn test_qp_forward_unconstrained() {
432 let qp = DifferentiableQP::new(
433 vec![vec![2.0, 0.0], vec![0.0, 2.0]],
434 vec![1.0, 2.0],
435 vec![],
436 vec![],
437 vec![],
438 vec![],
439 )
440 .expect("QP creation failed");
441
442 let config = DiffQPConfig::default();
443 let result = qp.forward(&config).expect("Forward solve failed");
444
445 assert!(result.converged, "QP should converge");
446 assert!(
447 (result.optimal_x[0] - (-0.5)).abs() < 1e-4,
448 "x[0] = {} (expected -0.5)",
449 result.optimal_x[0]
450 );
451 assert!(
452 (result.optimal_x[1] - (-1.0)).abs() < 1e-4,
453 "x[1] = {} (expected -1.0)",
454 result.optimal_x[1]
455 );
456 }
457
458 #[test]
463 fn test_qp_forward_with_inequality() {
464 let qp = DifferentiableQP::new(
465 vec![vec![2.0, 0.0], vec![0.0, 2.0]],
466 vec![0.0, 0.0],
467 vec![vec![-1.0, -1.0]], vec![-1.0],
469 vec![],
470 vec![],
471 )
472 .expect("QP creation failed");
473
474 let config = DiffQPConfig::default();
475 let result = qp.forward(&config).expect("Forward solve failed");
476
477 assert!(result.converged);
478 assert!(
479 (result.optimal_x[0] - 0.5).abs() < 1e-3,
480 "x[0] = {} (expected 0.5)",
481 result.optimal_x[0]
482 );
483 assert!(
484 (result.optimal_x[1] - 0.5).abs() < 1e-3,
485 "x[1] = {} (expected 0.5)",
486 result.optimal_x[1]
487 );
488 }
489
490 #[test]
495 fn test_backward_gradient_dl_dc() {
496 let qp = DifferentiableQP::new(
497 vec![vec![2.0, 0.0], vec![0.0, 2.0]],
498 vec![1.0, 2.0],
499 vec![],
500 vec![],
501 vec![],
502 vec![],
503 )
504 .expect("QP creation failed");
505
506 let config = DiffQPConfig::default();
507 let result = qp.forward(&config).expect("Forward solve failed");
508
509 let dl_dx = vec![1.0, 0.0];
511 let grad = qp
512 .backward(&result, &dl_dx, &config)
513 .expect("Backward failed");
514
515 assert!(
519 (grad.dl_dc[0] - (-0.5)).abs() < 1e-3,
520 "dl/dc[0] = {} (expected -0.5)",
521 grad.dl_dc[0]
522 );
523 assert!(
524 grad.dl_dc[1].abs() < 1e-3,
525 "dl/dc[1] = {} (expected 0)",
526 grad.dl_dc[1]
527 );
528 }
529
530 #[test]
532 fn test_backward_finite_difference_c() {
533 let eps = 1e-5;
534 let config = DiffQPConfig::default();
535
536 let q = vec![vec![4.0, 1.0], vec![1.0, 3.0]];
537 let c_base = vec![1.0, -1.0];
538 let g = vec![vec![-1.0, 0.0], vec![0.0, -1.0]]; let h = vec![0.0, 0.0];
540
541 let qp0 = DifferentiableQP::new(
542 q.clone(),
543 c_base.clone(),
544 g.clone(),
545 h.clone(),
546 vec![],
547 vec![],
548 )
549 .expect("QP creation failed");
550 let res0 = qp0.forward(&config).expect("Forward failed");
551 let obj0 = res0.objective;
552
553 let dl_dx = res0.optimal_x.clone();
555 let grad = qp0
556 .backward(&res0, &dl_dx, &config)
557 .expect("Backward failed");
558
559 let mut c_plus = c_base.clone();
561 c_plus[0] += eps;
562 let qp_plus =
563 DifferentiableQP::new(q.clone(), c_plus, g.clone(), h.clone(), vec![], vec![])
564 .expect("QP+ creation failed");
565 let res_plus = qp_plus.forward(&config).expect("Forward+ failed");
566
567 let mut c_minus = c_base.clone();
568 c_minus[0] -= eps;
569 let qp_minus =
570 DifferentiableQP::new(q.clone(), c_minus, g.clone(), h.clone(), vec![], vec![])
571 .expect("QP- creation failed");
572 let res_minus = qp_minus.forward(&config).expect("Forward- failed");
573
574 let loss_plus: f64 = res_plus.optimal_x.iter().map(|v| 0.5 * v * v).sum();
576 let loss_minus: f64 = res_minus.optimal_x.iter().map(|v| 0.5 * v * v).sum();
577 let fd_grad = (loss_plus - loss_minus) / (2.0 * eps);
578
579 assert!(
580 (grad.dl_dc[0] - fd_grad).abs() < 1e-3,
581 "dl/dc[0] analytical={} vs fd={}",
582 grad.dl_dc[0],
583 fd_grad
584 );
585 }
586
587 #[test]
589 fn test_backward_finite_difference_h() {
590 let eps = 1e-5;
591 let config = DiffQPConfig::default();
592
593 let q = vec![vec![2.0, 0.0], vec![0.0, 2.0]];
594 let c = vec![0.0, 0.0];
595 let g = vec![vec![-1.0, -1.0]]; let h_base = vec![-1.0]; let qp0 = DifferentiableQP::new(
599 q.clone(),
600 c.clone(),
601 g.clone(),
602 h_base.clone(),
603 vec![],
604 vec![],
605 )
606 .expect("QP creation failed");
607 let res0 = qp0.forward(&config).expect("Forward failed");
608
609 let dl_dx = res0.optimal_x.clone();
610 let grad = qp0
611 .backward(&res0, &dl_dx, &config)
612 .expect("Backward failed");
613
614 let mut h_plus = h_base.clone();
616 h_plus[0] += eps;
617 let qp_plus =
618 DifferentiableQP::new(q.clone(), c.clone(), g.clone(), h_plus, vec![], vec![])
619 .expect("QP+ creation failed");
620 let res_plus = qp_plus.forward(&config).expect("Forward+ failed");
621
622 let mut h_minus = h_base.clone();
623 h_minus[0] -= eps;
624 let qp_minus =
625 DifferentiableQP::new(q.clone(), c.clone(), g.clone(), h_minus, vec![], vec![])
626 .expect("QP- creation failed");
627 let res_minus = qp_minus.forward(&config).expect("Forward- failed");
628
629 let loss_plus: f64 = res_plus.optimal_x.iter().map(|v| 0.5 * v * v).sum();
630 let loss_minus: f64 = res_minus.optimal_x.iter().map(|v| 0.5 * v * v).sum();
631 let fd_grad = (loss_plus - loss_minus) / (2.0 * eps);
632
633 assert!(
635 (grad.dl_dh[0] - fd_grad).abs() < 0.1,
636 "dl/dh[0] analytical={} vs fd={}",
637 grad.dl_dh[0],
638 fd_grad
639 );
640 }
641
642 #[test]
643 fn test_qp_with_equality_constraint() {
644 let qp = DifferentiableQP::new(
647 vec![vec![2.0, 0.0], vec![0.0, 2.0]],
648 vec![0.0, 0.0],
649 vec![],
650 vec![],
651 vec![vec![1.0, 1.0]],
652 vec![1.0],
653 )
654 .expect("QP creation failed");
655
656 let config = DiffQPConfig::default();
657 let result = qp.forward(&config).expect("Forward failed");
658
659 assert!(result.converged);
660 assert!(
661 (result.optimal_x[0] - 0.5).abs() < 1e-3,
662 "x[0] = {}",
663 result.optimal_x[0]
664 );
665 assert!(
666 (result.optimal_x[1] - 0.5).abs() < 1e-3,
667 "x[1] = {}",
668 result.optimal_x[1]
669 );
670 }
671
672 #[test]
673 fn test_batched_forward_consistency() {
674 let qp1 = DifferentiableQP::new(
675 vec![vec![2.0, 0.0], vec![0.0, 2.0]],
676 vec![1.0, 0.0],
677 vec![],
678 vec![],
679 vec![],
680 vec![],
681 )
682 .expect("QP1 creation failed");
683 let qp2 = DifferentiableQP::new(
684 vec![vec![2.0, 0.0], vec![0.0, 2.0]],
685 vec![0.0, 1.0],
686 vec![],
687 vec![],
688 vec![],
689 vec![],
690 )
691 .expect("QP2 creation failed");
692
693 let config = DiffQPConfig::default();
694 let batch_results = DifferentiableQP::batched_forward(&[qp1.clone(), qp2.clone()], &config)
695 .expect("Batch failed");
696
697 let r1 = qp1.forward(&config).expect("Single 1 failed");
698 let r2 = qp2.forward(&config).expect("Single 2 failed");
699
700 for i in 0..2 {
701 assert!(
702 (batch_results[0].optimal_x[i] - r1.optimal_x[i]).abs() < 1e-10,
703 "Batch[0].x[{}] differs",
704 i
705 );
706 assert!(
707 (batch_results[1].optimal_x[i] - r2.optimal_x[i]).abs() < 1e-10,
708 "Batch[1].x[{}] differs",
709 i
710 );
711 }
712 }
713
714 #[test]
715 fn test_qp_empty_constraints() {
716 let qp = DifferentiableQP::new(vec![vec![2.0]], vec![4.0], vec![], vec![], vec![], vec![])
717 .expect("QP creation failed");
718
719 let config = DiffQPConfig::default();
720 let result = qp.forward(&config).expect("Forward failed");
721 assert!(result.converged);
722 assert!(
724 (result.optimal_x[0] - (-2.0)).abs() < 1e-3,
725 "x = {}",
726 result.optimal_x[0]
727 );
728 }
729
730 #[test]
731 fn test_qp_dimension_validation() {
732 let result = DifferentiableQP::new(
734 vec![vec![1.0, 0.0], vec![0.0, 1.0]],
735 vec![1.0, 2.0, 3.0],
736 vec![],
737 vec![],
738 vec![],
739 vec![],
740 );
741 assert!(result.is_err());
742 }
743
744 #[test]
745 fn test_qp_degenerate_active_constraints() {
746 let qp = DifferentiableQP::new(
750 vec![vec![2.0, 0.0], vec![0.0, 2.0]],
751 vec![0.0, 0.0],
752 vec![
753 vec![-1.0, 0.0], vec![0.0, -1.0], vec![-1.0, -1.0], ],
757 vec![-1.0, -1.0, -2.0],
758 vec![],
759 vec![],
760 )
761 .expect("QP creation failed");
762
763 let config = DiffQPConfig::default();
764 let result = qp.forward(&config).expect("Forward failed");
765
766 assert!(result.converged);
767 assert!(
768 (result.optimal_x[0] - 1.0).abs() < 1e-2,
769 "x[0] = {} (expected 1.0)",
770 result.optimal_x[0]
771 );
772 assert!(
773 (result.optimal_x[1] - 1.0).abs() < 1e-2,
774 "x[1] = {} (expected 1.0)",
775 result.optimal_x[1]
776 );
777 }
778}