Skip to main content

scirs2_optimize/differentiable_optimization/
layer.rs

1//! OptNet layer abstraction for embedding differentiable optimization
2//! in neural network pipelines.
3//!
4//! The `OptNetLayer` trait provides a uniform interface for forward (solve)
5//! and backward (gradient) passes, mirroring the pattern used in deep
6//! learning frameworks.
7
8use super::diff_qp::DifferentiableQP;
9use super::types::{DiffQPConfig, DiffQPResult, ImplicitGradient};
10use crate::error::OptimizeResult;
11
12/// Trait for a differentiable optimization layer.
13///
14/// Implementations wrap a parametric optimization problem and expose
15/// forward/backward methods suitable for integration into gradient-based
16/// training pipelines.
17pub trait OptNetLayer {
18    /// The result type returned by the forward pass.
19    type ForwardResult;
20
21    /// Solve the optimization problem (forward pass).
22    fn forward(&self) -> OptimizeResult<Self::ForwardResult>;
23
24    /// Compute parameter gradients (backward pass).
25    ///
26    /// # Arguments
27    /// * `result` – the result from a preceding `forward()` call.
28    /// * `dl_dx` – upstream gradient of the loss w.r.t. the optimal solution.
29    fn backward(
30        &self,
31        result: &Self::ForwardResult,
32        dl_dx: &[f64],
33    ) -> OptimizeResult<ImplicitGradient>;
34}
35
36/// A standard OptNet layer wrapping a differentiable QP.
37#[derive(Debug, Clone)]
38pub struct StandardOptNetLayer {
39    /// The underlying differentiable QP.
40    pub qp: DifferentiableQP,
41    /// Configuration for forward/backward passes.
42    pub config: DiffQPConfig,
43}
44
45impl StandardOptNetLayer {
46    /// Create a new OptNet layer from a differentiable QP and config.
47    pub fn new(qp: DifferentiableQP, config: DiffQPConfig) -> Self {
48        Self { qp, config }
49    }
50
51    /// Solve a batch of QPs sharing the same config.
52    pub fn forward_batch(
53        qps: &[DifferentiableQP],
54        config: &DiffQPConfig,
55    ) -> OptimizeResult<Vec<DiffQPResult>> {
56        DifferentiableQP::batched_forward(qps, config)
57    }
58}
59
60impl OptNetLayer for StandardOptNetLayer {
61    type ForwardResult = DiffQPResult;
62
63    fn forward(&self) -> OptimizeResult<DiffQPResult> {
64        self.qp.forward(&self.config)
65    }
66
67    fn backward(&self, result: &DiffQPResult, dl_dx: &[f64]) -> OptimizeResult<ImplicitGradient> {
68        self.qp.backward(result, dl_dx, &self.config)
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use super::*;
75
76    #[test]
77    fn test_layer_trait_dispatch() {
78        let qp = DifferentiableQP::new(
79            vec![vec![2.0, 0.0], vec![0.0, 2.0]],
80            vec![1.0, 2.0],
81            vec![],
82            vec![],
83            vec![],
84            vec![],
85        )
86        .expect("QP creation failed");
87
88        let layer = StandardOptNetLayer::new(qp, DiffQPConfig::default());
89
90        let result = layer.forward().expect("Forward failed");
91        assert!(result.converged);
92
93        let dl_dx = vec![1.0, 0.0];
94        let grad = layer.backward(&result, &dl_dx).expect("Backward failed");
95        assert_eq!(grad.dl_dc.len(), 2);
96    }
97
98    #[test]
99    fn test_layer_batch_interface() {
100        let qp1 = DifferentiableQP::new(vec![vec![2.0]], vec![1.0], vec![], vec![], vec![], vec![])
101            .expect("QP1 creation failed");
102        let qp2 = DifferentiableQP::new(vec![vec![4.0]], vec![2.0], vec![], vec![], vec![], vec![])
103            .expect("QP2 creation failed");
104
105        let config = DiffQPConfig::default();
106        let results =
107            StandardOptNetLayer::forward_batch(&[qp1, qp2], &config).expect("Batch failed");
108
109        assert_eq!(results.len(), 2);
110        // QP1: min x^2 + x → x* = -0.5
111        assert!(
112            (results[0].optimal_x[0] - (-0.5)).abs() < 1e-3,
113            "batch[0].x = {}",
114            results[0].optimal_x[0]
115        );
116        // QP2: min 2x^2 + 2x → x* = -0.5
117        assert!(
118            (results[1].optimal_x[0] - (-0.5)).abs() < 1e-2,
119            "batch[1].x = {}",
120            results[1].optimal_x[0]
121        );
122    }
123}