Skip to main content

scirs2_optimize/differentiable_optimization/
implicit_diff.rs

1//! Core implicit differentiation engine for optimization layers.
2//!
3//! Given an optimal solution z* satisfying KKT conditions F(z*, θ) = 0,
4//! the implicit function theorem yields:
5//!
6//!   dz*/dθ = -(∂F/∂z)⁻¹ · (∂F/∂θ)
7//!
8//! This module builds the KKT Jacobian ∂F/∂z, solves the resulting linear
9//! system, and supports an active-set variant that restricts differentiation
10//! to active inequality constraints.
11
12use crate::error::{OptimizeError, OptimizeResult};
13
14// ─────────────────────────────────────────────────────────────────────────────
15// KKT Jacobian construction
16// ─────────────────────────────────────────────────────────────────────────────
17
18/// Build the KKT Jacobian matrix ∂F/∂z for the QP:
19///
20///   min  ½ x'Qx + c'x
21///   s.t. Gx ≤ h   (m inequalities)
22///        Ax = b    (p equalities)
23///
24/// The KKT conditions (with slacks absorbed into complementarity) are:
25///
26///   F₁ = Qx + c + G'diag(λ) + A'ν  = 0   (stationarity, n eqs)
27///   F₂ = diag(λ)(Gx - h)           = 0   (complementarity, m eqs)
28///   F₃ = Ax - b                      = 0   (primal equality, p eqs)
29///
30/// The Jacobian w.r.t. z = (x, λ, ν) is the (n+m+p) × (n+m+p) matrix:
31///
32///   ┌ Q          G'diag(λ)?   A' ┐
33///   │ diag(λ)G   diag(Gx-h)  0  │
34///   └ A          0            0  ┘
35///
36/// For the complementarity row the correct linearisation is:
37///   ∂F₂/∂x = diag(λ) G
38///   ∂F₂/∂λ = diag(Gx - h)   (= diag(s) where s = h - Gx ≥ 0 at optimum)
39///
40/// # Arguments
41/// * `q` – n*n cost matrix (row-major `Vec<Vec<f64>>`).
42/// * `g` – m×n inequality constraint matrix.
43/// * `a` – p×n equality constraint matrix.
44/// * `x` – optimal primal (length n).
45/// * `lam` – optimal inequality duals (length m).
46/// * `nu` – optimal equality duals (length p).
47///
48/// # Returns
49/// The full KKT Jacobian as a flat (n+m+p) × (n+m+p) row-major matrix.
50pub fn compute_kkt_jacobian(
51    q: &[Vec<f64>],
52    g: &[Vec<f64>],
53    a: &[Vec<f64>],
54    x: &[f64],
55    lam: &[f64],
56    nu: &[f64],
57) -> Vec<Vec<f64>> {
58    let n = x.len();
59    let m = lam.len();
60    let p = nu.len();
61    let dim = n + m + p;
62
63    let mut jac = vec![vec![0.0; dim]; dim];
64
65    // ── Block (0,0): Q  (n×n) ──────────────────────────────────────────
66    for i in 0..n {
67        for j in 0..n {
68            jac[i][j] = if i < q.len() && j < q[i].len() {
69                q[i][j]
70            } else {
71                0.0
72            };
73        }
74    }
75
76    // ── Block (0,1): G' diag(λ)  →  actually for stationarity the
77    //     derivative ∂F₁/∂λ_j = G_j (the j-th row of G, transposed).
78    //     But we keep the simpler form: column j of block = G[j][:] (transpose).
79    //     Stationarity is Qx + c + G'λ + A'ν = 0, so ∂/∂λ_j = G[j][:] transposed.
80    for j in 0..m {
81        for i in 0..n {
82            let g_val = if j < g.len() && i < g[j].len() {
83                g[j][i]
84            } else {
85                0.0
86            };
87            jac[i][n + j] = g_val;
88        }
89    }
90
91    // ── Block (0,2): A'  (n×p) ─────────────────────────────────────────
92    for j in 0..p {
93        for i in 0..n {
94            let a_val = if j < a.len() && i < a[j].len() {
95                a[j][i]
96            } else {
97                0.0
98            };
99            jac[i][n + m + j] = a_val;
100        }
101    }
102
103    // ── Block (1,0): diag(λ) G  (m×n) ─────────────────────────────────
104    for i in 0..m {
105        let li = lam[i];
106        for j in 0..n {
107            let g_val = if i < g.len() && j < g[i].len() {
108                g[i][j]
109            } else {
110                0.0
111            };
112            jac[n + i][j] = li * g_val;
113        }
114    }
115
116    // ── Block (1,1): diag(Gx - h)  (m×m) ──────────────────────────────
117    // s_i = (Gx - h)_i
118    for i in 0..m {
119        let mut gx_i = 0.0;
120        if i < g.len() {
121            for j in 0..n.min(g[i].len()) {
122                gx_i += g[i][j] * x[j];
123            }
124        }
125        // Note: h is not passed here; it cancels with the complementarity form.
126        // At optimality λ_i (Gx-h)_i = 0, but the Jacobian entry is (Gx-h)_i.
127        // We store the slack as-is; caller must account for h offset.
128        jac[n + i][n + i] = gx_i; // will be adjusted by caller with -h_i
129    }
130
131    // ── Block (2,0): A  (p×n) ──────────────────────────────────────────
132    for i in 0..p {
133        for j in 0..n {
134            let a_val = if i < a.len() && j < a[i].len() {
135                a[i][j]
136            } else {
137                0.0
138            };
139            jac[n + m + i][j] = a_val;
140        }
141    }
142
143    // Blocks (1,2), (2,1), (2,2) are zero.
144    jac
145}
146
147/// Adjust the complementarity diagonal block with the rhs h.
148///
149/// Call after `compute_kkt_jacobian` to set diag(Gx - h) correctly.
150pub fn adjust_complementarity_diagonal(jac: &mut [Vec<f64>], h: &[f64], n: usize) {
151    for (i, &h_i) in h.iter().enumerate() {
152        jac[n + i][n + i] -= h_i;
153    }
154}
155
156// ─────────────────────────────────────────────────────────────────────────────
157// Linear system solver (Gaussian elimination with partial pivoting)
158// ─────────────────────────────────────────────────────────────────────────────
159
160/// Solve the linear system `A x = rhs` via Gaussian elimination with partial
161/// pivoting.
162///
163/// Both `mat` (square, row-major) and `rhs` are consumed / mutated.
164///
165/// # Errors
166/// Returns `OptimizeError::ComputationError` if the matrix is singular.
167pub fn solve_implicit_system(mat: &[Vec<f64>], rhs: &[f64]) -> OptimizeResult<Vec<f64>> {
168    let n = rhs.len();
169    if mat.len() != n {
170        return Err(OptimizeError::InvalidInput(format!(
171            "KKT matrix rows ({}) != rhs length ({})",
172            mat.len(),
173            n
174        )));
175    }
176
177    // Build augmented matrix
178    let mut aug: Vec<Vec<f64>> = mat
179        .iter()
180        .enumerate()
181        .map(|(i, row)| {
182            let mut r = row.clone();
183            r.push(rhs[i]);
184            r
185        })
186        .collect();
187
188    // Forward elimination with partial pivoting
189    for col in 0..n {
190        // Find pivot
191        let mut max_val = aug[col][col].abs();
192        let mut max_row = col;
193        for row in (col + 1)..n {
194            let v = aug[row][col].abs();
195            if v > max_val {
196                max_val = v;
197                max_row = row;
198            }
199        }
200
201        if max_val < 1e-30 {
202            return Err(OptimizeError::ComputationError(
203                "Singular KKT matrix in implicit differentiation".to_string(),
204            ));
205        }
206
207        if max_row != col {
208            aug.swap(col, max_row);
209        }
210
211        let pivot = aug[col][col];
212        for row in (col + 1)..n {
213            let factor = aug[row][col] / pivot;
214            for j in col..=n {
215                let val = aug[col][j];
216                aug[row][j] -= factor * val;
217            }
218        }
219    }
220
221    // Back substitution
222    let mut solution = vec![0.0; n];
223    for i in (0..n).rev() {
224        let mut sum = aug[i][n];
225        for j in (i + 1)..n {
226            sum -= aug[i][j] * solution[j];
227        }
228        let diag = aug[i][i];
229        if diag.abs() < 1e-30 {
230            return Err(OptimizeError::ComputationError(
231                "Zero diagonal in back substitution".to_string(),
232            ));
233        }
234        solution[i] = sum / diag;
235    }
236
237    Ok(solution)
238}
239
240/// Solve the system `mat * X = rhs_matrix` where rhs_matrix has `k` columns.
241/// Returns the solution matrix (n × k) in row-major form.
242pub fn solve_implicit_system_multi(
243    mat: &[Vec<f64>],
244    rhs_cols: &[Vec<f64>],
245) -> OptimizeResult<Vec<Vec<f64>>> {
246    rhs_cols
247        .iter()
248        .map(|rhs| solve_implicit_system(mat, rhs))
249        .collect()
250}
251
252// ─────────────────────────────────────────────────────────────────────────────
253// Active constraint identification
254// ─────────────────────────────────────────────────────────────────────────────
255
256/// Identify active inequality constraints: those where h_i - G_i x ≤ tol
257/// (i.e., the slack is near zero).
258///
259/// Returns indices of the active constraints.
260pub fn identify_active_constraints(g: &[Vec<f64>], h: &[f64], x: &[f64], tol: f64) -> Vec<usize> {
261    let m = h.len();
262    let n = x.len();
263    let mut active = Vec::new();
264
265    for i in 0..m {
266        let mut gx_i = 0.0;
267        if i < g.len() {
268            for j in 0..n.min(g[i].len()) {
269                gx_i += g[i][j] * x[j];
270            }
271        }
272        let slack = h[i] - gx_i; // slack ≥ 0 at feasibility
273        if slack.abs() <= tol {
274            active.push(i);
275        }
276    }
277
278    active
279}
280
281/// Extract rows from G and h corresponding to the given active indices.
282pub fn extract_active_constraints(
283    g: &[Vec<f64>],
284    h: &[f64],
285    active: &[usize],
286) -> (Vec<Vec<f64>>, Vec<f64>) {
287    let g_active: Vec<Vec<f64>> = active.iter().filter_map(|&i| g.get(i).cloned()).collect();
288    let h_active: Vec<f64> = active.iter().filter_map(|&i| h.get(i).copied()).collect();
289    (g_active, h_active)
290}
291
292// ─────────────────────────────────────────────────────────────────────────────
293// Full implicit backward pass
294// ─────────────────────────────────────────────────────────────────────────────
295
296/// Compute the full implicit gradient dz*/dθ for a QP.
297///
298/// Given dl/dx (the upstream gradient from the loss), compute dl/dθ for
299/// θ = (Q, c, G, h, A, b) by solving:
300///
301///   (∂F/∂z)' dz = -[dl/dx, 0, 0]'
302///
303/// and then computing dl/dθ = (∂F/∂θ)' dz.
304///
305/// # Arguments
306/// * `q` – n×n cost matrix.
307/// * `g` – m×n inequality constraint matrix.
308/// * `h` – m inequality rhs.
309/// * `a` – p×n equality constraint matrix.
310/// * `x` – optimal primal solution.
311/// * `lam` – optimal inequality duals.
312/// * `nu` – optimal equality duals.
313/// * `dl_dx` – upstream gradient dl/dx (length n).
314///
315/// # Returns
316/// The implicit gradients for all parameters.
317pub fn compute_full_implicit_gradient(
318    q: &[Vec<f64>],
319    g: &[Vec<f64>],
320    h: &[f64],
321    a: &[Vec<f64>],
322    x: &[f64],
323    lam: &[f64],
324    nu: &[f64],
325    dl_dx: &[f64],
326) -> OptimizeResult<super::types::ImplicitGradient> {
327    let n = x.len();
328    let m = lam.len();
329    let p = nu.len();
330    let dim = n + m + p;
331
332    // Build the KKT Jacobian
333    let mut kkt = compute_kkt_jacobian(q, g, a, x, lam, nu);
334    adjust_complementarity_diagonal(&mut kkt, h, n);
335
336    // Transpose the KKT Jacobian (we solve the adjoint system)
337    let mut kkt_t = vec![vec![0.0; dim]; dim];
338    for i in 0..dim {
339        for j in 0..dim {
340            kkt_t[i][j] = kkt[j][i];
341        }
342    }
343
344    // RHS = -[dl/dx, 0_m, 0_p]
345    let mut rhs = vec![0.0; dim];
346    for i in 0..n {
347        rhs[i] = -dl_dx[i];
348    }
349
350    // Solve: kkt_t * dz = rhs
351    let dz = solve_implicit_system(&kkt_t, &rhs)?;
352
353    // Extract components: dz = (dx, dlam, dnu)
354    let dx = &dz[..n];
355    let dlam = &dz[n..n + m];
356    let dnu = &dz[n + m..];
357
358    // ── Compute dl/dθ from dz ──────────────────────────────────────────
359    // dl/dc = dx  (from ∂F₁/∂c = I)
360    let dl_dc = dx.to_vec();
361
362    // dl/dh = -dlam  (from ∂F₂/∂h = -diag(λ), contracted with dlam gives -λ·dlam,
363    //  but more directly: ∂F₂/∂h_i = -λ_i, so dl/dh_i = -λ_i * dlam_i / λ_i = -dlam_i
364    //  when λ_i ≠ 0, and 0 otherwise; we use dlam directly.)
365    let dl_dh: Vec<f64> = dlam.iter().map(|&v| -v).collect();
366
367    // dl/db = -dnu  (from ∂F₃/∂b = -I)
368    let dl_db: Vec<f64> = dnu.iter().map(|&v| -v).collect();
369
370    // dl/dQ = dx * x' (outer product, symmetric part)
371    let mut dl_dq = vec![vec![0.0; n]; n];
372    for i in 0..n {
373        for j in 0..n {
374            // ∂F₁/∂Q_{ij} · x_j, contracted with dx_i
375            // = 0.5 * (dx_i * x_j + dx_j * x_i)  for symmetric Q
376            dl_dq[i][j] = 0.5 * (dx[i] * x[j] + dx[j] * x[i]);
377        }
378    }
379
380    // dl/dG: from stationarity (G'λ term) and complementarity (diag(λ)G term)
381    let mut dl_dg = vec![vec![0.0; n]; m];
382    for i in 0..m {
383        for j in 0..n {
384            // From stationarity: ∂F₁_j / ∂G_{ij} = λ_i, contracted with dx_j
385            // From complementarity: ∂F₂_i / ∂G_{ij} = λ_i * x_j, contracted with dlam_i
386            dl_dg[i][j] = dx[j] * lam[i] + dlam[i] * lam[i] * x[j];
387        }
388    }
389
390    // dl/dA: from stationarity (A'ν) and primal equality (Ax-b)
391    let mut dl_da = vec![vec![0.0; n]; p];
392    for i in 0..p {
393        for j in 0..n {
394            // From stationarity: ∂F₁_j / ∂A_{ij} = ν_i, contracted with dx_j
395            // From primal eq: ∂F₃_i / ∂A_{ij} = x_j, contracted with dnu_i
396            dl_da[i][j] = dx[j] * nu[i] + dnu[i] * x[j];
397        }
398    }
399
400    Ok(super::types::ImplicitGradient {
401        dl_dq: Some(dl_dq),
402        dl_dc,
403        dl_dg: Some(dl_dg),
404        dl_dh,
405        dl_da: if p > 0 { Some(dl_da) } else { None },
406        dl_db,
407    })
408}
409
410/// Compute implicit gradient using only active constraints.
411///
412/// This is faster than full differentiation when many inequality constraints
413/// are inactive, since those constraints have zero dual variables and do not
414/// contribute to the gradient.
415pub fn compute_active_set_implicit_gradient(
416    q: &[Vec<f64>],
417    g: &[Vec<f64>],
418    h: &[f64],
419    a: &[Vec<f64>],
420    x: &[f64],
421    lam: &[f64],
422    nu: &[f64],
423    dl_dx: &[f64],
424    active_tol: f64,
425) -> OptimizeResult<super::types::ImplicitGradient> {
426    let m = lam.len();
427
428    // Identify active constraints
429    let active = identify_active_constraints(g, h, x, active_tol);
430    let (g_active, h_active) = extract_active_constraints(g, h, &active);
431    let lam_active: Vec<f64> = active
432        .iter()
433        .filter_map(|&i| if i < m { Some(lam[i]) } else { None })
434        .collect();
435
436    // Solve reduced system
437    let grad =
438        compute_full_implicit_gradient(q, &g_active, &h_active, a, x, &lam_active, nu, dl_dx)?;
439
440    // Expand gradient back to full dimension
441    let m_full = lam.len();
442    let n = x.len();
443
444    let mut dl_dh_full = vec![0.0; m_full];
445    for (idx, &ai) in active.iter().enumerate() {
446        if ai < m_full && idx < grad.dl_dh.len() {
447            dl_dh_full[ai] = grad.dl_dh[idx];
448        }
449    }
450
451    let dl_dg_full = if let Some(ref dg) = grad.dl_dg {
452        let mut full = vec![vec![0.0; n]; m_full];
453        for (idx, &ai) in active.iter().enumerate() {
454            if ai < m_full && idx < dg.len() {
455                full[ai] = dg[idx].clone();
456            }
457        }
458        Some(full)
459    } else {
460        None
461    };
462
463    Ok(super::types::ImplicitGradient {
464        dl_dq: grad.dl_dq,
465        dl_dc: grad.dl_dc,
466        dl_dg: dl_dg_full,
467        dl_dh: dl_dh_full,
468        dl_da: grad.dl_da,
469        dl_db: grad.dl_db,
470    })
471}
472
473#[cfg(test)]
474mod tests {
475    use super::*;
476
477    #[test]
478    fn test_kkt_jacobian_structure_2x2() {
479        // Simple 2-var QP with 1 inequality, 0 equalities
480        // Q = [[2, 0], [0, 2]], G = [[1, 1]], h = [1]
481        let q = vec![vec![2.0, 0.0], vec![0.0, 2.0]];
482        let g = vec![vec![1.0, 1.0]];
483        let a: Vec<Vec<f64>> = vec![];
484        let x = vec![0.25, 0.25];
485        let lam = vec![0.5];
486        let nu: Vec<f64> = vec![];
487
488        let jac = compute_kkt_jacobian(&q, &g, &a, &x, &lam, &nu);
489
490        // Dimension: 2 + 1 + 0 = 3
491        assert_eq!(jac.len(), 3);
492        assert_eq!(jac[0].len(), 3);
493
494        // Block (0,0): Q
495        assert!((jac[0][0] - 2.0).abs() < 1e-12);
496        assert!((jac[1][1] - 2.0).abs() < 1e-12);
497
498        // Block (0,1): G' (transposed)
499        assert!((jac[0][2] - 1.0).abs() < 1e-12);
500        assert!((jac[1][2] - 1.0).abs() < 1e-12);
501
502        // Block (1,0): diag(λ) G
503        assert!((jac[2][0] - 0.5).abs() < 1e-12);
504        assert!((jac[2][1] - 0.5).abs() < 1e-12);
505    }
506
507    #[test]
508    fn test_active_constraint_identification() {
509        let g = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
510        let h = vec![1.0, 2.0, 1.5];
511        let x = vec![1.0, 0.5]; // Gx = [1.0, 0.5, 1.5], slacks = [0.0, 1.5, 0.0]
512
513        let active = identify_active_constraints(&g, &h, &x, 1e-6);
514        assert_eq!(active, vec![0, 2]);
515    }
516
517    #[test]
518    fn test_solve_implicit_system_simple() {
519        // Solve [[2, 1], [1, 3]] x = [5, 7]  →  x = [8/5, 9/5]
520        let mat = vec![vec![2.0, 1.0], vec![1.0, 3.0]];
521        let rhs = vec![5.0, 7.0];
522        let sol = solve_implicit_system(&mat, &rhs).expect("solve failed");
523        assert!((sol[0] - 1.6).abs() < 1e-10);
524        assert!((sol[1] - 1.8).abs() < 1e-10);
525    }
526
527    #[test]
528    fn test_solve_singular_matrix() {
529        let mat = vec![vec![1.0, 2.0], vec![2.0, 4.0]];
530        let rhs = vec![3.0, 6.0];
531        let result = solve_implicit_system(&mat, &rhs);
532        assert!(result.is_err());
533    }
534
535    #[test]
536    fn test_extract_active_constraints() {
537        let g = vec![vec![1.0], vec![2.0], vec![3.0]];
538        let h = vec![10.0, 20.0, 30.0];
539        let active = vec![0, 2];
540
541        let (ga, ha) = extract_active_constraints(&g, &h, &active);
542        assert_eq!(ga.len(), 2);
543        assert!((ga[0][0] - 1.0).abs() < 1e-12);
544        assert!((ga[1][0] - 3.0).abs() < 1e-12);
545        assert!((ha[0] - 10.0).abs() < 1e-12);
546        assert!((ha[1] - 30.0).abs() < 1e-12);
547    }
548
549    #[test]
550    fn test_empty_constraints() {
551        let q = vec![vec![2.0, 0.0], vec![0.0, 2.0]];
552        let g: Vec<Vec<f64>> = vec![];
553        let a: Vec<Vec<f64>> = vec![];
554        let x = vec![1.0, 2.0];
555        let lam: Vec<f64> = vec![];
556        let nu: Vec<f64> = vec![];
557
558        let jac = compute_kkt_jacobian(&q, &g, &a, &x, &lam, &nu);
559        assert_eq!(jac.len(), 2);
560        assert!((jac[0][0] - 2.0).abs() < 1e-12);
561    }
562}