Skip to main content

scirs2_optimize/differentiable_optimization/
types.rs

1//! Types for differentiable optimization (OptNet-style LP/QP layers).
2//!
3//! Provides configuration, result, and gradient structures for differentiable
4//! quadratic and linear programming.
5
6/// Configuration for differentiable QP solving.
7#[derive(Debug, Clone)]
8pub struct DiffQPConfig {
9    /// Convergence tolerance for the interior-point solver.
10    pub tolerance: f64,
11    /// Maximum number of interior-point iterations.
12    pub max_iterations: usize,
13    /// Tikhonov regularization added to Q diagonal for numerical stability.
14    pub regularization: f64,
15    /// Backward differentiation mode.
16    pub backward_mode: BackwardMode,
17}
18
19impl Default for DiffQPConfig {
20    fn default() -> Self {
21        Self {
22            tolerance: 1e-8,
23            max_iterations: 100,
24            regularization: 1e-7,
25            backward_mode: BackwardMode::FullDifferentiation,
26        }
27    }
28}
29
30/// Configuration for differentiable LP solving.
31#[derive(Debug, Clone)]
32pub struct DiffLPConfig {
33    /// Convergence tolerance.
34    pub tolerance: f64,
35    /// Maximum number of interior-point iterations.
36    pub max_iterations: usize,
37    /// Tolerance for identifying active inequality constraints.
38    pub active_constraint_tol: f64,
39    /// Tikhonov regularization for numerical stability.
40    pub regularization: f64,
41}
42
43impl Default for DiffLPConfig {
44    fn default() -> Self {
45        Self {
46            tolerance: 1e-8,
47            max_iterations: 100,
48            active_constraint_tol: 1e-6,
49            regularization: 1e-7,
50        }
51    }
52}
53
54/// Result of a differentiable QP forward solve.
55#[derive(Debug, Clone)]
56pub struct DiffQPResult {
57    /// Optimal primal variable x*.
58    pub optimal_x: Vec<f64>,
59    /// Dual variables for inequality constraints (lambda).
60    pub optimal_lambda: Vec<f64>,
61    /// Dual variables for equality constraints (nu).
62    pub optimal_nu: Vec<f64>,
63    /// Optimal objective value: 0.5 x' Q x + c' x.
64    pub objective: f64,
65    /// Whether the solver converged within tolerance.
66    pub converged: bool,
67    /// Number of iterations taken.
68    pub iterations: usize,
69}
70
71/// Result of a differentiable LP forward solve.
72#[derive(Debug, Clone)]
73pub struct DiffLPResult {
74    /// Optimal primal variable x*.
75    pub optimal_x: Vec<f64>,
76    /// Dual variables for inequality constraints (lambda).
77    pub optimal_lambda: Vec<f64>,
78    /// Dual variables for equality constraints (nu).
79    pub optimal_nu: Vec<f64>,
80    /// Optimal objective value: c' x.
81    pub objective: f64,
82    /// Whether the solver converged within tolerance.
83    pub converged: bool,
84    /// Number of iterations taken.
85    pub iterations: usize,
86}
87
88/// Implicit gradients of the loss w.r.t. problem parameters.
89///
90/// Given a downstream loss L(x*(θ)), these are the gradients dL/dθ
91/// computed via implicit differentiation of the KKT conditions.
92#[derive(Debug, Clone)]
93pub struct ImplicitGradient {
94    /// Gradient w.r.t. the quadratic cost matrix Q (n x n), if applicable.
95    pub dl_dq: Option<Vec<Vec<f64>>>,
96    /// Gradient w.r.t. the linear cost vector c (n).
97    pub dl_dc: Vec<f64>,
98    /// Gradient w.r.t. the inequality constraint matrix G (m x n), if applicable.
99    pub dl_dg: Option<Vec<Vec<f64>>>,
100    /// Gradient w.r.t. the inequality constraint rhs h (m).
101    pub dl_dh: Vec<f64>,
102    /// Gradient w.r.t. the equality constraint matrix A (p x n), if applicable.
103    pub dl_da: Option<Vec<Vec<f64>>>,
104    /// Gradient w.r.t. the equality constraint rhs b (p).
105    pub dl_db: Vec<f64>,
106}
107
108// ─────────────────────────────────────────────────────────────────────────────
109// New unified layer types (DiffOptLayer trait, DiffOptParams, etc.)
110// ─────────────────────────────────────────────────────────────────────────────
111
112/// Parameters for a generic QP/LP optimization layer.
113///
114/// Holds the cost and constraint data for `min ½xᵀQx + cᵀx s.t. Ax=b, Gx≤h`.
115#[derive(Debug, Clone, Default)]
116pub struct DiffOptParams {
117    /// Quadratic cost matrix Q (n×n). Empty means LP (Q=0).
118    pub q: Vec<Vec<f64>>,
119    /// Linear cost vector c (n).
120    pub c: Vec<f64>,
121    /// Equality constraint matrix A (p×n).
122    pub a: Vec<Vec<f64>>,
123    /// Equality rhs b (p).
124    pub b: Vec<f64>,
125    /// Inequality constraint matrix G (m×n): Gx ≤ h.
126    pub g: Vec<Vec<f64>>,
127    /// Inequality rhs h (m).
128    pub h: Vec<f64>,
129}
130
131/// Result of a generic optimization layer forward pass.
132#[derive(Debug, Clone)]
133pub struct DiffOptResult {
134    /// Optimal primal solution x*.
135    pub x: Vec<f64>,
136    /// Dual variables for inequality constraints λ.
137    pub lambda: Vec<f64>,
138    /// Dual variables for equality constraints ν.
139    pub nu: Vec<f64>,
140    /// Optimal objective value.
141    pub objective: f64,
142    /// Solver status.
143    pub status: DiffOptStatus,
144    /// Number of iterations taken.
145    pub iterations: usize,
146}
147
148/// Solver status for an optimization layer.
149#[derive(Debug, Clone, Copy, PartialEq, Eq)]
150#[non_exhaustive]
151pub enum DiffOptStatus {
152    /// Optimal solution found within tolerance.
153    Optimal,
154    /// Maximum iterations reached without convergence.
155    MaxIterations,
156    /// Solver detected infeasibility.
157    Infeasible,
158    /// Solver detected unboundedness.
159    Unbounded,
160}
161
162impl Default for DiffOptStatus {
163    fn default() -> Self {
164        DiffOptStatus::Optimal
165    }
166}
167
168/// Gradient of a loss w.r.t. all optimization layer parameters.
169#[derive(Debug, Clone)]
170pub struct DiffOptGrad {
171    /// Gradient dL/dQ (n×n).
172    pub dl_dq: Option<Vec<Vec<f64>>>,
173    /// Gradient dL/dc (n).
174    pub dl_dc: Vec<f64>,
175    /// Gradient dL/dA (p×n).
176    pub dl_da: Option<Vec<Vec<f64>>>,
177    /// Gradient dL/db (p).
178    pub dl_db: Vec<f64>,
179    /// Gradient dL/dG (m×n).
180    pub dl_dg: Option<Vec<Vec<f64>>>,
181    /// Gradient dL/dh (m).
182    pub dl_dh: Vec<f64>,
183}
184
185/// Mode of backward differentiation through the optimization layer.
186#[derive(Debug, Clone, Copy, PartialEq, Eq)]
187#[non_exhaustive]
188pub enum BackwardMode {
189    /// Differentiate through all KKT conditions (full implicit differentiation).
190    FullDifferentiation,
191    /// Differentiate only through active inequality constraints, treating
192    /// inactive constraints as absent. Faster but approximate when constraints
193    /// are near the boundary.
194    ActiveSetOnly,
195}
196
197/// KKT system residuals for monitoring convergence.
198#[derive(Debug, Clone)]
199pub struct KKTSystem {
200    /// Stationarity residual: Qx + c + G'λ + A'ν.
201    pub stationarity: Vec<f64>,
202    /// Primal feasibility (equality): Ax - b.
203    pub primal_eq: Vec<f64>,
204    /// Primal feasibility (inequality): Gx - h (should be ≤ 0).
205    pub primal_ineq: Vec<f64>,
206    /// Complementary slackness: λ_i * (Gx - h)_i.
207    pub complementarity: Vec<f64>,
208    /// Maximum absolute residual across all conditions.
209    pub max_residual: f64,
210}
211
212impl KKTSystem {
213    /// Check whether all KKT residuals are below the given tolerance.
214    pub fn is_satisfied(&self, tol: f64) -> bool {
215        self.max_residual < tol
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    #[test]
224    fn test_diff_qp_config_default() {
225        let cfg = DiffQPConfig::default();
226        assert!((cfg.tolerance - 1e-8).abs() < 1e-15);
227        assert_eq!(cfg.max_iterations, 100);
228        assert!((cfg.regularization - 1e-7).abs() < 1e-15);
229        assert_eq!(cfg.backward_mode, BackwardMode::FullDifferentiation);
230    }
231
232    #[test]
233    fn test_diff_lp_config_default() {
234        let cfg = DiffLPConfig::default();
235        assert!((cfg.tolerance - 1e-8).abs() < 1e-15);
236        assert_eq!(cfg.max_iterations, 100);
237        assert!((cfg.active_constraint_tol - 1e-6).abs() < 1e-15);
238    }
239
240    #[test]
241    fn test_kkt_system_satisfied() {
242        let kkt = KKTSystem {
243            stationarity: vec![1e-10],
244            primal_eq: vec![1e-10],
245            primal_ineq: vec![-0.5],
246            complementarity: vec![1e-12],
247            max_residual: 1e-10,
248        };
249        assert!(kkt.is_satisfied(1e-8));
250        assert!(!kkt.is_satisfied(1e-12));
251    }
252
253    #[test]
254    fn test_backward_mode_non_exhaustive() {
255        // Verify both modes exist and can be matched
256        let mode = BackwardMode::ActiveSetOnly;
257        match mode {
258            BackwardMode::FullDifferentiation => panic!("wrong variant"),
259            BackwardMode::ActiveSetOnly => {}
260            _ => {}
261        }
262    }
263}