scirs2_optimize/differentiable_optimization/types.rs
1//! Types for differentiable optimization (OptNet-style LP/QP layers).
2//!
3//! Provides configuration, result, and gradient structures for differentiable
4//! quadratic and linear programming.
5
6/// Configuration for differentiable QP solving.
7#[derive(Debug, Clone)]
8pub struct DiffQPConfig {
9 /// Convergence tolerance for the interior-point solver.
10 pub tolerance: f64,
11 /// Maximum number of interior-point iterations.
12 pub max_iterations: usize,
13 /// Tikhonov regularization added to Q diagonal for numerical stability.
14 pub regularization: f64,
15 /// Backward differentiation mode.
16 pub backward_mode: BackwardMode,
17}
18
19impl Default for DiffQPConfig {
20 fn default() -> Self {
21 Self {
22 tolerance: 1e-8,
23 max_iterations: 100,
24 regularization: 1e-7,
25 backward_mode: BackwardMode::FullDifferentiation,
26 }
27 }
28}
29
30/// Configuration for differentiable LP solving.
31#[derive(Debug, Clone)]
32pub struct DiffLPConfig {
33 /// Convergence tolerance.
34 pub tolerance: f64,
35 /// Maximum number of interior-point iterations.
36 pub max_iterations: usize,
37 /// Tolerance for identifying active inequality constraints.
38 pub active_constraint_tol: f64,
39 /// Tikhonov regularization for numerical stability.
40 pub regularization: f64,
41}
42
43impl Default for DiffLPConfig {
44 fn default() -> Self {
45 Self {
46 tolerance: 1e-8,
47 max_iterations: 100,
48 active_constraint_tol: 1e-6,
49 regularization: 1e-7,
50 }
51 }
52}
53
54/// Result of a differentiable QP forward solve.
55#[derive(Debug, Clone)]
56pub struct DiffQPResult {
57 /// Optimal primal variable x*.
58 pub optimal_x: Vec<f64>,
59 /// Dual variables for inequality constraints (lambda).
60 pub optimal_lambda: Vec<f64>,
61 /// Dual variables for equality constraints (nu).
62 pub optimal_nu: Vec<f64>,
63 /// Optimal objective value: 0.5 x' Q x + c' x.
64 pub objective: f64,
65 /// Whether the solver converged within tolerance.
66 pub converged: bool,
67 /// Number of iterations taken.
68 pub iterations: usize,
69}
70
71/// Result of a differentiable LP forward solve.
72#[derive(Debug, Clone)]
73pub struct DiffLPResult {
74 /// Optimal primal variable x*.
75 pub optimal_x: Vec<f64>,
76 /// Dual variables for inequality constraints (lambda).
77 pub optimal_lambda: Vec<f64>,
78 /// Dual variables for equality constraints (nu).
79 pub optimal_nu: Vec<f64>,
80 /// Optimal objective value: c' x.
81 pub objective: f64,
82 /// Whether the solver converged within tolerance.
83 pub converged: bool,
84 /// Number of iterations taken.
85 pub iterations: usize,
86}
87
88/// Implicit gradients of the loss w.r.t. problem parameters.
89///
90/// Given a downstream loss L(x*(θ)), these are the gradients dL/dθ
91/// computed via implicit differentiation of the KKT conditions.
92#[derive(Debug, Clone)]
93pub struct ImplicitGradient {
94 /// Gradient w.r.t. the quadratic cost matrix Q (n x n), if applicable.
95 pub dl_dq: Option<Vec<Vec<f64>>>,
96 /// Gradient w.r.t. the linear cost vector c (n).
97 pub dl_dc: Vec<f64>,
98 /// Gradient w.r.t. the inequality constraint matrix G (m x n), if applicable.
99 pub dl_dg: Option<Vec<Vec<f64>>>,
100 /// Gradient w.r.t. the inequality constraint rhs h (m).
101 pub dl_dh: Vec<f64>,
102 /// Gradient w.r.t. the equality constraint matrix A (p x n), if applicable.
103 pub dl_da: Option<Vec<Vec<f64>>>,
104 /// Gradient w.r.t. the equality constraint rhs b (p).
105 pub dl_db: Vec<f64>,
106}
107
108// ─────────────────────────────────────────────────────────────────────────────
109// New unified layer types (DiffOptLayer trait, DiffOptParams, etc.)
110// ─────────────────────────────────────────────────────────────────────────────
111
112/// Parameters for a generic QP/LP optimization layer.
113///
114/// Holds the cost and constraint data for `min ½xᵀQx + cᵀx s.t. Ax=b, Gx≤h`.
115#[derive(Debug, Clone, Default)]
116pub struct DiffOptParams {
117 /// Quadratic cost matrix Q (n×n). Empty means LP (Q=0).
118 pub q: Vec<Vec<f64>>,
119 /// Linear cost vector c (n).
120 pub c: Vec<f64>,
121 /// Equality constraint matrix A (p×n).
122 pub a: Vec<Vec<f64>>,
123 /// Equality rhs b (p).
124 pub b: Vec<f64>,
125 /// Inequality constraint matrix G (m×n): Gx ≤ h.
126 pub g: Vec<Vec<f64>>,
127 /// Inequality rhs h (m).
128 pub h: Vec<f64>,
129}
130
131/// Result of a generic optimization layer forward pass.
132#[derive(Debug, Clone)]
133pub struct DiffOptResult {
134 /// Optimal primal solution x*.
135 pub x: Vec<f64>,
136 /// Dual variables for inequality constraints λ.
137 pub lambda: Vec<f64>,
138 /// Dual variables for equality constraints ν.
139 pub nu: Vec<f64>,
140 /// Optimal objective value.
141 pub objective: f64,
142 /// Solver status.
143 pub status: DiffOptStatus,
144 /// Number of iterations taken.
145 pub iterations: usize,
146}
147
148/// Solver status for an optimization layer.
149#[derive(Debug, Clone, Copy, PartialEq, Eq)]
150#[non_exhaustive]
151pub enum DiffOptStatus {
152 /// Optimal solution found within tolerance.
153 Optimal,
154 /// Maximum iterations reached without convergence.
155 MaxIterations,
156 /// Solver detected infeasibility.
157 Infeasible,
158 /// Solver detected unboundedness.
159 Unbounded,
160}
161
162impl Default for DiffOptStatus {
163 fn default() -> Self {
164 DiffOptStatus::Optimal
165 }
166}
167
168/// Gradient of a loss w.r.t. all optimization layer parameters.
169#[derive(Debug, Clone)]
170pub struct DiffOptGrad {
171 /// Gradient dL/dQ (n×n).
172 pub dl_dq: Option<Vec<Vec<f64>>>,
173 /// Gradient dL/dc (n).
174 pub dl_dc: Vec<f64>,
175 /// Gradient dL/dA (p×n).
176 pub dl_da: Option<Vec<Vec<f64>>>,
177 /// Gradient dL/db (p).
178 pub dl_db: Vec<f64>,
179 /// Gradient dL/dG (m×n).
180 pub dl_dg: Option<Vec<Vec<f64>>>,
181 /// Gradient dL/dh (m).
182 pub dl_dh: Vec<f64>,
183}
184
185/// Mode of backward differentiation through the optimization layer.
186#[derive(Debug, Clone, Copy, PartialEq, Eq)]
187#[non_exhaustive]
188pub enum BackwardMode {
189 /// Differentiate through all KKT conditions (full implicit differentiation).
190 FullDifferentiation,
191 /// Differentiate only through active inequality constraints, treating
192 /// inactive constraints as absent. Faster but approximate when constraints
193 /// are near the boundary.
194 ActiveSetOnly,
195}
196
197/// KKT system residuals for monitoring convergence.
198#[derive(Debug, Clone)]
199pub struct KKTSystem {
200 /// Stationarity residual: Qx + c + G'λ + A'ν.
201 pub stationarity: Vec<f64>,
202 /// Primal feasibility (equality): Ax - b.
203 pub primal_eq: Vec<f64>,
204 /// Primal feasibility (inequality): Gx - h (should be ≤ 0).
205 pub primal_ineq: Vec<f64>,
206 /// Complementary slackness: λ_i * (Gx - h)_i.
207 pub complementarity: Vec<f64>,
208 /// Maximum absolute residual across all conditions.
209 pub max_residual: f64,
210}
211
212impl KKTSystem {
213 /// Check whether all KKT residuals are below the given tolerance.
214 pub fn is_satisfied(&self, tol: f64) -> bool {
215 self.max_residual < tol
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222
223 #[test]
224 fn test_diff_qp_config_default() {
225 let cfg = DiffQPConfig::default();
226 assert!((cfg.tolerance - 1e-8).abs() < 1e-15);
227 assert_eq!(cfg.max_iterations, 100);
228 assert!((cfg.regularization - 1e-7).abs() < 1e-15);
229 assert_eq!(cfg.backward_mode, BackwardMode::FullDifferentiation);
230 }
231
232 #[test]
233 fn test_diff_lp_config_default() {
234 let cfg = DiffLPConfig::default();
235 assert!((cfg.tolerance - 1e-8).abs() < 1e-15);
236 assert_eq!(cfg.max_iterations, 100);
237 assert!((cfg.active_constraint_tol - 1e-6).abs() < 1e-15);
238 }
239
240 #[test]
241 fn test_kkt_system_satisfied() {
242 let kkt = KKTSystem {
243 stationarity: vec![1e-10],
244 primal_eq: vec![1e-10],
245 primal_ineq: vec![-0.5],
246 complementarity: vec![1e-12],
247 max_residual: 1e-10,
248 };
249 assert!(kkt.is_satisfied(1e-8));
250 assert!(!kkt.is_satisfied(1e-12));
251 }
252
253 #[test]
254 fn test_backward_mode_non_exhaustive() {
255 // Verify both modes exist and can be matched
256 let mode = BackwardMode::ActiveSetOnly;
257 match mode {
258 BackwardMode::FullDifferentiation => panic!("wrong variant"),
259 BackwardMode::ActiveSetOnly => {}
260 _ => {}
261 }
262 }
263}