Skip to main content

scirs2_optimize/differentiable_optimization/
diff_qp.rs

1//! Differentiable Quadratic Programming (OptNet-style).
2//!
3//! Solves the QP:
4//!
5//!   min  ½ x'Qx + c'x
6//!   s.t. Gx ≤ h
7//!        Ax = b
8//!
9//! and computes gradients of the optimal solution x* w.r.t. all problem
10//! parameters (Q, c, G, h, A, b) via implicit differentiation of the KKT
11//! conditions.
12//!
13//! # References
14//! - Amos & Kolter (2017). "OptNet: Differentiable Optimization as a Layer
15//!   in Neural Networks." ICML.
16
17use super::implicit_diff;
18use super::types::{BackwardMode, DiffQPConfig, DiffQPResult, ImplicitGradient};
19use crate::error::{OptimizeError, OptimizeResult};
20
21/// A differentiable QP layer.
22///
23/// Holds the problem data and supports forward solving and backward
24/// (gradient) computation.
25#[derive(Debug, Clone)]
26pub struct DifferentiableQP {
27    /// Quadratic cost matrix Q (n×n, symmetric positive semi-definite).
28    pub q: Vec<Vec<f64>>,
29    /// Linear cost vector c (n).
30    pub c: Vec<f64>,
31    /// Inequality constraint matrix G (m×n): Gx ≤ h.
32    pub g: Vec<Vec<f64>>,
33    /// Inequality constraint rhs h (m).
34    pub h: Vec<f64>,
35    /// Equality constraint matrix A (p×n): Ax = b.
36    pub a: Vec<Vec<f64>>,
37    /// Equality constraint rhs b (p).
38    pub b: Vec<f64>,
39}
40
41impl DifferentiableQP {
42    /// Create a new differentiable QP.
43    ///
44    /// # Arguments
45    /// * `q` – n×n cost matrix (must be symmetric PSD).
46    /// * `c` – n-dimensional linear cost.
47    /// * `g` – m×n inequality constraint matrix.
48    /// * `h` – m-dimensional inequality rhs.
49    /// * `a` – p×n equality constraint matrix.
50    /// * `b` – p-dimensional equality rhs.
51    pub fn new(
52        q: Vec<Vec<f64>>,
53        c: Vec<f64>,
54        g: Vec<Vec<f64>>,
55        h: Vec<f64>,
56        a: Vec<Vec<f64>>,
57        b: Vec<f64>,
58    ) -> OptimizeResult<Self> {
59        let n = c.len();
60        if q.len() != n {
61            return Err(OptimizeError::InvalidInput(format!(
62                "Q has {} rows but c has length {}",
63                q.len(),
64                n
65            )));
66        }
67        for (i, row) in q.iter().enumerate() {
68            if row.len() != n {
69                return Err(OptimizeError::InvalidInput(format!(
70                    "Q row {} has length {} but expected {}",
71                    i,
72                    row.len(),
73                    n
74                )));
75            }
76        }
77        for (i, row) in g.iter().enumerate() {
78            if row.len() != n {
79                return Err(OptimizeError::InvalidInput(format!(
80                    "G row {} has length {} but expected {}",
81                    i,
82                    row.len(),
83                    n
84                )));
85            }
86        }
87        if g.len() != h.len() {
88            return Err(OptimizeError::InvalidInput(format!(
89                "G has {} rows but h has length {}",
90                g.len(),
91                h.len()
92            )));
93        }
94        for (i, row) in a.iter().enumerate() {
95            if row.len() != n {
96                return Err(OptimizeError::InvalidInput(format!(
97                    "A row {} has length {} but expected {}",
98                    i,
99                    row.len(),
100                    n
101                )));
102            }
103        }
104        if a.len() != b.len() {
105            return Err(OptimizeError::InvalidInput(format!(
106                "A has {} rows but b has length {}",
107                a.len(),
108                b.len()
109            )));
110        }
111
112        Ok(Self { q, c, g, h, a, b })
113    }
114
115    /// Number of primal variables.
116    pub fn n(&self) -> usize {
117        self.c.len()
118    }
119
120    /// Number of inequality constraints.
121    pub fn m(&self) -> usize {
122        self.h.len()
123    }
124
125    /// Number of equality constraints.
126    pub fn p(&self) -> usize {
127        self.b.len()
128    }
129
130    /// Solve the QP (forward pass).
131    ///
132    /// Uses a primal-dual interior-point method with Mehrotra predictor-
133    /// corrector steps.
134    pub fn forward(&self, config: &DiffQPConfig) -> OptimizeResult<DiffQPResult> {
135        let n = self.n();
136        let m = self.m();
137        let p = self.p();
138
139        // ── Build regularised Q ────────────────────────────────────────
140        let mut q_reg = self.q.clone();
141        for i in 0..n {
142            q_reg[i][i] += config.regularization;
143        }
144
145        // ── Initialisation ─────────────────────────────────────────────
146        let mut x = vec![0.0; n];
147        let mut lam = vec![1.0; m]; // inequality duals > 0
148        let mut nu = vec![0.0; p]; // equality duals
149        let mut s = vec![1.0; m]; // slacks s = h - Gx > 0
150
151        // Compute initial slacks
152        for i in 0..m {
153            let mut gx_i = 0.0;
154            for j in 0..n {
155                gx_i += self.g[i][j] * x[j];
156            }
157            s[i] = self.h[i] - gx_i;
158            if s[i] <= 0.0 {
159                s[i] = 1.0; // ensure positivity
160            }
161        }
162
163        let mut converged = false;
164        let mut iterations = 0;
165
166        for iter in 0..config.max_iterations {
167            iterations = iter + 1;
168
169            // ── Compute residuals ──────────────────────────────────────
170            // r_stat = Qx + c + G'λ + A'ν  (stationarity)
171            let mut r_stat = vec![0.0; n];
172            for i in 0..n {
173                let mut qx_i = 0.0;
174                for j in 0..n {
175                    qx_i += q_reg[i][j] * x[j];
176                }
177                r_stat[i] = qx_i + self.c[i];
178            }
179            for k in 0..m {
180                for i in 0..n {
181                    r_stat[i] += self.g[k][i] * lam[k];
182                }
183            }
184            for k in 0..p {
185                for i in 0..n {
186                    r_stat[i] += self.a[k][i] * nu[k];
187                }
188            }
189
190            // r_eq = Ax - b  (primal equality)
191            let mut r_eq = vec![0.0; p];
192            for i in 0..p {
193                for j in 0..n {
194                    r_eq[i] += self.a[i][j] * x[j];
195                }
196                r_eq[i] -= self.b[i];
197            }
198
199            // r_ineq = s + Gx - h  (slack definition)
200            let mut r_ineq = vec![0.0; m];
201            for i in 0..m {
202                let mut gx_i = 0.0;
203                for j in 0..n {
204                    gx_i += self.g[i][j] * x[j];
205                }
206                r_ineq[i] = s[i] + gx_i - self.h[i];
207            }
208
209            // r_comp = diag(λ) s  (complementarity, want → 0)
210            let mu: f64 = if m > 0 {
211                lam.iter()
212                    .zip(s.iter())
213                    .map(|(&li, &si)| li * si)
214                    .sum::<f64>()
215                    / m as f64
216            } else {
217                0.0
218            };
219
220            // Check convergence
221            let res_stat: f64 = r_stat.iter().map(|v| v.abs()).fold(0.0, f64::max);
222            let res_eq: f64 = r_eq.iter().map(|v| v.abs()).fold(0.0, f64::max);
223            let res_ineq: f64 = r_ineq.iter().map(|v| v.abs()).fold(0.0, f64::max);
224            let max_res = res_stat.max(res_eq).max(res_ineq).max(mu);
225
226            if max_res < config.tolerance {
227                converged = true;
228                break;
229            }
230
231            // ── Build and solve the KKT system for Newton direction ────
232            // We solve the reduced system by eliminating s.
233            // Variables: (dx, dlam, dnu)
234            let dim = n + m + p;
235            let mut kkt = vec![vec![0.0; dim]; dim];
236            let mut rhs = vec![0.0; dim];
237
238            // Block row 0 (stationarity): Q dx + G' dlam + A' dnu = -r_stat
239            for i in 0..n {
240                for j in 0..n {
241                    kkt[i][j] = q_reg[i][j];
242                }
243                for k in 0..m {
244                    kkt[i][n + k] = self.g[k][i];
245                }
246                for k in 0..p {
247                    kkt[i][n + m + k] = self.a[k][i];
248                }
249                rhs[i] = -r_stat[i];
250            }
251
252            // Block row 1 (complementarity + slack elimination):
253            // diag(s) dlam + diag(λ) ds = -diag(λ)s + σμe
254            // ds = -r_ineq - G dx   (from slack row)
255            // → diag(s) dlam + diag(λ)(-r_ineq - G dx) = -diag(λ)s + σμe
256            // → -diag(λ)G dx + diag(s) dlam = -diag(λ)s + σμe + diag(λ) r_ineq
257            let sigma = 0.1_f64; // centering parameter
258            for i in 0..m {
259                let li = lam[i];
260                let si = s[i];
261                for j in 0..n {
262                    kkt[n + i][j] = -li * self.g[i][j];
263                }
264                kkt[n + i][n + i] = si;
265                rhs[n + i] = -li * si + sigma * mu + li * r_ineq[i];
266            }
267
268            // Block row 2 (equality): A dx = -r_eq
269            for i in 0..p {
270                for j in 0..n {
271                    kkt[n + m + i][j] = self.a[i][j];
272                }
273                rhs[n + m + i] = -r_eq[i];
274            }
275
276            let dir = match implicit_diff::solve_implicit_system(&kkt, &rhs) {
277                Ok(d) => d,
278                Err(_) => break, // singular system, stop
279            };
280
281            let dx = &dir[..n];
282            let dlam = &dir[n..n + m];
283            let dnu = &dir[n + m..];
284
285            // Recover ds
286            let mut ds = vec![0.0; m];
287            for i in 0..m {
288                let mut gx_i = 0.0;
289                for j in 0..n {
290                    gx_i += self.g[i][j] * dx[j];
291                }
292                ds[i] = -r_ineq[i] - gx_i;
293            }
294
295            // ── Step size (fraction-to-boundary) ───────────────────────
296            let tau = 0.995;
297            let mut alpha_p = 1.0_f64;
298            let mut alpha_d = 1.0_f64;
299
300            for i in 0..m {
301                if ds[i] < 0.0 {
302                    let ratio = -tau * s[i] / ds[i];
303                    if ratio < alpha_p {
304                        alpha_p = ratio;
305                    }
306                }
307                if dlam[i] < 0.0 {
308                    let ratio = -tau * lam[i] / dlam[i];
309                    if ratio < alpha_d {
310                        alpha_d = ratio;
311                    }
312                }
313            }
314
315            alpha_p = alpha_p.min(1.0).max(1e-12);
316            alpha_d = alpha_d.min(1.0).max(1e-12);
317
318            // ── Update ─────────────────────────────────────────────────
319            for i in 0..n {
320                x[i] += alpha_p * dx[i];
321            }
322            for i in 0..m {
323                s[i] += alpha_p * ds[i];
324                lam[i] += alpha_d * dlam[i];
325                // Safety: keep positive
326                if s[i] < 1e-14 {
327                    s[i] = 1e-14;
328                }
329                if lam[i] < 1e-14 {
330                    lam[i] = 1e-14;
331                }
332            }
333            for i in 0..p {
334                nu[i] += alpha_d * dnu[i];
335            }
336        }
337
338        // ── Compute objective ──────────────────────────────────────────
339        let mut obj = 0.0;
340        for i in 0..n {
341            obj += self.c[i] * x[i];
342            for j in 0..n {
343                obj += 0.5 * self.q[i][j] * x[i] * x[j];
344            }
345        }
346
347        Ok(DiffQPResult {
348            optimal_x: x,
349            optimal_lambda: lam,
350            optimal_nu: nu,
351            objective: obj,
352            converged,
353            iterations,
354        })
355    }
356
357    /// Backward pass: compute gradients of loss w.r.t. QP parameters.
358    ///
359    /// Given the upstream gradient dl/dx*, returns the implicit gradients
360    /// dl/d{Q, c, G, h, A, b}.
361    pub fn backward(
362        &self,
363        result: &DiffQPResult,
364        dl_dx: &[f64],
365        config: &DiffQPConfig,
366    ) -> OptimizeResult<ImplicitGradient> {
367        let n = self.n();
368        if dl_dx.len() != n {
369            return Err(OptimizeError::InvalidInput(format!(
370                "dl_dx length {} != n {}",
371                dl_dx.len(),
372                n
373            )));
374        }
375
376        // Add regularization to Q for the backward pass as well
377        let mut q_reg = self.q.clone();
378        for i in 0..n {
379            q_reg[i][i] += config.regularization;
380        }
381
382        match config.backward_mode {
383            BackwardMode::FullDifferentiation => implicit_diff::compute_full_implicit_gradient(
384                &q_reg,
385                &self.g,
386                &self.h,
387                &self.a,
388                &result.optimal_x,
389                &result.optimal_lambda,
390                &result.optimal_nu,
391                dl_dx,
392            ),
393            BackwardMode::ActiveSetOnly => {
394                implicit_diff::compute_active_set_implicit_gradient(
395                    &q_reg,
396                    &self.g,
397                    &self.h,
398                    &self.a,
399                    &result.optimal_x,
400                    &result.optimal_lambda,
401                    &result.optimal_nu,
402                    dl_dx,
403                    config.tolerance * 100.0, // slightly relaxed for active set
404                )
405            }
406            _ => Err(OptimizeError::NotImplementedError(
407                "Unknown backward mode".to_string(),
408            )),
409        }
410    }
411
412    /// Solve multiple QPs with the same structure but different parameters.
413    ///
414    /// This is a convenience method; each QP is solved independently.
415    pub fn batched_forward(
416        params_list: &[DifferentiableQP],
417        config: &DiffQPConfig,
418    ) -> OptimizeResult<Vec<DiffQPResult>> {
419        params_list.iter().map(|qp| qp.forward(config)).collect()
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426
427    /// Simple 2-variable unconstrained QP:
428    ///   min x^2 + y^2 + x + 2y
429    ///   → optimal at x = -0.5, y = -1.0
430    #[test]
431    fn test_qp_forward_unconstrained() {
432        let qp = DifferentiableQP::new(
433            vec![vec![2.0, 0.0], vec![0.0, 2.0]],
434            vec![1.0, 2.0],
435            vec![],
436            vec![],
437            vec![],
438            vec![],
439        )
440        .expect("QP creation failed");
441
442        let config = DiffQPConfig::default();
443        let result = qp.forward(&config).expect("Forward solve failed");
444
445        assert!(result.converged, "QP should converge");
446        assert!(
447            (result.optimal_x[0] - (-0.5)).abs() < 1e-4,
448            "x[0] = {} (expected -0.5)",
449            result.optimal_x[0]
450        );
451        assert!(
452            (result.optimal_x[1] - (-1.0)).abs() < 1e-4,
453            "x[1] = {} (expected -1.0)",
454            result.optimal_x[1]
455        );
456    }
457
458    /// 2-variable QP with one inequality constraint:
459    ///   min x^2 + y^2
460    ///   s.t. x + y >= 1   →  -x - y <= -1
461    ///   optimal: x = 0.5, y = 0.5
462    #[test]
463    fn test_qp_forward_with_inequality() {
464        let qp = DifferentiableQP::new(
465            vec![vec![2.0, 0.0], vec![0.0, 2.0]],
466            vec![0.0, 0.0],
467            vec![vec![-1.0, -1.0]], // -x - y <= -1
468            vec![-1.0],
469            vec![],
470            vec![],
471        )
472        .expect("QP creation failed");
473
474        let config = DiffQPConfig::default();
475        let result = qp.forward(&config).expect("Forward solve failed");
476
477        assert!(result.converged);
478        assert!(
479            (result.optimal_x[0] - 0.5).abs() < 1e-3,
480            "x[0] = {} (expected 0.5)",
481            result.optimal_x[0]
482        );
483        assert!(
484            (result.optimal_x[1] - 0.5).abs() < 1e-3,
485            "x[1] = {} (expected 0.5)",
486            result.optimal_x[1]
487        );
488    }
489
490    /// For an unconstrained QP:  min ½ x'Qx + c'x
491    /// x* = -Q⁻¹ c, and dl/dc = dx*/dc · dl/dx = -Q⁻¹ · dl/dx.
492    /// When dl/dx = I (unit upstream), dl/dc = -Q⁻¹.
493    /// For Q = 2I, dl/dc_i with dl/dx = e_i should give -0.5 * e_i.
494    #[test]
495    fn test_backward_gradient_dl_dc() {
496        let qp = DifferentiableQP::new(
497            vec![vec![2.0, 0.0], vec![0.0, 2.0]],
498            vec![1.0, 2.0],
499            vec![],
500            vec![],
501            vec![],
502            vec![],
503        )
504        .expect("QP creation failed");
505
506        let config = DiffQPConfig::default();
507        let result = qp.forward(&config).expect("Forward solve failed");
508
509        // dl/dx = [1, 0] (gradient of loss w.r.t. x)
510        let dl_dx = vec![1.0, 0.0];
511        let grad = qp
512            .backward(&result, &dl_dx, &config)
513            .expect("Backward failed");
514
515        // For unconstrained: dl/dc = -Q^{-1} dl/dx = -0.5 * [1, 0]
516        // But the implicit differentiation through KKT gives dl/dc = dx
517        // where dx solves Q dx = -dl/dx, so dx = -Q^{-1} dl/dx = [-0.5, 0]
518        assert!(
519            (grad.dl_dc[0] - (-0.5)).abs() < 1e-3,
520            "dl/dc[0] = {} (expected -0.5)",
521            grad.dl_dc[0]
522        );
523        assert!(
524            grad.dl_dc[1].abs() < 1e-3,
525            "dl/dc[1] = {} (expected 0)",
526            grad.dl_dc[1]
527        );
528    }
529
530    /// Finite-difference check for dl/dc.
531    #[test]
532    fn test_backward_finite_difference_c() {
533        let eps = 1e-5;
534        let config = DiffQPConfig::default();
535
536        let q = vec![vec![4.0, 1.0], vec![1.0, 3.0]];
537        let c_base = vec![1.0, -1.0];
538        let g = vec![vec![-1.0, 0.0], vec![0.0, -1.0]]; // x >= 0
539        let h = vec![0.0, 0.0];
540
541        let qp0 = DifferentiableQP::new(
542            q.clone(),
543            c_base.clone(),
544            g.clone(),
545            h.clone(),
546            vec![],
547            vec![],
548        )
549        .expect("QP creation failed");
550        let res0 = qp0.forward(&config).expect("Forward failed");
551        let obj0 = res0.objective;
552
553        // dl/dx = x* (so loss = 0.5 * ||x*||^2)
554        let dl_dx = res0.optimal_x.clone();
555        let grad = qp0
556            .backward(&res0, &dl_dx, &config)
557            .expect("Backward failed");
558
559        // Finite difference for c[0]
560        let mut c_plus = c_base.clone();
561        c_plus[0] += eps;
562        let qp_plus =
563            DifferentiableQP::new(q.clone(), c_plus, g.clone(), h.clone(), vec![], vec![])
564                .expect("QP+ creation failed");
565        let res_plus = qp_plus.forward(&config).expect("Forward+ failed");
566
567        let mut c_minus = c_base.clone();
568        c_minus[0] -= eps;
569        let qp_minus =
570            DifferentiableQP::new(q.clone(), c_minus, g.clone(), h.clone(), vec![], vec![])
571                .expect("QP- creation failed");
572        let res_minus = qp_minus.forward(&config).expect("Forward- failed");
573
574        // loss = 0.5 * ||x*||^2
575        let loss_plus: f64 = res_plus.optimal_x.iter().map(|v| 0.5 * v * v).sum();
576        let loss_minus: f64 = res_minus.optimal_x.iter().map(|v| 0.5 * v * v).sum();
577        let fd_grad = (loss_plus - loss_minus) / (2.0 * eps);
578
579        assert!(
580            (grad.dl_dc[0] - fd_grad).abs() < 1e-3,
581            "dl/dc[0] analytical={} vs fd={}",
582            grad.dl_dc[0],
583            fd_grad
584        );
585    }
586
587    /// Finite-difference check for dl/dh (inequality rhs).
588    #[test]
589    fn test_backward_finite_difference_h() {
590        let eps = 1e-5;
591        let config = DiffQPConfig::default();
592
593        let q = vec![vec![2.0, 0.0], vec![0.0, 2.0]];
594        let c = vec![0.0, 0.0];
595        let g = vec![vec![-1.0, -1.0]]; // -x-y <= h[0]
596        let h_base = vec![-1.0]; // x+y >= 1
597
598        let qp0 = DifferentiableQP::new(
599            q.clone(),
600            c.clone(),
601            g.clone(),
602            h_base.clone(),
603            vec![],
604            vec![],
605        )
606        .expect("QP creation failed");
607        let res0 = qp0.forward(&config).expect("Forward failed");
608
609        let dl_dx = res0.optimal_x.clone();
610        let grad = qp0
611            .backward(&res0, &dl_dx, &config)
612            .expect("Backward failed");
613
614        // Perturb h[0]
615        let mut h_plus = h_base.clone();
616        h_plus[0] += eps;
617        let qp_plus =
618            DifferentiableQP::new(q.clone(), c.clone(), g.clone(), h_plus, vec![], vec![])
619                .expect("QP+ creation failed");
620        let res_plus = qp_plus.forward(&config).expect("Forward+ failed");
621
622        let mut h_minus = h_base.clone();
623        h_minus[0] -= eps;
624        let qp_minus =
625            DifferentiableQP::new(q.clone(), c.clone(), g.clone(), h_minus, vec![], vec![])
626                .expect("QP- creation failed");
627        let res_minus = qp_minus.forward(&config).expect("Forward- failed");
628
629        let loss_plus: f64 = res_plus.optimal_x.iter().map(|v| 0.5 * v * v).sum();
630        let loss_minus: f64 = res_minus.optimal_x.iter().map(|v| 0.5 * v * v).sum();
631        let fd_grad = (loss_plus - loss_minus) / (2.0 * eps);
632
633        // Allow somewhat loose tolerance since IP method + implicit diff can have some error
634        assert!(
635            (grad.dl_dh[0] - fd_grad).abs() < 0.1,
636            "dl/dh[0] analytical={} vs fd={}",
637            grad.dl_dh[0],
638            fd_grad
639        );
640    }
641
642    #[test]
643    fn test_qp_with_equality_constraint() {
644        // min x^2 + y^2 s.t. x + y = 1
645        // optimal: x = 0.5, y = 0.5
646        let qp = DifferentiableQP::new(
647            vec![vec![2.0, 0.0], vec![0.0, 2.0]],
648            vec![0.0, 0.0],
649            vec![],
650            vec![],
651            vec![vec![1.0, 1.0]],
652            vec![1.0],
653        )
654        .expect("QP creation failed");
655
656        let config = DiffQPConfig::default();
657        let result = qp.forward(&config).expect("Forward failed");
658
659        assert!(result.converged);
660        assert!(
661            (result.optimal_x[0] - 0.5).abs() < 1e-3,
662            "x[0] = {}",
663            result.optimal_x[0]
664        );
665        assert!(
666            (result.optimal_x[1] - 0.5).abs() < 1e-3,
667            "x[1] = {}",
668            result.optimal_x[1]
669        );
670    }
671
672    #[test]
673    fn test_batched_forward_consistency() {
674        let qp1 = DifferentiableQP::new(
675            vec![vec![2.0, 0.0], vec![0.0, 2.0]],
676            vec![1.0, 0.0],
677            vec![],
678            vec![],
679            vec![],
680            vec![],
681        )
682        .expect("QP1 creation failed");
683        let qp2 = DifferentiableQP::new(
684            vec![vec![2.0, 0.0], vec![0.0, 2.0]],
685            vec![0.0, 1.0],
686            vec![],
687            vec![],
688            vec![],
689            vec![],
690        )
691        .expect("QP2 creation failed");
692
693        let config = DiffQPConfig::default();
694        let batch_results = DifferentiableQP::batched_forward(&[qp1.clone(), qp2.clone()], &config)
695            .expect("Batch failed");
696
697        let r1 = qp1.forward(&config).expect("Single 1 failed");
698        let r2 = qp2.forward(&config).expect("Single 2 failed");
699
700        for i in 0..2 {
701            assert!(
702                (batch_results[0].optimal_x[i] - r1.optimal_x[i]).abs() < 1e-10,
703                "Batch[0].x[{}] differs",
704                i
705            );
706            assert!(
707                (batch_results[1].optimal_x[i] - r2.optimal_x[i]).abs() < 1e-10,
708                "Batch[1].x[{}] differs",
709                i
710            );
711        }
712    }
713
714    #[test]
715    fn test_qp_empty_constraints() {
716        let qp = DifferentiableQP::new(vec![vec![2.0]], vec![4.0], vec![], vec![], vec![], vec![])
717            .expect("QP creation failed");
718
719        let config = DiffQPConfig::default();
720        let result = qp.forward(&config).expect("Forward failed");
721        assert!(result.converged);
722        // min x^2 + 4x → x* = -2
723        assert!(
724            (result.optimal_x[0] - (-2.0)).abs() < 1e-3,
725            "x = {}",
726            result.optimal_x[0]
727        );
728    }
729
730    #[test]
731    fn test_qp_dimension_validation() {
732        // Q is 2x2 but c is length 3 → error
733        let result = DifferentiableQP::new(
734            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
735            vec![1.0, 2.0, 3.0],
736            vec![],
737            vec![],
738            vec![],
739            vec![],
740        );
741        assert!(result.is_err());
742    }
743
744    #[test]
745    fn test_qp_degenerate_active_constraints() {
746        // Two active constraints at the same point
747        // min x^2 + y^2 s.t. x >= 1, y >= 1, x+y >= 2
748        // At optimal (1,1) all three constraints are active
749        let qp = DifferentiableQP::new(
750            vec![vec![2.0, 0.0], vec![0.0, 2.0]],
751            vec![0.0, 0.0],
752            vec![
753                vec![-1.0, 0.0],  // -x <= -1
754                vec![0.0, -1.0],  // -y <= -1
755                vec![-1.0, -1.0], // -x-y <= -2
756            ],
757            vec![-1.0, -1.0, -2.0],
758            vec![],
759            vec![],
760        )
761        .expect("QP creation failed");
762
763        let config = DiffQPConfig::default();
764        let result = qp.forward(&config).expect("Forward failed");
765
766        assert!(result.converged);
767        assert!(
768            (result.optimal_x[0] - 1.0).abs() < 1e-2,
769            "x[0] = {} (expected 1.0)",
770            result.optimal_x[0]
771        );
772        assert!(
773            (result.optimal_x[1] - 1.0).abs() < 1e-2,
774            "x[1] = {} (expected 1.0)",
775            result.optimal_x[1]
776        );
777    }
778}