1use crate::error::{OptimizeError, OptimizeResult};
21
22use super::implicit_diff::solve_implicit_system;
23use super::types::{DiffOptGrad, DiffOptResult, DiffOptStatus};
24
25#[derive(Debug, Clone)]
31pub struct LpLayerConfig {
32 pub epsilon: f64,
34 pub max_iter: usize,
36 pub tol: f64,
38 pub basis_tol: f64,
40}
41
42impl Default for LpLayerConfig {
43 fn default() -> Self {
44 Self {
45 epsilon: 1e-3,
46 max_iter: 500,
47 tol: 1e-8,
48 basis_tol: 1e-6,
49 }
50 }
51}
52
53fn softmax(v: &[f64]) -> Vec<f64> {
59 let max_v = v.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
60 let exp_v: Vec<f64> = v.iter().map(|&vi| (vi - max_v).exp()).collect();
61 let sum: f64 = exp_v.iter().sum();
62 if sum < 1e-300 {
63 vec![1.0 / v.len() as f64; v.len()]
64 } else {
65 exp_v.iter().map(|&e| e / sum).collect()
66 }
67}
68
69fn logsumexp(v: &[f64]) -> f64 {
71 let max_v = v.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
72 let sum: f64 = v.iter().map(|&vi| (vi - max_v).exp()).sum();
73 max_v + sum.ln()
74}
75
76pub fn lp_perturbed(
100 c: &[f64],
101 a: &[Vec<f64>],
102 b: &[f64],
103 epsilon: f64,
104 max_iter: usize,
105 tol: f64,
106) -> OptimizeResult<Vec<f64>> {
107 let n = c.len();
108 let m = b.len();
109
110 if a.len() != m {
111 return Err(OptimizeError::InvalidInput(format!(
112 "A has {} rows but b has length {}",
113 a.len(),
114 m
115 )));
116 }
117
118 if epsilon <= 0.0 {
120 return Err(OptimizeError::InvalidInput(
121 "epsilon must be positive for entropic regularization".to_string(),
122 ));
123 }
124
125 let mut lambda = vec![0.0_f64; m];
127
128 let budget: f64 = b.iter().filter(|&&bi| bi > 0.0).sum::<f64>().max(1.0);
133
134 let primal_from_dual = |lam: &[f64]| -> Vec<f64> {
135 let scores: Vec<f64> = (0..n)
136 .map(|i| {
137 let atl: f64 = (0..m)
138 .map(|k| {
139 let a_ki = if k < a.len() && i < a[k].len() {
140 a[k][i]
141 } else {
142 0.0
143 };
144 lam[k] * a_ki
145 })
146 .sum();
147 (-c[i] - atl) / epsilon
148 })
149 .collect();
150 let sm = softmax(&scores);
152 sm.iter().map(|&si| si * budget).collect()
153 };
154
155 let mut step = 1.0_f64 / (1.0 + budget * budget);
157
158 for _iter in 0..max_iter {
159 let x = primal_from_dual(&lambda);
160
161 let ax: Vec<f64> = (0..m)
163 .map(|k| {
164 (0..n)
165 .map(|i| {
166 let a_ki = if k < a.len() && i < a[k].len() {
167 a[k][i]
168 } else {
169 0.0
170 };
171 a_ki * x[i]
172 })
173 .sum::<f64>()
174 })
175 .collect();
176
177 let lambda_new: Vec<f64> = (0..m)
179 .map(|k| (lambda[k] + step * (b[k] - ax[k])).max(0.0))
180 .collect();
181
182 let delta: f64 = lambda_new
184 .iter()
185 .zip(lambda.iter())
186 .map(|(a, b)| (a - b).powi(2))
187 .sum::<f64>()
188 .sqrt();
189
190 lambda = lambda_new;
191
192 if delta < tol {
193 break;
194 }
195
196 if _iter % 50 == 49 {
198 step *= 0.9;
199 }
200 }
201
202 let x = primal_from_dual(&lambda);
203 Ok(x)
204}
205
206pub fn lp_gradient(
218 c: &[f64],
219 a: &[Vec<f64>],
220 b: &[f64],
221 x_star: &[f64],
222 dl_dx: &[f64],
223 epsilon: f64,
224) -> OptimizeResult<Vec<f64>> {
225 let n = c.len();
226 if x_star.len() != n || dl_dx.len() != n {
227 return Err(OptimizeError::InvalidInput(
228 "Dimension mismatch in lp_gradient".to_string(),
229 ));
230 }
231
232 let sum_x: f64 = x_star.iter().sum();
237 let norm = if sum_x > 1e-15 { sum_x } else { 1.0 };
238
239 let mut dl_dc = vec![0.0_f64; n];
240 for j in 0..n {
241 for i in 0..n {
242 let delta_ij = if i == j { 1.0 } else { 0.0 };
243 let j_ij = -(1.0 / epsilon) * x_star[i] * (delta_ij - x_star[j] / norm);
245 dl_dc[j] += j_ij * dl_dx[i];
246 }
247 }
248
249 let _ = (a, b);
251
252 Ok(dl_dc)
253}
254
255#[derive(Debug, Clone)]
261pub struct LpLayer {
262 config: LpLayerConfig,
263 last_x: Option<Vec<f64>>,
265 last_c: Option<Vec<f64>>,
266 last_a: Option<Vec<Vec<f64>>>,
267 last_b: Option<Vec<f64>>,
268}
269
270impl LpLayer {
271 pub fn new() -> Self {
273 Self {
274 config: LpLayerConfig::default(),
275 last_x: None,
276 last_c: None,
277 last_a: None,
278 last_b: None,
279 }
280 }
281
282 pub fn with_config(config: LpLayerConfig) -> Self {
284 Self {
285 config,
286 last_x: None,
287 last_c: None,
288 last_a: None,
289 last_b: None,
290 }
291 }
292
293 pub fn forward(
300 &mut self,
301 c: Vec<f64>,
302 a: Vec<Vec<f64>>,
303 b: Vec<f64>,
304 ) -> OptimizeResult<DiffOptResult> {
305 let n = c.len();
306 let m = b.len();
307
308 let x = lp_perturbed(
309 &c,
310 &a,
311 &b,
312 self.config.epsilon,
313 self.config.max_iter,
314 self.config.tol,
315 )?;
316
317 let feasible = x.iter().all(|&xi| xi >= -1e-6)
319 && (0..m).all(|k| {
320 let ax_k: f64 = (0..n)
321 .map(|i| {
322 let a_ki = if k < a.len() && i < a[k].len() {
323 a[k][i]
324 } else {
325 0.0
326 };
327 a_ki * x[i]
328 })
329 .sum();
330 ax_k <= b[k] + 1e-4
331 });
332
333 let status = if feasible {
334 DiffOptStatus::Optimal
335 } else {
336 DiffOptStatus::MaxIterations
337 };
338
339 let objective: f64 = c.iter().zip(x.iter()).map(|(&ci, &xi)| ci * xi).sum();
341
342 self.last_x = Some(x.clone());
344 self.last_c = Some(c.clone());
345 self.last_a = Some(a.clone());
346 self.last_b = Some(b.clone());
347
348 Ok(DiffOptResult {
349 x,
350 lambda: vec![0.0; m], nu: vec![],
352 objective,
353 status,
354 iterations: self.config.max_iter,
355 })
356 }
357
358 pub fn backward(&self, dl_dx: &[f64]) -> OptimizeResult<DiffOptGrad> {
360 let x_star = self.last_x.as_ref().ok_or_else(|| {
361 OptimizeError::ComputationError("LpLayer::backward called before forward".to_string())
362 })?;
363 let c = self
364 .last_c
365 .as_ref()
366 .ok_or_else(|| OptimizeError::ComputationError("No cached c".to_string()))?;
367 let a = self
368 .last_a
369 .as_ref()
370 .ok_or_else(|| OptimizeError::ComputationError("No cached A".to_string()))?;
371 let b = self
372 .last_b
373 .as_ref()
374 .ok_or_else(|| OptimizeError::ComputationError("No cached b".to_string()))?;
375
376 let dl_dc = lp_gradient(c, a, b, x_star, dl_dx, self.config.epsilon)?;
377 let m = b.len();
378 let n = c.len();
379
380 Ok(DiffOptGrad {
381 dl_dq: None,
382 dl_dc,
383 dl_da: None,
384 dl_db: vec![0.0_f64; 0],
385 dl_dg: Some(vec![vec![0.0_f64; n]; m]),
386 dl_dh: vec![0.0_f64; m],
387 })
388 }
389}
390
391impl Default for LpLayer {
392 fn default() -> Self {
393 Self::new()
394 }
395}
396
397#[derive(Debug, Clone)]
415pub struct LpSensitivity {
416 pub basis: Vec<usize>,
418 pub x_basic: Vec<f64>,
420 pub dual: Vec<f64>,
422 pub b_inv: Vec<Vec<f64>>,
424}
425
426impl LpSensitivity {
427 pub fn new(
435 a: &[Vec<f64>],
436 b_rhs: &[f64],
437 c: &[f64],
438 basis: Vec<usize>,
439 ) -> OptimizeResult<Self> {
440 let p = b_rhs.len();
441 let n = c.len();
442
443 if basis.len() != p {
444 return Err(OptimizeError::InvalidInput(format!(
445 "Basis size {} != number of constraints {}",
446 basis.len(),
447 p
448 )));
449 }
450
451 let b_mat: Vec<Vec<f64>> = (0..p)
453 .map(|i| {
454 basis
455 .iter()
456 .map(|&j| {
457 if i < a.len() && j < a[i].len() {
458 a[i][j]
459 } else {
460 0.0
461 }
462 })
463 .collect()
464 })
465 .collect();
466
467 let b_inv = invert_matrix(&b_mat)?;
469
470 let x_basic: Vec<f64> = (0..p)
472 .map(|i| {
473 (0..p)
474 .map(|j| b_inv[i][j] * if j < b_rhs.len() { b_rhs[j] } else { 0.0 })
475 .sum()
476 })
477 .collect();
478
479 let c_basic: Vec<f64> = basis
481 .iter()
482 .map(|&j| if j < c.len() { c[j] } else { 0.0 })
483 .collect();
484 let dual: Vec<f64> = (0..p)
485 .map(|i| {
486 (0..p)
487 .map(|j| b_inv[j][i] * if j < c_basic.len() { c_basic[j] } else { 0.0 })
488 .sum()
489 })
490 .collect();
491
492 let _ = n;
493 Ok(Self {
494 basis,
495 x_basic,
496 dual,
497 b_inv,
498 })
499 }
500
501 pub fn dl_dc(&self, n: usize, dl_dx: &[f64]) -> Vec<f64> {
510 let p = self.basis.len();
511 let mut grad = vec![0.0_f64; n];
512
513 for (j, &bj) in self.basis.iter().enumerate() {
515 if bj < n {
516 let sum: f64 = (0..p)
517 .map(|i| {
518 let dl_xi = if self.basis[i] < dl_dx.len() {
519 dl_dx[self.basis[i]]
520 } else {
521 0.0
522 };
523 self.b_inv[i][j] * dl_xi
524 })
525 .sum();
526 grad[bj] = -sum;
527 }
528 }
529 grad
530 }
531
532 pub fn dl_db(&self, dl_dx: &[f64]) -> Vec<f64> {
537 let p = self.basis.len();
538
539 (0..p)
540 .map(|k| {
541 (0..p)
542 .map(|i| {
543 let dl_xi = if self.basis[i] < dl_dx.len() {
544 dl_dx[self.basis[i]]
545 } else {
546 0.0
547 };
548 self.b_inv[i][k] * dl_xi
549 })
550 .sum()
551 })
552 .collect()
553 }
554
555 pub fn reduced_costs(&self, a: &[Vec<f64>], c: &[f64]) -> Vec<f64> {
560 let n = c.len();
561 let nonbasic: Vec<usize> = (0..n).filter(|i| !self.basis.contains(i)).collect();
562 let p = self.basis.len();
563
564 let mut rc = vec![0.0_f64; n];
565 for &j in &nonbasic {
566 let a_j: Vec<f64> = (0..p)
568 .map(|i| {
569 if i < a.len() && j < a[i].len() {
570 a[i][j]
571 } else {
572 0.0
573 }
574 })
575 .collect();
576 let ytaj: f64 = self
578 .dual
579 .iter()
580 .zip(a_j.iter())
581 .map(|(yi, aij)| yi * aij)
582 .sum();
583 rc[j] = if j < c.len() { c[j] } else { 0.0 } - ytaj;
584 }
585 rc
586 }
587}
588
589fn invert_matrix(a: &[Vec<f64>]) -> OptimizeResult<Vec<Vec<f64>>> {
591 let n = a.len();
592 if n == 0 {
593 return Ok(vec![]);
594 }
595
596 let mut aug: Vec<Vec<f64>> = a
598 .iter()
599 .enumerate()
600 .map(|(i, row)| {
601 let mut r = row.clone();
602 for j in 0..n {
603 r.push(if i == j { 1.0 } else { 0.0 });
604 }
605 r
606 })
607 .collect();
608
609 for col in 0..n {
611 let mut max_val = aug[col][col].abs();
612 let mut max_row = col;
613 for row in (col + 1)..n {
614 let v = aug[row][col].abs();
615 if v > max_val {
616 max_val = v;
617 max_row = row;
618 }
619 }
620 if max_val < 1e-12 {
621 return Err(OptimizeError::ComputationError(
622 "Singular matrix in LP basis inversion".to_string(),
623 ));
624 }
625 if max_row != col {
626 aug.swap(col, max_row);
627 }
628
629 let pivot = aug[col][col];
630 for j in col..2 * n {
631 aug[col][j] /= pivot;
632 }
633 for row in 0..n {
634 if row != col {
635 let factor = aug[row][col];
636 for j in col..2 * n {
637 let aug_col_j = aug[col][j];
638 aug[row][j] -= factor * aug_col_j;
639 }
640 }
641 }
642 }
643
644 Ok((0..n).map(|i| aug[i][n..2 * n].to_vec()).collect())
646}
647
648#[cfg(test)]
653mod tests {
654 use super::*;
655
656 #[test]
657 fn test_lp_layer_config_default() {
658 let cfg = LpLayerConfig::default();
659 assert!(cfg.epsilon > 0.0);
660 assert!(cfg.max_iter > 0);
661 assert!(cfg.tol > 0.0);
662 }
663
664 #[test]
665 fn test_lp_perturbed_feasibility() {
666 let c = vec![-1.0, -1.0];
668 let a = vec![
669 vec![1.0, 1.0], vec![-1.0, 0.0], vec![0.0, -1.0], ];
673 let b = vec![1.0, 0.0, 0.0];
674
675 let x = lp_perturbed(&c, &a, &b, 0.1, 1000, 1e-8).expect("LP failed");
676
677 for xi in &x {
679 assert!(*xi >= -1e-6, "xi < 0: {}", xi);
680 }
681
682 let sum: f64 = x.iter().sum();
684 assert!(sum <= 1.0 + 1e-3, "x+y = {} > 1", sum);
685 }
686
687 #[test]
688 fn test_lp_layer_forward_feasibility() {
689 let mut layer = LpLayer::new();
690 let c = vec![-1.0, -1.0];
691 let a = vec![vec![1.0, 1.0], vec![-1.0, 0.0], vec![0.0, -1.0]];
692 let b = vec![1.0, 0.0, 0.0];
693
694 let result = layer.forward(c, a, b).expect("LP forward failed");
695
696 for xi in &result.x {
698 assert!(*xi >= -1e-5, "xi < 0: {}", xi);
699 }
700 }
701
702 #[test]
703 fn test_lp_layer_backward_shape() {
704 let mut layer = LpLayer::new();
705 let c = vec![-1.0, -1.0];
706 let a = vec![vec![1.0, 1.0]];
707 let b = vec![1.0];
708
709 let result = layer.forward(c, a, b).expect("LP forward failed");
710 let _ = result;
711
712 let grad = layer.backward(&[1.0, 1.0]).expect("LP backward failed");
713 assert_eq!(grad.dl_dc.len(), 2, "dl/dc should have length 2");
714 for gi in &grad.dl_dc {
715 assert!(gi.is_finite(), "dl/dc not finite");
716 }
717 }
718
719 #[test]
720 fn test_lp_layer_backward_no_forward_error() {
721 let layer = LpLayer::new();
722 let result = layer.backward(&[1.0]);
723 assert!(result.is_err(), "Should error without forward pass");
724 }
725
726 #[test]
727 fn test_lp_gradient_direction() {
728 let c = vec![0.0, 0.0];
731 let a: Vec<Vec<f64>> = vec![];
732 let b: Vec<f64> = vec![];
733 let epsilon = 0.1;
734 let n_iter = 100;
735 let tol = 1e-8;
736
737 let x_base = lp_perturbed(&c, &a, &b, epsilon, n_iter, tol).expect("LP failed");
738
739 let c_plus = vec![0.1, 0.0];
740 let x_plus = lp_perturbed(&c_plus, &a, &b, epsilon, n_iter, tol).expect("LP failed");
741
742 assert!(
745 x_plus[0] <= x_base[0] + 1e-3,
746 "x[0] should not increase when c[0] increases: base={}, plus={}",
747 x_base[0],
748 x_plus[0]
749 );
750 }
751
752 #[test]
753 fn test_invert_matrix_identity() {
754 let a = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
755 let a_inv = invert_matrix(&a).expect("Inversion failed");
756 assert!((a_inv[0][0] - 1.0).abs() < 1e-10);
757 assert!((a_inv[1][1] - 1.0).abs() < 1e-10);
758 assert!(a_inv[0][1].abs() < 1e-10);
759 assert!(a_inv[1][0].abs() < 1e-10);
760 }
761
762 #[test]
763 fn test_invert_matrix_2x2() {
764 let a = vec![vec![2.0, 1.0], vec![1.0, 3.0]];
766 let a_inv = invert_matrix(&a).expect("Inversion failed");
767 assert!((a_inv[0][0] - 0.6).abs() < 1e-10);
768 assert!((a_inv[0][1] - (-0.2)).abs() < 1e-10);
769 assert!((a_inv[1][0] - (-0.2)).abs() < 1e-10);
770 assert!((a_inv[1][1] - 0.4).abs() < 1e-10);
771 }
772
773 #[test]
774 fn test_lp_sensitivity_simple() {
775 let a = vec![vec![1.0, 1.0]]; let b_rhs = vec![1.0];
780 let c = vec![1.0, 2.0];
781 let basis = vec![0]; let sens = LpSensitivity::new(&a, &b_rhs, &c, basis).expect("LpSensitivity failed");
784
785 assert!((sens.x_basic[0] - 1.0).abs() < 1e-10);
787
788 let dl_dc = sens.dl_dc(2, &[1.0, 0.0]);
790 assert_eq!(dl_dc.len(), 2);
791 }
792
793 #[test]
794 fn test_lp_sensitivity_b_inv_correctness() {
795 let a = vec![vec![2.0, 0.0, 1.0], vec![0.0, 3.0, 1.0]];
797 let b_rhs = vec![4.0, 6.0];
798 let c = vec![1.0, 1.0, 0.0];
799 let basis = vec![0, 1]; let sens = LpSensitivity::new(&a, &b_rhs, &c, basis).expect("LpSensitivity failed");
802
803 assert!((sens.b_inv[0][0] - 0.5).abs() < 1e-10);
805 assert!((sens.b_inv[1][1] - 1.0 / 3.0).abs() < 1e-10);
806
807 assert!((sens.x_basic[0] - 2.0).abs() < 1e-10);
809 assert!((sens.x_basic[1] - 2.0).abs() < 1e-10);
810 }
811
812 #[test]
813 fn test_softmax_properties() {
814 let v = vec![1.0, 2.0, 3.0];
815 let s = softmax(&v);
816 let sum: f64 = s.iter().sum();
818 assert!((sum - 1.0).abs() < 1e-10);
819 for si in &s {
821 assert!(*si > 0.0);
822 }
823 assert!(s[2] > s[1] && s[1] > s[0]);
825 }
826
827 #[test]
828 fn test_logsumexp() {
829 let v = vec![0.0, 0.0, 0.0];
830 let lse = logsumexp(&v);
831 assert!((lse - 3.0_f64.ln()).abs() < 1e-10);
833 }
834}