Skip to main content

scirs2_optimize/differentiable_optimization/
qp_layer.rs

1//! ADMM-based differentiable QP layer with warm-start and active-set backward.
2//!
3//! Solves the QP:
4//!
5//!   min  ½ xᵀQx + cᵀx
6//!   s.t. A_eq x = b_eq    (equality)
7//!        G_ineq x ≤ h_ineq (inequality)
8//!
9//! The forward pass uses an OSQP-style ADMM iteration:
10//!
11//!   x-update: (Q + ρ Cᵀ C)⁻¹ (ρ Cᵀ (z - u) - c)   where C = [A_eq; G_ineq]
12//!   z-update: projection onto {Ax=b} × {Gx ≤ h}
13//!   u-update: u += C x - z
14//!
15//! The backward pass uses KKT sensitivity on the active constraints.
16
17use super::implicit_diff::identify_active_constraints;
18use super::kkt_sensitivity::{kkt_sensitivity, regularize_q};
19use super::types::{DiffOptGrad, DiffOptParams, DiffOptResult, DiffOptStatus};
20use crate::error::{OptimizeError, OptimizeResult};
21
22// ─────────────────────────────────────────────────────────────────────────────
23// Configuration
24// ─────────────────────────────────────────────────────────────────────────────
25
26/// Configuration for the ADMM-based QP layer.
27#[derive(Debug, Clone)]
28pub struct QpLayerConfig {
29    /// Maximum number of ADMM iterations.
30    pub max_iter: usize,
31    /// Primal and dual residual tolerance for convergence.
32    pub tol: f64,
33    /// ADMM penalty parameter ρ.
34    pub rho: f64,
35    /// Tikhonov regularization on Q for numerical stability.
36    pub regularization: f64,
37    /// Tolerance for identifying active inequality constraints in backward pass.
38    pub active_tol: f64,
39    /// Whether to print convergence information.
40    pub verbose: bool,
41}
42
43impl Default for QpLayerConfig {
44    fn default() -> Self {
45        Self {
46            max_iter: 100,
47            tol: 1e-8,
48            rho: 1.0,
49            regularization: 1e-7,
50            active_tol: 1e-6,
51            verbose: false,
52        }
53    }
54}
55
56// ─────────────────────────────────────────────────────────────────────────────
57// Cholesky factorization (simplified LDLᵀ for symmetric PD matrices)
58// ─────────────────────────────────────────────────────────────────────────────
59
60/// Cholesky decomposition: returns lower triangular L such that A = L Lᵀ.
61/// Uses the standard Cholesky-Banachiewicz algorithm.
62fn cholesky(a: &[Vec<f64>]) -> OptimizeResult<Vec<Vec<f64>>> {
63    let n = a.len();
64    let mut l = vec![vec![0.0_f64; n]; n];
65
66    for i in 0..n {
67        for j in 0..=i {
68            let mut sum = 0.0_f64;
69            for k in 0..j {
70                sum += l[i][k] * l[j][k];
71            }
72            if i == j {
73                let diag = a[i][i] - sum;
74                if diag <= 0.0 {
75                    return Err(OptimizeError::ComputationError(format!(
76                        "Cholesky failed: non-positive diagonal at index {}. diag = {diag}",
77                        i
78                    )));
79                }
80                l[i][j] = diag.sqrt();
81            } else {
82                let l_jj = l[j][j];
83                if l_jj.abs() < 1e-30 {
84                    return Err(OptimizeError::ComputationError(
85                        "Cholesky failed: zero diagonal element".to_string(),
86                    ));
87                }
88                l[i][j] = (a[i][j] - sum) / l_jj;
89            }
90        }
91    }
92    Ok(l)
93}
94
95/// Forward substitution: solve L y = b where L is lower triangular.
96fn forward_sub(l: &[Vec<f64>], b: &[f64]) -> Vec<f64> {
97    let n = b.len();
98    let mut y = vec![0.0_f64; n];
99    for i in 0..n {
100        let mut sum = b[i];
101        for j in 0..i {
102            sum -= l[i][j] * y[j];
103        }
104        let diag = l[i][i];
105        y[i] = if diag.abs() < 1e-30 { 0.0 } else { sum / diag };
106    }
107    y
108}
109
110/// Backward substitution: solve Lᵀ x = y where L is lower triangular.
111fn backward_sub(l: &[Vec<f64>], y: &[f64]) -> Vec<f64> {
112    let n = y.len();
113    let mut x = vec![0.0_f64; n];
114    for i in (0..n).rev() {
115        let mut sum = y[i];
116        for j in (i + 1)..n {
117            sum -= l[j][i] * x[j];
118        }
119        let diag = l[i][i];
120        x[i] = if diag.abs() < 1e-30 { 0.0 } else { sum / diag };
121    }
122    x
123}
124
125/// Solve the symmetric positive definite system Ax = b via Cholesky factorization.
126/// Falls back to Gaussian elimination if Cholesky fails.
127fn cholesky_solve(a: &[Vec<f64>], b: &[f64]) -> OptimizeResult<Vec<f64>> {
128    match cholesky(a) {
129        Ok(l) => {
130            let y = forward_sub(&l, b);
131            Ok(backward_sub(&l, &y))
132        }
133        Err(_) => {
134            // Fall back to implicit_diff solver
135            super::implicit_diff::solve_implicit_system(a, b)
136        }
137    }
138}
139
140// ─────────────────────────────────────────────────────────────────────────────
141// QP layer
142// ─────────────────────────────────────────────────────────────────────────────
143
144/// An ADMM-based differentiable QP layer.
145///
146/// Stores problem data and the last forward-pass solution for use in
147/// the backward pass.
148#[derive(Debug, Clone)]
149pub struct QpLayer {
150    config: QpLayerConfig,
151    /// Cached warm-start primal.
152    warm_x: Option<Vec<f64>>,
153    /// Cached warm-start z.
154    warm_z: Option<Vec<f64>>,
155    /// Cached warm-start u.
156    warm_u: Option<Vec<f64>>,
157    /// Last forward result (needed for backward).
158    last_result: Option<QpForwardCache>,
159}
160
161/// Cached data from the forward pass needed for gradient computation.
162#[derive(Debug, Clone)]
163struct QpForwardCache {
164    x: Vec<f64>,
165    lambda: Vec<f64>, // inequality duals
166    nu: Vec<f64>,     // equality duals
167    q: Vec<Vec<f64>>,
168    c: Vec<f64>,
169    a_eq: Vec<Vec<f64>>,
170    b_eq: Vec<f64>,
171    g_ineq: Vec<Vec<f64>>,
172    h_ineq: Vec<f64>,
173}
174
175impl QpLayer {
176    /// Create a new QP layer with default configuration.
177    pub fn new() -> Self {
178        Self {
179            config: QpLayerConfig::default(),
180            warm_x: None,
181            warm_z: None,
182            warm_u: None,
183            last_result: None,
184        }
185    }
186
187    /// Create a new QP layer with custom configuration.
188    pub fn with_config(config: QpLayerConfig) -> Self {
189        Self {
190            config,
191            warm_x: None,
192            warm_z: None,
193            warm_u: None,
194            last_result: None,
195        }
196    }
197
198    /// Solve the QP (forward pass).
199    ///
200    /// Uses ADMM with warm-start. The constraint matrix C = [A_eq; G_ineq] is
201    /// stacked, and z is projected onto the feasible set:
202    ///
203    ///   z_eq   = b_eq                     (equality: exact satisfaction)
204    ///   z_ineq = min(z_ineq_raw, h_ineq)  (inequality: clamp to ≤ h)
205    ///
206    /// # Arguments
207    /// * `q`      – n×n cost matrix (symmetric PSD).
208    /// * `c`      – n linear cost vector.
209    /// * `a_eq`   – p×n equality constraint matrix.
210    /// * `b_eq`   – p equality rhs.
211    /// * `g_ineq` – m×n inequality constraint matrix.
212    /// * `h_ineq` – m inequality rhs.
213    pub fn forward(
214        &mut self,
215        q: Vec<Vec<f64>>,
216        c: Vec<f64>,
217        a_eq: Vec<Vec<f64>>,
218        b_eq: Vec<f64>,
219        g_ineq: Vec<Vec<f64>>,
220        h_ineq: Vec<f64>,
221    ) -> OptimizeResult<DiffOptResult> {
222        let n = c.len();
223        let p = b_eq.len();
224        let m = h_ineq.len();
225        let nc = p + m; // total constraints
226
227        // ── Validate dimensions ────────────────────────────────────────────
228        if q.len() != n {
229            return Err(OptimizeError::InvalidInput(format!(
230                "Q rows ({}) != n ({})",
231                q.len(),
232                n
233            )));
234        }
235        if a_eq.len() != p {
236            return Err(OptimizeError::InvalidInput(format!(
237                "A_eq rows ({}) != p ({})",
238                a_eq.len(),
239                p
240            )));
241        }
242        if g_ineq.len() != m {
243            return Err(OptimizeError::InvalidInput(format!(
244                "G_ineq rows ({}) != m ({})",
245                g_ineq.len(),
246                m
247            )));
248        }
249
250        // ── Regularize Q ───────────────────────────────────────────────────
251        let q_reg = regularize_q(&q, self.config.regularization);
252        let rho = self.config.rho;
253
254        // ── Build C = [A_eq; G_ineq] (nc × n) ────────────────────────────
255        let c_mat: Vec<Vec<f64>> = a_eq.iter().cloned().chain(g_ineq.iter().cloned()).collect();
256
257        // ── Build M = Q_reg + ρ CᵀC (n×n) ────────────────────────────────
258        let mut m_mat = q_reg.clone();
259        for row in &c_mat {
260            for i in 0..n {
261                for j in 0..n {
262                    let ci = if i < row.len() { row[i] } else { 0.0 };
263                    let cj = if j < row.len() { row[j] } else { 0.0 };
264                    m_mat[i][j] += rho * ci * cj;
265                }
266            }
267        }
268
269        // ── Initialise from warm-start or zero ────────────────────────────
270        let mut x = self
271            .warm_x
272            .as_ref()
273            .filter(|wx| wx.len() == n)
274            .cloned()
275            .unwrap_or_else(|| vec![0.0_f64; n]);
276
277        let mut z = self
278            .warm_z
279            .as_ref()
280            .filter(|wz| wz.len() == nc)
281            .cloned()
282            .unwrap_or_else(|| {
283                // z_eq = b_eq, z_ineq = h_ineq / 2
284                let mut z0 = Vec::with_capacity(nc);
285                z0.extend_from_slice(&b_eq);
286                z0.extend(h_ineq.iter().map(|&hi| hi / 2.0));
287                z0
288            });
289
290        let mut u = self
291            .warm_u
292            .as_ref()
293            .filter(|wu| wu.len() == nc)
294            .cloned()
295            .unwrap_or_else(|| vec![0.0_f64; nc]);
296
297        let mut converged = false;
298        let mut iterations = 0_usize;
299
300        for iter in 0..self.config.max_iter {
301            iterations = iter + 1;
302
303            // ── x-update: solve M x_new = ρ Cᵀ(z - u) - c ──────────────
304            let mut rhs_x = c.iter().map(|&ci| -ci).collect::<Vec<_>>();
305            for (k, row) in c_mat.iter().enumerate() {
306                let zu_k =
307                    if k < z.len() { z[k] } else { 0.0 } - if k < u.len() { u[k] } else { 0.0 };
308                for j in 0..n {
309                    let ckj = if j < row.len() { row[j] } else { 0.0 };
310                    rhs_x[j] += rho * ckj * zu_k;
311                }
312            }
313
314            let x_new = cholesky_solve(&m_mat, &rhs_x)?;
315
316            // ── z-update: project (C x_new + u) onto feasible set ────────
317            let mut cx = vec![0.0_f64; nc];
318            for (k, row) in c_mat.iter().enumerate() {
319                for j in 0..n {
320                    let ckj = if j < row.len() { row[j] } else { 0.0 };
321                    cx[k] += ckj * x_new[j];
322                }
323            }
324
325            let mut z_new = vec![0.0_f64; nc];
326            // Equality block: project onto Ax = b → z_k = b_k
327            for k in 0..p {
328                z_new[k] = if k < b_eq.len() { b_eq[k] } else { 0.0 };
329            }
330            // Inequality block: project onto Gx ≤ h → z_k = min(cx[p+k] + u[p+k], h_k)
331            for k in 0..m {
332                let raw = cx[p + k] + u[p + k];
333                let h_k = if k < h_ineq.len() { h_ineq[k] } else { 0.0 };
334                z_new[p + k] = raw.min(h_k);
335            }
336
337            // ── u-update: u += Cx - z ─────────────────────────────────────
338            let mut u_new = vec![0.0_f64; nc];
339            for k in 0..nc {
340                u_new[k] = u[k] + cx[k] - z_new[k];
341            }
342
343            // ── Compute residuals ─────────────────────────────────────────
344            let primal_res: f64 = cx
345                .iter()
346                .zip(z_new.iter())
347                .map(|(a, b)| (a - b).powi(2))
348                .sum::<f64>()
349                .sqrt();
350            let dual_res: f64 = {
351                // rho * Cᵀ (z - z_old)
352                let mut dr = 0.0_f64;
353                for k in 0..nc {
354                    let dz = z_new[k] - z[k];
355                    for j in 0..n {
356                        let ckj = if j < c_mat[k].len() { c_mat[k][j] } else { 0.0 };
357                        dr += (rho * ckj * dz).powi(2);
358                    }
359                }
360                dr.sqrt()
361            };
362
363            if self.config.verbose {
364                eprintln!(
365                    "iter {}: primal_res={:.2e}, dual_res={:.2e}",
366                    iter, primal_res, dual_res
367                );
368            }
369
370            x = x_new;
371            z = z_new;
372            u = u_new;
373
374            if primal_res < self.config.tol && dual_res < self.config.tol {
375                converged = true;
376                break;
377            }
378        }
379
380        // ── Extract dual variables ─────────────────────────────────────────
381        // In ADMM, the dual variable for the k-th constraint is ρ u[k].
382        let nu: Vec<f64> = u[..p].iter().map(|&ui| rho * ui).collect();
383        let lambda: Vec<f64> = u[p..].iter().map(|&ui| rho * ui.max(0.0)).collect();
384
385        // ── Compute objective ──────────────────────────────────────────────
386        let mut obj = 0.0_f64;
387        for i in 0..n {
388            obj += c[i] * x[i];
389            for j in 0..n {
390                let q_ij = if i < q.len() && j < q[i].len() {
391                    q[i][j]
392                } else {
393                    0.0
394                };
395                obj += 0.5 * q_ij * x[i] * x[j];
396            }
397        }
398
399        let status = if converged {
400            DiffOptStatus::Optimal
401        } else {
402            DiffOptStatus::MaxIterations
403        };
404
405        // ── Update warm-start cache ────────────────────────────────────────
406        self.warm_x = Some(x.clone());
407        self.warm_z = Some(z);
408        self.warm_u = Some(u);
409
410        // ── Cache for backward ────────────────────────────────────────────
411        self.last_result = Some(QpForwardCache {
412            x: x.clone(),
413            lambda: lambda.clone(),
414            nu: nu.clone(),
415            q: q.clone(),
416            c: c.clone(),
417            a_eq: a_eq.clone(),
418            b_eq: b_eq.clone(),
419            g_ineq: g_ineq.clone(),
420            h_ineq: h_ineq.clone(),
421        });
422
423        Ok(DiffOptResult {
424            x,
425            lambda,
426            nu,
427            objective: obj,
428            status,
429            iterations,
430        })
431    }
432
433    /// Backward pass: compute parameter gradients via KKT sensitivity.
434    ///
435    /// Uses the active-set at the solution to identify binding inequality
436    /// constraints, stacks them with equality constraints, and calls
437    /// `kkt_sensitivity` on the resulting system.
438    ///
439    /// # Arguments
440    /// * `dl_dx` – upstream gradient dL/dx (length n).
441    ///
442    /// # Errors
443    /// Returns `OptimizeError::ComputationError` if no forward pass has been
444    /// run, or if the KKT system is singular.
445    pub fn backward(&self, dl_dx: &[f64]) -> OptimizeResult<DiffOptGrad> {
446        let cache = self.last_result.as_ref().ok_or_else(|| {
447            OptimizeError::ComputationError("QpLayer::backward called before forward".to_string())
448        })?;
449
450        let n = cache.x.len();
451        if dl_dx.len() != n {
452            return Err(OptimizeError::InvalidInput(format!(
453                "dl_dx length {} != n {}",
454                dl_dx.len(),
455                n
456            )));
457        }
458
459        // ── Identify active inequality constraints ─────────────────────────
460        let active_idx = identify_active_constraints(
461            &cache.g_ineq,
462            &cache.h_ineq,
463            &cache.x,
464            self.config.active_tol,
465        );
466
467        // Stack equality constraints and active inequality rows
468        let mut a_aug: Vec<Vec<f64>> = cache.a_eq.clone();
469        let mut b_aug: Vec<f64> = cache.b_eq.clone();
470        let mut nu_aug: Vec<f64> = cache.nu.clone();
471
472        for &ai in &active_idx {
473            if ai < cache.g_ineq.len() {
474                a_aug.push(cache.g_ineq[ai].clone());
475                b_aug.push(cache.h_ineq.get(ai).copied().unwrap_or(0.0));
476                nu_aug.push(cache.lambda.get(ai).copied().unwrap_or(0.0));
477            }
478        }
479
480        // ── Regularize Q ──────────────────────────────────────────────────
481        let q_reg = regularize_q(&cache.q, self.config.regularization);
482
483        // ── Call KKT sensitivity on augmented equality system ─────────────
484        let kkt_grad = kkt_sensitivity(&q_reg, &a_aug, &cache.x, &nu_aug, dl_dx)?;
485
486        // Split dl_da back into dl_da_eq and dl_dg (active rows only)
487        let p = cache.a_eq.len();
488        let m_full = cache.g_ineq.len();
489
490        let dl_da_eq: Option<Vec<Vec<f64>>> = if p > 0 {
491            Some(kkt_grad.dl_da[..p].to_vec())
492        } else {
493            None
494        };
495
496        let dl_db_eq = kkt_grad.dl_db[..p].to_vec();
497
498        // Expand active gradients to full G dimension
499        let mut dl_dg = vec![vec![0.0_f64; n]; m_full];
500        let mut dl_dh = vec![0.0_f64; m_full];
501        for (idx, &ai) in active_idx.iter().enumerate() {
502            let aug_idx = p + idx;
503            if ai < m_full && aug_idx < kkt_grad.dl_da.len() {
504                dl_dg[ai] = kkt_grad.dl_da[aug_idx].clone();
505                dl_dh[ai] = kkt_grad.dl_db.get(aug_idx).copied().unwrap_or(0.0);
506            }
507        }
508
509        Ok(DiffOptGrad {
510            dl_dq: Some(kkt_grad.dl_dq),
511            dl_dc: kkt_grad.dl_dc,
512            dl_da: dl_da_eq,
513            dl_db: dl_db_eq,
514            dl_dg: Some(dl_dg),
515            dl_dh,
516        })
517    }
518
519    /// Access the cached solution from the last forward pass.
520    pub fn last_solution(&self) -> Option<&[f64]> {
521        self.last_result.as_ref().map(|r| r.x.as_slice())
522    }
523
524    /// Reset warm-start cache.
525    pub fn reset_warm_start(&mut self) {
526        self.warm_x = None;
527        self.warm_z = None;
528        self.warm_u = None;
529    }
530}
531
532impl Default for QpLayer {
533    fn default() -> Self {
534        Self::new()
535    }
536}
537
538// ─────────────────────────────────────────────────────────────────────────────
539// Tests
540// ─────────────────────────────────────────────────────────────────────────────
541
542#[cfg(test)]
543mod tests {
544    use super::*;
545
546    fn make_identity_qp(n: usize) -> (Vec<Vec<f64>>, Vec<f64>) {
547        let q = (0..n)
548            .map(|i| {
549                let mut row = vec![0.0_f64; n];
550                row[i] = 2.0; // 2I so x* = -Q^{-1}c = -0.5 c
551                row
552            })
553            .collect();
554        let c = vec![0.0_f64; n];
555        (q, c)
556    }
557
558    #[test]
559    fn test_qp_layer_config_default() {
560        let cfg = QpLayerConfig::default();
561        assert_eq!(cfg.max_iter, 100);
562        assert!((cfg.tol - 1e-8).abs() < 1e-15);
563        assert!(!cfg.verbose);
564        assert!((cfg.rho - 1.0).abs() < 1e-15);
565    }
566
567    #[test]
568    fn test_qp_layer_identity_q_zero_c() {
569        // min ½||x||² s.t. x[0] + x[1] = 0
570        // x* = [0, 0] with equality b=0
571        let mut layer = QpLayer::new();
572        let (q, c) = make_identity_qp(2);
573        let a_eq = vec![vec![1.0, 1.0]];
574        let b_eq = vec![0.0];
575
576        let result = layer
577            .forward(q, c, a_eq, b_eq, vec![], vec![])
578            .expect("Forward failed");
579
580        assert!(
581            result.x[0].abs() < 1e-4,
582            "x[0] = {} (expected 0)",
583            result.x[0]
584        );
585        assert!(
586            result.x[1].abs() < 1e-4,
587            "x[1] = {} (expected 0)",
588            result.x[1]
589        );
590    }
591
592    #[test]
593    fn test_qp_layer_forward_unconstrained() {
594        // min x^2 + y^2 + x + 2y → x* = [-0.5, -1.0]
595        let mut layer = QpLayer::new();
596        let q = vec![vec![2.0, 0.0], vec![0.0, 2.0]];
597        let c = vec![1.0, 2.0];
598
599        let result = layer
600            .forward(q, c, vec![], vec![], vec![], vec![])
601            .expect("Forward failed");
602
603        assert!(
604            (result.x[0] - (-0.5)).abs() < 1e-3,
605            "x[0] = {} (expected -0.5)",
606            result.x[0]
607        );
608        assert!(
609            (result.x[1] - (-1.0)).abs() < 1e-3,
610            "x[1] = {} (expected -1.0)",
611            result.x[1]
612        );
613    }
614
615    #[test]
616    fn test_qp_layer_forward_with_equality() {
617        // min x^2 + y^2 s.t. x + y = 1 → x* = [0.5, 0.5]
618        let mut layer = QpLayer::new();
619        let q = vec![vec![2.0, 0.0], vec![0.0, 2.0]];
620        let c = vec![0.0, 0.0];
621        let a_eq = vec![vec![1.0, 1.0]];
622        let b_eq = vec![1.0];
623
624        let result = layer
625            .forward(q, c, a_eq, b_eq, vec![], vec![])
626            .expect("Forward failed");
627
628        assert!(
629            (result.x[0] - 0.5).abs() < 1e-3,
630            "x[0] = {} (expected 0.5)",
631            result.x[0]
632        );
633        assert!(
634            (result.x[1] - 0.5).abs() < 1e-3,
635            "x[1] = {} (expected 0.5)",
636            result.x[1]
637        );
638    }
639
640    #[test]
641    fn test_qp_layer_forward_with_inequality() {
642        // min x^2 + y^2 s.t. -x - y <= -1 (i.e. x + y >= 1) → x* = [0.5, 0.5]
643        let mut layer = QpLayer::new();
644        let q = vec![vec![2.0, 0.0], vec![0.0, 2.0]];
645        let c = vec![0.0, 0.0];
646        let g = vec![vec![-1.0, -1.0]];
647        let h = vec![-1.0];
648
649        let result = layer
650            .forward(q, c, vec![], vec![], g, h)
651            .expect("Forward failed");
652
653        // x + y should be >= 1
654        let sum = result.x[0] + result.x[1];
655        assert!(sum >= 1.0 - 1e-3, "x + y = {} (should be >= 1)", sum);
656    }
657
658    #[test]
659    fn test_qp_layer_backward_no_forward_error() {
660        let layer = QpLayer::new();
661        let result = layer.backward(&[1.0, 0.0]);
662        assert!(result.is_err(), "Should error without forward pass");
663    }
664
665    #[test]
666    fn test_qp_layer_backward_dl_dc_finite() {
667        let mut layer = QpLayer::new();
668        let q = vec![vec![2.0, 0.0], vec![0.0, 2.0]];
669        let c = vec![1.0, 2.0];
670
671        let result = layer
672            .forward(q, c, vec![], vec![], vec![], vec![])
673            .expect("Forward failed");
674        let _ = result;
675
676        let grad = layer.backward(&[1.0, 0.0]).expect("Backward failed");
677        assert_eq!(grad.dl_dc.len(), 2);
678        assert!(grad.dl_dc[0].is_finite(), "dl/dc[0] not finite");
679        assert!(grad.dl_dc[1].is_finite(), "dl/dc[1] not finite");
680    }
681
682    #[test]
683    fn test_qp_layer_backward_gradient_check() {
684        // Verify dl/dc via finite differences
685        // min x^2 + y^2 + c[0]*x + c[1]*y (unconstrained)
686        // x* = [-c[0]/2, -c[1]/2]
687        // Loss L = 0.5 * ||x*||^2 = 0.5*(c[0]^2/4 + c[1]^2/4)
688        // dL/dc[0] = c[0]/4 = 0.25 for c=[1,0]
689
690        let eps = 1e-5_f64;
691        let c_base = vec![1.0_f64, 0.0];
692        let q = vec![vec![2.0, 0.0], vec![0.0, 2.0]];
693
694        let solve_and_loss = |c_vec: Vec<f64>| -> f64 {
695            let mut layer = QpLayer::new();
696            let res = layer
697                .forward(q.clone(), c_vec, vec![], vec![], vec![], vec![])
698                .expect("Forward failed");
699            res.x.iter().map(|&xi| 0.5 * xi * xi).sum::<f64>()
700        };
701
702        // Forward + backward for analytical gradient
703        let mut layer = QpLayer::new();
704        let res = layer
705            .forward(q.clone(), c_base.clone(), vec![], vec![], vec![], vec![])
706            .expect("Forward failed");
707        let dl_dx = res.x.clone(); // dL/dx = x* for L = 0.5 ||x*||^2
708        let grad = layer.backward(&dl_dx).expect("Backward failed");
709
710        // Finite difference for dc[0]
711        let mut c_plus = c_base.clone();
712        c_plus[0] += eps;
713        let mut c_minus = c_base.clone();
714        c_minus[0] -= eps;
715        let fd_dc0 = (solve_and_loss(c_plus) - solve_and_loss(c_minus)) / (2.0 * eps);
716
717        assert!(
718            (grad.dl_dc[0] - fd_dc0).abs() < 1e-3,
719            "dl/dc[0] analytical={} vs FD={}",
720            grad.dl_dc[0],
721            fd_dc0
722        );
723    }
724
725    #[test]
726    fn test_qp_layer_active_set_identification() {
727        // min x^2 + y^2 s.t. x >= 0, y >= 0, x+y >= 0.5
728        // At x* = [0.25, 0.25], x+y=0.5 is active, x>=0 and y>=0 are inactive
729        let mut layer = QpLayer::new();
730        let q = vec![vec![2.0, 0.0], vec![0.0, 2.0]];
731        let c = vec![0.0, 0.0];
732        let g = vec![
733            vec![-1.0, 0.0],  // -x <= 0
734            vec![0.0, -1.0],  // -y <= 0
735            vec![-1.0, -1.0], // -x - y <= -0.5
736        ];
737        let h = vec![0.0, 0.0, -0.5];
738
739        let result = layer
740            .forward(q, c, vec![], vec![], g, h)
741            .expect("Forward failed");
742
743        // x + y should be >= 0.5
744        let sum = result.x[0] + result.x[1];
745        assert!(sum >= 0.5 - 1e-3, "x + y = {} (should be >= 0.5)", sum);
746    }
747
748    #[test]
749    fn test_qp_layer_warm_start() {
750        // Two consecutive solves with same problem — should warm start
751        let mut layer = QpLayer::new();
752        let q = vec![vec![2.0, 0.0], vec![0.0, 2.0]];
753        let c = vec![1.0, 1.0];
754
755        let res1 = layer
756            .forward(q.clone(), c.clone(), vec![], vec![], vec![], vec![])
757            .expect("Forward 1 failed");
758
759        let res2 = layer
760            .forward(q, c, vec![], vec![], vec![], vec![])
761            .expect("Forward 2 failed");
762
763        // Both should give same result
764        assert!(
765            (res1.x[0] - res2.x[0]).abs() < 1e-6,
766            "Warm-start inconsistency"
767        );
768    }
769
770    #[test]
771    fn test_qp_layer_last_solution() {
772        let mut layer = QpLayer::new();
773        let q = vec![vec![2.0, 0.0], vec![0.0, 2.0]];
774        let c = vec![1.0, 0.0];
775
776        assert!(layer.last_solution().is_none());
777        layer
778            .forward(q, c, vec![], vec![], vec![], vec![])
779            .expect("Forward failed");
780        assert!(layer.last_solution().is_some());
781    }
782
783    #[test]
784    fn test_cholesky_solve_identity() {
785        let a = vec![vec![4.0, 0.0], vec![0.0, 9.0]];
786        let b = vec![8.0, 18.0];
787        let x = cholesky_solve(&a, &b).expect("Cholesky solve failed");
788        assert!((x[0] - 2.0).abs() < 1e-10, "x[0] = {}", x[0]);
789        assert!((x[1] - 2.0).abs() < 1e-10, "x[1] = {}", x[1]);
790    }
791}