1use crate::error::{OptimizeError, OptimizeResult};
40use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
41use scirs2_linalg::{cholesky, inv, solve, LinalgError};
42
43impl From<LinalgError> for OptimizeError {
46 fn from(e: LinalgError) -> Self {
47 OptimizeError::ComputationError(format!("linalg: {}", e))
48 }
49}
50
51#[inline]
55fn mat_inner(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
56 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
57}
58
59fn sym_product(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
61 let n = a.nrows();
62 let mut out = Array2::<f64>::zeros((n, n));
63 for i in 0..n {
64 for j in 0..n {
65 let mut v = 0.0_f64;
66 for k in 0..n {
67 v += a[[i, k]] * b[[k, j]] + b[[i, k]] * a[[k, j]];
68 }
69 out[[i, j]] = v * 0.5;
70 }
71 }
72 out
73}
74
75fn frobenius_norm(a: &Array2<f64>) -> f64 {
77 a.iter().map(|x| x * x).sum::<f64>().sqrt()
78}
79
80fn mat_inv(a: &Array2<f64>) -> OptimizeResult<Array2<f64>> {
82 inv(&a.view(), None).map_err(OptimizeError::from)
83}
84
85fn cholesky_lower(a: &Array2<f64>) -> OptimizeResult<Array2<f64>> {
87 let n = a.nrows();
88 let mut reg = a.clone();
90 let eps = 1e-14 * frobenius_norm(a).max(1.0);
91 for i in 0..n {
92 reg[[i, i]] += eps;
93 }
94 cholesky(®.view(), None).map_err(OptimizeError::from)
95}
96
97fn spd_inv(x: &Array2<f64>) -> OptimizeResult<Array2<f64>> {
99 mat_inv(x)
100}
101
102fn is_positive_definite(a: &Array2<f64>) -> bool {
104 cholesky_lower(a).is_ok()
105}
106
107fn regularise_pd(a: &mut Array2<f64>) {
109 let n = a.nrows();
110 let norm = frobenius_norm(a);
111 let delta = 1e-8 * norm.max(1.0);
112 for i in 0..n {
113 a[[i, i]] += delta;
114 }
115}
116
117#[derive(Debug, Clone)]
129pub struct SDPProblem {
130 pub c: Array2<f64>,
132 pub a: Vec<Array2<f64>>,
134 pub b: Array1<f64>,
136}
137
138impl SDPProblem {
139 pub fn new(c: Array2<f64>, a: Vec<Array2<f64>>, b: Array1<f64>) -> OptimizeResult<Self> {
141 let n = c.nrows();
142 if c.ncols() != n {
143 return Err(OptimizeError::ValueError(format!(
144 "C must be square, got {}×{}",
145 n,
146 c.ncols()
147 )));
148 }
149 let m = b.len();
150 if a.len() != m {
151 return Err(OptimizeError::ValueError(format!(
152 "Number of constraint matrices ({}) must equal len(b)={}",
153 a.len(),
154 m
155 )));
156 }
157 for (i, ai) in a.iter().enumerate() {
158 if ai.nrows() != n || ai.ncols() != n {
159 return Err(OptimizeError::ValueError(format!(
160 "Constraint matrix A[{}] is {}×{}, expected {}×{}",
161 i,
162 ai.nrows(),
163 ai.ncols(),
164 n,
165 n
166 )));
167 }
168 }
169 Ok(Self { c, a, b })
170 }
171
172 pub fn n(&self) -> usize {
174 self.c.nrows()
175 }
176
177 pub fn m(&self) -> usize {
179 self.b.len()
180 }
181}
182
183#[derive(Debug, Clone)]
187pub struct SDPSolverConfig {
188 pub max_iter: usize,
190 pub tol: f64,
192 pub mu_init: f64,
194 pub step_factor: f64,
196}
197
198impl Default for SDPSolverConfig {
199 fn default() -> Self {
200 Self {
201 max_iter: 200,
202 tol: 1e-7,
203 mu_init: 1.0,
204 step_factor: 0.95,
205 }
206 }
207}
208
209#[derive(Debug, Clone)]
213pub struct SDPResult {
214 pub x: Array2<f64>,
216 pub y: Array1<f64>,
218 pub s: Array2<f64>,
220 pub primal_obj: f64,
222 pub dual_obj: f64,
224 pub gap: f64,
226 pub n_iter: usize,
228 pub converged: bool,
230 pub message: String,
232}
233
234#[derive(Debug, Clone)]
238pub struct SDPSolver {
239 config: SDPSolverConfig,
240}
241
242impl SDPSolver {
243 pub fn new() -> Self {
245 Self {
246 config: SDPSolverConfig::default(),
247 }
248 }
249
250 pub fn with_config(config: SDPSolverConfig) -> Self {
252 Self { config }
253 }
254
255 pub fn solve(&self, problem: &SDPProblem) -> OptimizeResult<SDPResult> {
257 let n = problem.n();
258 let m = problem.m();
259
260 let mut x = Array2::<f64>::eye(n);
263 let mut s = Array2::<f64>::eye(n);
264 let mut y = Array1::<f64>::zeros(m);
265
266 let mut n_iter = 0usize;
267 let mut converged = false;
268 let mut message = String::from("maximum iterations reached");
269
270 for iter in 0..self.config.max_iter {
271 n_iter = iter + 1;
272
273 let rp = primal_residual(problem, &x);
276 let rd = dual_residual(problem, &y, &s);
278 let gap = sdp_duality_gap(&x, &s);
280
281 let rp_norm = rp.iter().map(|v| v * v).sum::<f64>().sqrt();
283 let rd_norm = frobenius_norm(&rd);
284 if gap.abs() < self.config.tol && rp_norm < self.config.tol && rd_norm < self.config.tol
285 {
286 converged = true;
287 message = format!(
288 "Converged in {} iterations (gap={:.2e}, rp={:.2e}, rd={:.2e})",
289 n_iter, gap, rp_norm, rd_norm
290 );
291 break;
292 }
293
294 let mu = mat_inner(&x, &s) / n as f64;
296
297 let x_inv = spd_inv(&x)?;
301 let schur = build_schur_complement(problem, &x, &s, &x_inv)?;
302
303 let (dx_aff, dy_aff, ds_aff) =
305 solve_newton_system(problem, &schur, &rp, &rd, &x, &s, &x_inv, 0.0, mu)?;
306
307 let alpha_aff_p = max_step_length_pd(&x, &dx_aff);
309 let alpha_aff_d = max_step_length_pd(&s, &ds_aff);
310 let alpha_aff = (alpha_aff_p.min(alpha_aff_d) * self.config.step_factor).min(1.0);
311
312 let mu_aff = mat_inner(
314 &(&x + &(&dx_aff * alpha_aff)),
315 &(&s + &(&ds_aff * alpha_aff)),
316 ) / n as f64;
317 let sigma = (mu_aff / mu.max(1e-15)).powi(3).min(1.0);
318
319 let (dx, dy, ds) =
321 solve_newton_system(problem, &schur, &rp, &rd, &x, &s, &x_inv, sigma * mu, mu)?;
322
323 let alpha_p = (max_step_length_pd(&x, &dx) * self.config.step_factor).min(1.0);
325 let alpha_d = (max_step_length_pd(&s, &ds) * self.config.step_factor).min(1.0);
326
327 primal_sdp_step(&mut x, &dx, alpha_p);
329 dual_sdp_step(&mut y, &mut s, &dy, &ds, alpha_d);
330
331 if !is_positive_definite(&x) {
333 regularise_pd(&mut x);
334 }
335 if !is_positive_definite(&s) {
336 regularise_pd(&mut s);
337 }
338 }
339
340 let primal_obj = mat_inner(&problem.c, &x);
341 let dual_obj = problem.b.iter().zip(y.iter()).map(|(bi, yi)| bi * yi).sum();
342 let gap = sdp_duality_gap(&x, &s);
343
344 Ok(SDPResult {
345 x,
346 y,
347 s,
348 primal_obj,
349 dual_obj,
350 gap,
351 n_iter,
352 converged,
353 message,
354 })
355 }
356}
357
358impl Default for SDPSolver {
359 fn default() -> Self {
360 Self::new()
361 }
362}
363
364fn primal_residual(problem: &SDPProblem, x: &Array2<f64>) -> Array1<f64> {
368 let m = problem.m();
369 let mut rp = Array1::<f64>::zeros(m);
370 for i in 0..m {
371 rp[i] = problem.b[i] - mat_inner(&problem.a[i], x);
372 }
373 rp
374}
375
376fn dual_residual(problem: &SDPProblem, y: &Array1<f64>, s: &Array2<f64>) -> Array2<f64> {
378 let n = problem.n();
379 let m = problem.m();
380 let mut rd = problem.c.clone();
381 for i in 0..m {
382 rd = rd - &(&problem.a[i] * y[i]);
383 }
384 rd = rd - s;
385 rd
386}
387
388fn build_schur_complement(
395 problem: &SDPProblem,
396 x: &Array2<f64>,
397 _s: &Array2<f64>,
398 _x_inv: &Array2<f64>,
399) -> OptimizeResult<Array2<f64>> {
400 let m = problem.m();
401 let n = problem.n();
402
403 let s_inv = spd_inv(_s)?;
411
412 let mut b_mats: Vec<Array2<f64>> = Vec::with_capacity(m);
414 for i in 0..m {
415 let bi = mat_mul(&problem.a[i], x);
416 b_mats.push(bi);
417 }
418
419 let mut c_mats: Vec<Array2<f64>> = Vec::with_capacity(m);
421 for bi in &b_mats {
422 let ci = mat_mul(bi, &s_inv);
423 c_mats.push(ci);
424 }
425
426 let mut m_mat = Array2::<f64>::zeros((m, m));
429 for i in 0..m {
430 for j in i..m {
431 let mut v = 0.0_f64;
432 for r in 0..n {
433 for c in 0..n {
434 v += c_mats[i][[r, c]] * problem.a[j][[c, r]];
435 }
436 }
437 m_mat[[i, j]] = v;
438 m_mat[[j, i]] = v;
439 }
440 }
441
442 let eps = 1e-12 * frobenius_norm(&m_mat).max(1.0);
444 for i in 0..m {
445 m_mat[[i, i]] += eps;
446 }
447
448 Ok(m_mat)
449}
450
451fn solve_newton_system(
462 problem: &SDPProblem,
463 schur: &Array2<f64>,
464 rp: &Array1<f64>,
465 rd: &Array2<f64>,
466 x: &Array2<f64>,
467 s: &Array2<f64>,
468 _x_inv: &Array2<f64>,
469 sigma_mu: f64,
470 _mu: f64,
471) -> OptimizeResult<(Array2<f64>, Array1<f64>, Array2<f64>)> {
472 let n = problem.n();
473 let m = problem.m();
474
475 let s_inv = spd_inv(s)?;
476
477 let rd_sinv = mat_mul(rd, &s_inv);
484 let t = mat_mul(x, &rd_sinv);
485
486 let mut rhs = Array1::<f64>::zeros(m);
487 for i in 0..m {
488 let mut tr_t = 0.0_f64;
490 let mut tr_sinv = 0.0_f64;
492 for r in 0..n {
493 for c in 0..n {
494 tr_t += problem.a[i][[r, c]] * t[[c, r]];
495 tr_sinv += problem.a[i][[r, c]] * s_inv[[c, r]];
496 }
497 }
498 rhs[i] = problem.b[i] + tr_t - sigma_mu * tr_sinv;
499 }
500
501 let dy = solve(&schur.view(), &rhs.view(), None)?;
503
504 let mut ds = rd.clone();
506 for i in 0..m {
507 ds = ds - &(&problem.a[i] * dy[i]);
508 }
509
510 let xs = mat_mul(x, s);
513 let x_ds = mat_mul(x, &ds);
514 let n_mat = {
515 let mut nm = Array2::<f64>::zeros((n, n));
516 for i in 0..n {
517 for j in 0..n {
518 let diag = if i == j { sigma_mu } else { 0.0 };
519 nm[[i, j]] = diag - xs[[i, j]] - x_ds[[i, j]];
520 }
521 }
522 nm
523 };
524 let tmp = mat_mul(&n_mat, &s_inv);
525 let dx = {
527 let mut d = Array2::<f64>::zeros((n, n));
528 for i in 0..n {
529 for j in 0..n {
530 d[[i, j]] = (tmp[[i, j]] + tmp[[j, i]]) * 0.5;
531 }
532 }
533 d
534 };
535
536 Ok((dx, dy, ds))
537}
538
539fn mat_mul(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
541 let (m, k) = (a.nrows(), a.ncols());
542 let l = b.ncols();
543 let mut c = Array2::<f64>::zeros((m, l));
544 for i in 0..m {
545 for j in 0..l {
546 let mut v = 0.0_f64;
547 for p in 0..k {
548 v += a[[i, p]] * b[[p, j]];
549 }
550 c[[i, j]] = v;
551 }
552 }
553 c
554}
555
556pub fn primal_sdp_step(x: &mut Array2<f64>, dx: &Array2<f64>, alpha: f64) {
565 let n = x.nrows();
566 for i in 0..n {
567 for j in 0..n {
568 x[[i, j]] += alpha * dx[[i, j]];
569 }
570 }
571}
572
573pub fn dual_sdp_step(
582 y: &mut Array1<f64>,
583 s: &mut Array2<f64>,
584 dy: &Array1<f64>,
585 ds: &Array2<f64>,
586 alpha: f64,
587) {
588 let m = y.len();
589 for i in 0..m {
590 y[i] += alpha * dy[i];
591 }
592 let n = s.nrows();
593 for i in 0..n {
594 for j in 0..n {
595 s[[i, j]] += alpha * ds[[i, j]];
596 }
597 }
598}
599
600pub fn sdp_duality_gap(x: &Array2<f64>, s: &Array2<f64>) -> f64 {
604 mat_inner(x, s)
605}
606
607fn max_step_length_pd(m: &Array2<f64>, dm: &Array2<f64>) -> f64 {
612 if dm.iter().all(|&v| v.abs() < 1e-15) {
614 return 1.0;
615 }
616
617 let mut lo = 0.0_f64;
620 let mut hi = 1.0_f64;
621
622 let full = m + &(dm * 1.0_f64);
624 if is_positive_definite(&full) {
625 return 1.0;
626 }
627
628 for _ in 0..30 {
630 let mid = (lo + hi) * 0.5;
631 let trial = m + &(dm * mid);
632 if is_positive_definite(&trial) {
633 lo = mid;
634 } else {
635 hi = mid;
636 }
637 }
638 lo
639}
640
641#[derive(Debug, Clone)]
645pub struct MaxCutSdpResult {
646 pub sdp_matrix: Array2<f64>,
648 pub sdp_value: f64,
650 pub cut: Vec<i8>,
652 pub cut_value: f64,
654 pub converged: bool,
656}
657
658pub fn max_cut_sdp(w: &ArrayView2<f64>) -> OptimizeResult<MaxCutSdpResult> {
678 let n = w.nrows();
679 if w.ncols() != n {
680 return Err(OptimizeError::ValueError(
681 "Weight matrix must be square".into(),
682 ));
683 }
684
685 let c = w.map(|&v| v * 0.25);
689 let b = Array1::<f64>::ones(n);
690
691 let mut a_mats: Vec<Array2<f64>> = Vec::with_capacity(n);
693 for k in 0..n {
694 let mut ak = Array2::<f64>::zeros((n, n));
695 ak[[k, k]] = 1.0;
696 a_mats.push(ak);
697 }
698
699 let problem = SDPProblem::new(c, a_mats, b)?;
700 let solver = SDPSolver::new();
701 let result = solver.solve(&problem)?;
702
703 let x_mat = &result.x;
705 let v = power_iteration(x_mat, 50);
707 let cut: Vec<i8> = v.iter().map(|&vi| if vi >= 0.0 { 1 } else { -1 }).collect();
708
709 let mut cut_value = 0.0_f64;
711 for i in 0..n {
712 for j in (i + 1)..n {
713 if cut[i] != cut[j] {
714 cut_value += w[[i, j]];
715 }
716 }
717 }
718
719 let sdp_value = result.primal_obj;
720
721 Ok(MaxCutSdpResult {
722 sdp_matrix: result.x,
723 sdp_value,
724 cut,
725 cut_value,
726 converged: result.converged,
727 })
728}
729
730fn power_iteration(a: &Array2<f64>, iters: usize) -> Array1<f64> {
737 let n = a.nrows();
738
739 let mut v = Array1::<f64>::zeros(n);
742 for (i, vi) in v.iter_mut().enumerate() {
743 *vi = 1.0 + 0.1 * i as f64;
744 }
745 let v_norm = v.iter().map(|x| x * x).sum::<f64>().sqrt().max(1e-15);
746 for vi in v.iter_mut() {
747 *vi /= v_norm;
748 }
749
750 for restart in 0..3 {
751 let mut cur = v.clone();
752
753 for _ in 0..iters {
754 let mut w = Array1::<f64>::zeros(n);
755 for i in 0..n {
756 for j in 0..n {
757 w[i] += a[[i, j]] * cur[j];
758 }
759 }
760 let w_norm = w.iter().map(|x| x * x).sum::<f64>().sqrt();
761 if w_norm < 1e-14 {
762 break;
764 }
765 for wi in w.iter_mut() {
766 *wi /= w_norm;
767 }
768 cur = w;
769 }
770
771 let cur_norm = cur.iter().map(|x| x * x).sum::<f64>().sqrt();
773 if cur_norm > 0.5 {
774 return cur;
775 }
776
777 for (i, vi) in v.iter_mut().enumerate() {
779 *vi = 1.0 + (restart as f64 + 1.0) * 0.37 + 0.13 * i as f64;
780 }
781 let vn = v.iter().map(|x| x * x).sum::<f64>().sqrt().max(1e-15);
782 for vi in v.iter_mut() {
783 *vi /= vn;
784 }
785 }
786
787 v
789}
790
791#[derive(Debug, Clone)]
795pub struct MatrixCompletionSdpResult {
796 pub completed: Array2<f64>,
798 pub sdp_value: f64,
800 pub converged: bool,
802}
803
804pub fn matrix_completion_sdp(
823 p: usize,
824 q: usize,
825 observed: &[(usize, usize, f64)],
826) -> OptimizeResult<MatrixCompletionSdpResult> {
827 let nn = p + q;
832
833 let c = Array2::<f64>::eye(nn) * 0.5;
836
837 let mut a_mats: Vec<Array2<f64>> = Vec::new();
843 let mut b_vals: Vec<f64> = Vec::new();
844
845 for &(row, col, val) in observed {
846 if row >= p || col >= q {
847 return Err(OptimizeError::ValueError(format!(
848 "Observed entry ({}, {}) out of range ({}, {})",
849 row, col, p, q
850 )));
851 }
852 let col_lifted = p + col;
853 let mut ak = Array2::<f64>::zeros((nn, nn));
854 ak[[row, col_lifted]] = 0.5;
855 ak[[col_lifted, row]] = 0.5;
856 a_mats.push(ak);
857 b_vals.push(val);
858 }
859
860 if a_mats.is_empty() {
862 return Ok(MatrixCompletionSdpResult {
863 completed: Array2::<f64>::zeros((p, q)),
864 sdp_value: 0.0,
865 converged: true,
866 });
867 }
868
869 let b = Array1::from_vec(b_vals);
870 let problem = SDPProblem::new(c, a_mats, b)?;
871
872 let mut config = SDPSolverConfig::default();
873 config.tol = 1e-5; let solver = SDPSolver::with_config(config);
875 let result = solver.solve(&problem)?;
876
877 let mut completed = Array2::<f64>::zeros((p, q));
879 for i in 0..p {
880 for j in 0..q {
881 completed[[i, j]] = result.x[[i, p + j]];
882 }
883 }
884
885 Ok(MatrixCompletionSdpResult {
886 completed,
887 sdp_value: result.primal_obj,
888 converged: result.converged,
889 })
890}
891
892#[cfg(test)]
895mod tests {
896 use super::*;
897 use approx::assert_abs_diff_eq;
898
899 #[test]
900 fn test_sdp_duality_gap_zero() {
901 let x = Array2::<f64>::eye(3);
902 let s = Array2::<f64>::zeros((3, 3));
903 assert_abs_diff_eq!(sdp_duality_gap(&x, &s), 0.0, epsilon = 1e-12);
904 }
905
906 #[test]
907 fn test_sdp_duality_gap_positive() {
908 let x = Array2::<f64>::eye(2);
909 let s = Array2::<f64>::eye(2);
910 assert_abs_diff_eq!(sdp_duality_gap(&x, &s), 2.0, epsilon = 1e-12);
912 }
913
914 #[test]
915 fn test_primal_sdp_step() {
916 let mut x = Array2::<f64>::eye(2);
917 let dx = Array2::<f64>::eye(2) * 0.5;
918 primal_sdp_step(&mut x, &dx, 0.2);
919 assert_abs_diff_eq!(x[[0, 0]], 1.1, epsilon = 1e-12);
921 }
922
923 #[test]
924 fn test_dual_sdp_step() {
925 let mut y = Array1::<f64>::zeros(2);
926 let dy = Array1::from_vec(vec![1.0, -1.0]);
927 let mut s = Array2::<f64>::eye(2);
928 let ds = Array2::<f64>::eye(2) * (-0.5_f64);
929 dual_sdp_step(&mut y, &mut s, &dy, &ds, 0.5);
930 assert_abs_diff_eq!(y[0], 0.5, epsilon = 1e-12);
931 assert_abs_diff_eq!(y[1], -0.5, epsilon = 1e-12);
932 assert_abs_diff_eq!(s[[0, 0]], 0.75, epsilon = 1e-12);
933 }
934
935 #[test]
936 fn test_sdp_simple_1d() {
937 let c = Array2::<f64>::eye(1);
939 let mut a0 = Array2::<f64>::zeros((1, 1));
940 a0[[0, 0]] = 1.0;
941 let b = Array1::from_vec(vec![1.0]);
942 let problem = SDPProblem::new(c, vec![a0], b).expect("valid problem");
943 let solver = SDPSolver::new();
944 let result = solver.solve(&problem).expect("solver should not fail");
945 assert_abs_diff_eq!(result.primal_obj, 1.0, epsilon = 1e-4);
946 }
947
948 #[test]
949 fn test_max_cut_sdp_triangle() {
950 let mut w = Array2::<f64>::zeros((3, 3));
952 w[[0, 1]] = 1.0;
953 w[[1, 0]] = 1.0;
954 w[[0, 2]] = 1.0;
955 w[[2, 0]] = 1.0;
956 w[[1, 2]] = 1.0;
957 w[[2, 1]] = 1.0;
958
959 let result = max_cut_sdp(&w.view()).expect("max_cut_sdp should not fail");
960 assert!(result.cut_value >= 1.0, "Cut value should be at least 1");
962 }
963
964 #[test]
965 fn test_matrix_completion_simple() {
966 let observed = vec![(0, 0, 1.0), (1, 1, 1.0)];
968 let result =
969 matrix_completion_sdp(2, 2, &observed).expect("matrix_completion_sdp should not fail");
970 assert!(result.completed[[0, 0]].is_finite());
972 assert!(result.completed[[1, 1]].is_finite());
973 }
974}