scirs2_optimize/differentiable_optimization/
layer.rs1use super::diff_qp::DifferentiableQP;
9use super::types::{DiffQPConfig, DiffQPResult, ImplicitGradient};
10use crate::error::OptimizeResult;
11
12pub trait OptNetLayer {
18 type ForwardResult;
20
21 fn forward(&self) -> OptimizeResult<Self::ForwardResult>;
23
24 fn backward(
30 &self,
31 result: &Self::ForwardResult,
32 dl_dx: &[f64],
33 ) -> OptimizeResult<ImplicitGradient>;
34}
35
36#[derive(Debug, Clone)]
38pub struct StandardOptNetLayer {
39 pub qp: DifferentiableQP,
41 pub config: DiffQPConfig,
43}
44
45impl StandardOptNetLayer {
46 pub fn new(qp: DifferentiableQP, config: DiffQPConfig) -> Self {
48 Self { qp, config }
49 }
50
51 pub fn forward_batch(
53 qps: &[DifferentiableQP],
54 config: &DiffQPConfig,
55 ) -> OptimizeResult<Vec<DiffQPResult>> {
56 DifferentiableQP::batched_forward(qps, config)
57 }
58}
59
60impl OptNetLayer for StandardOptNetLayer {
61 type ForwardResult = DiffQPResult;
62
63 fn forward(&self) -> OptimizeResult<DiffQPResult> {
64 self.qp.forward(&self.config)
65 }
66
67 fn backward(&self, result: &DiffQPResult, dl_dx: &[f64]) -> OptimizeResult<ImplicitGradient> {
68 self.qp.backward(result, dl_dx, &self.config)
69 }
70}
71
72#[cfg(test)]
73mod tests {
74 use super::*;
75
76 #[test]
77 fn test_layer_trait_dispatch() {
78 let qp = DifferentiableQP::new(
79 vec![vec![2.0, 0.0], vec![0.0, 2.0]],
80 vec![1.0, 2.0],
81 vec![],
82 vec![],
83 vec![],
84 vec![],
85 )
86 .expect("QP creation failed");
87
88 let layer = StandardOptNetLayer::new(qp, DiffQPConfig::default());
89
90 let result = layer.forward().expect("Forward failed");
91 assert!(result.converged);
92
93 let dl_dx = vec![1.0, 0.0];
94 let grad = layer.backward(&result, &dl_dx).expect("Backward failed");
95 assert_eq!(grad.dl_dc.len(), 2);
96 }
97
98 #[test]
99 fn test_layer_batch_interface() {
100 let qp1 = DifferentiableQP::new(vec![vec![2.0]], vec![1.0], vec![], vec![], vec![], vec![])
101 .expect("QP1 creation failed");
102 let qp2 = DifferentiableQP::new(vec![vec![4.0]], vec![2.0], vec![], vec![], vec![], vec![])
103 .expect("QP2 creation failed");
104
105 let config = DiffQPConfig::default();
106 let results =
107 StandardOptNetLayer::forward_batch(&[qp1, qp2], &config).expect("Batch failed");
108
109 assert_eq!(results.len(), 2);
110 assert!(
112 (results[0].optimal_x[0] - (-0.5)).abs() < 1e-3,
113 "batch[0].x = {}",
114 results[0].optimal_x[0]
115 );
116 assert!(
118 (results[1].optimal_x[0] - (-0.5)).abs() < 1e-2,
119 "batch[1].x = {}",
120 results[1].optimal_x[0]
121 );
122 }
123}