1use num_traits::{Float, Zero, One};
4use core::marker::PhantomData;
5use core::fmt::{Debug, LowerExp};
6use crate::solver::{SliceLike, LinAlg, Operator, Cone, SolverError};
7use crate::{splitm, splitm_mut};
8
9#[derive(Debug, Clone, PartialEq)]
13pub struct SolverParam<F: Float>
14{
15 pub max_iter: Option<usize>,
17 pub eps_acc: F,
19 pub eps_inf: F,
21 pub eps_zero: F,
23 pub log_period: usize,
25}
26
27impl<F: Float> Default for SolverParam<F>
28{
29 fn default() -> Self
30 {
31 let ten = F::from(10).unwrap();
32
33 SolverParam {
34 max_iter: None,
35 eps_acc: ten.powi(-6),
36 eps_inf: ten.powi(-6),
37 eps_zero: ten.powi(-12),
38 log_period: 10_000,
39 }
40 }
41}
42
43struct SelfDualEmbed<L, OC, OA, OB>
46where L: LinAlg, OC: Operator<L>, OA: Operator<L>, OB: Operator<L>
47{
48 ph_l: PhantomData<L>,
49 c: OC,
50 a: OA,
51 b: OB,
52}
53
54impl<L, OC, OA, OB> SelfDualEmbed<L, OC, OA, OB>
55where L: LinAlg, OC: Operator<L>, OA: Operator<L>, OB: Operator<L>
56{
57 fn c(&self) -> &OC
58 {
59 &self.c
60 }
61
62 fn a(&self) -> &OA
63 {
64 &self.a
65 }
66
67 fn b(&self) -> &OB
68 {
69 &self.b
70 }
71
72 fn norm_b(&self,
73 work_v: &mut L::Sl, work_t: &mut L::Sl) -> L::F
74 {
75 Self::fr_norm(self.b(), work_v, work_t)
76 }
77
78 fn norm_c(&self,
79 work_v: &mut L::Sl, work_t: &mut L::Sl) -> L::F
80 {
81 Self::fr_norm(self.c(), work_v, work_t)
82 }
83
84 fn fr_norm<O: Operator<L>>(
86 op: &O,
87 work_v: &mut L::Sl, work_t: &mut L::Sl) -> L::F
88 {
89 assert_eq!(work_v.len(), op.size().1);
90 assert_eq!(work_t.len(), op.size().0);
91
92 let f0 = L::F::zero();
93 let f1 = L::F::one();
94
95 L::scale(f0, work_v);
96 let mut sq_norm = f0;
97
98 for row in 0.. op.size().1 {
99 work_v.set(row, f1);
100 op.op(f1, work_v, f0, work_t);
101 let n = L::norm(work_t);
102 sq_norm = sq_norm + n * n;
103 work_v.set(row, f0);
104 }
105
106 sq_norm.sqrt()
107 }
108
109 fn op(&self, alpha: L::F, x: &L::Sl, beta: L::F, y: &mut L::Sl)
110 {
111 let (m, n) = self.a.size();
112
113 assert_eq!(x.len(), n + m + m + 1);
114 assert_eq!(y.len(), n + m + 1);
115
116 splitm!(x, (x_x; n), (x_y; m), (x_s; m), (x_tau; 1));
117
118 splitm_mut!(y, (y_n; n), (y_m; m), (y_1; 1));
119
120 let f1 = L::F::one();
121
122 self.a.trans_op(alpha, &x_y, beta, &mut y_n);
123 self.c.op(alpha, &x_tau, f1, &mut y_n);
124
125 self.a.op(-alpha, &x_x, beta, &mut y_m);
126 L::add(-alpha, &x_s, &mut y_m);
127 self.b.op(alpha, &x_tau, f1, &mut y_m);
128
129 self.c.trans_op(-alpha, &x_x, beta, &mut y_1);
130 self.b.trans_op(-alpha, &x_y, f1, &mut y_1);
131 }
132
133 fn trans_op(&self, alpha: L::F, x: &L::Sl, beta: L::F, y: &mut L::Sl)
134 {
135 let (m, n) = self.a.size();
136
137 assert_eq!(x.len(), n + m + 1);
138 assert_eq!(y.len(), n + m + m + 1);
139
140 splitm!(x, (x_n; n), (x_m; m), (x_1; 1));
141
142 splitm_mut!(y, (y_x; n), (y_y; m), (y_s; m), (y_tau; 1));
143
144 let f1 = L::F::one();
145
146 self.a.trans_op(-alpha, &x_m, beta, &mut y_x);
147 self.c.op(-alpha, &x_1, f1, &mut y_x);
148
149 self.a.op(alpha, &x_n, beta, &mut y_y);
150 self.b.op(-alpha, &x_1, f1, &mut y_y);
151
152 L::scale(beta, &mut y_s);
153 L::add(-alpha, &x_m, &mut y_s);
154
155 self.c.trans_op(alpha, &x_n, beta, &mut y_tau);
156 self.b.trans_op(alpha, &x_m, f1, &mut y_tau);
157 }
158
159 fn abssum(&self, tau: &mut L::Sl, sigma: &mut L::Sl)
160 {
161 let (m, n) = self.a.size();
162 let f0 = L::F::zero();
163 let f1 = L::F::one();
164
165 L::scale(f0, tau);
166
167 splitm_mut!(tau, (tau_x; n), (tau_y; m), (tau_s; m), (tau_tau; 1));
168
169 self.a.absadd_cols(&mut tau_x);
170 self.c.absadd_rows(&mut tau_x);
171 self.a.absadd_rows(&mut tau_y);
172 self.b.absadd_rows(&mut tau_y);
173 L::adds(f1, &mut tau_s);
174 self.c.absadd_cols(&mut tau_tau);
175 self.b.absadd_cols(&mut tau_tau);
176
177 splitm_mut!(sigma, (sigma_n; n), (sigma_m; m), (sigma_1; 1));
178
179 L::copy(&tau_x, &mut sigma_n);
180 L::copy(&tau_y, &mut sigma_m);
181 L::add(f1, &tau_s, &mut sigma_m);
182 L::copy(&tau_tau, &mut sigma_1);
183 }
184}
185
186pub struct Solver<L: LinAlg>
220{
221 pub par: SolverParam<L::F>,
223}
224
225impl<L: LinAlg> Solver<L>
226{
227 pub fn query_worklen(op_a_size: (usize, usize)) -> usize
232 {
233 let (m, n) = op_a_size;
234
235 let len_iteration =
236 n + m + m + 1 + n + m + 1 + n + m + m + 1 + n + m + 1 + n + m + m + 1 + n + m + m + 1; len_iteration
249 }
250
251 pub fn new() -> Self
255 {
256 Solver {
257 par: SolverParam::default(),
258 }
259 }
260
261 pub fn par<P>(mut self, f: P) -> Self
266 where P: FnOnce(&mut SolverParam<L::F>)
267 {
268 f(&mut self.par);
269 self
270 }
271}
272
273impl<L: LinAlg> Solver<L>
274where L::F: Float + Debug + LowerExp
275{
276 pub fn solve<OC, OA, OB, C>(self,
286 (op_c, op_a, op_b, cone, work): (OC, OA, OB, C, &mut[L::F])
287 ) -> Result<(&[L::F], &[L::F]), SolverError>
288 where OC: Operator<L>, OA: Operator<L>, OB: Operator<L>, C: Cone<L>
289 {
290 let (m, n) = op_a.size();
291
292 if op_c.size() != (n, 1) || op_b.size() != (m, 1) {
293 log::error!("Size mismatch: op_c{:?}, op_a{:?}, op_b{:?}", op_c.size(), op_a.size(), op_b.size());
294 return Err(SolverError::InvalidOp);
295 }
296
297 if Self::query_worklen((m, n)) > work.len() {
298 log::error!("Work memory length {} must be >= {}", work.len(), Self::query_worklen((m, n)));
299 return Err(SolverError::WorkShortage);
300 }
301
302 log::debug!("{:?}", self.par);
303
304 let op_k = SelfDualEmbed {
305 ph_l: PhantomData::<L>,
306 c: op_c, a: op_a, b: op_b
307 };
308
309 let core = SolverCore {
310 par: self.par,
311 op_k,
312 cone,
313 };
314
315 let rslt = core.solve(&mut L::Sl::new_mut(work));
316
317 let (sol_x, rest) = work.split_at(n);
318 let (sol_y, _) = rest.split_at(m);
319
320 rslt.map(|_| {(sol_x, sol_y)})
321 }
322}
323
324struct SolverCore<L, OC, OA, OB, C>
327where L: LinAlg, L::F: Float + Debug + LowerExp,
328 OC: Operator<L>, OA: Operator<L>, OB: Operator<L>, C: Cone<L>
329{
330 par: SolverParam<L::F>,
331
332 op_k: SelfDualEmbed<L, OC, OA, OB>,
333 cone: C,
334}
335
336impl<L, OC, OA, OB, C> SolverCore<L, OC, OA, OB, C>
337where L: LinAlg, L::F: Float + Debug + LowerExp,
338 OC: Operator<L>, OA: Operator<L>, OB: Operator<L>, C: Cone<L>
339{
340 fn solve(mut self, work: &mut L::Sl) -> Result<(), SolverError>
341 {
342 log::info!("----- Initializing");
343 let (m, n) = self.op_k.a().size();
344
345 let (norm_b, norm_c) = self.calc_norms(work);
347
348 splitm_mut!(work,
350 (x; n + m + m + 1),
351 (y; n + m + 1),
352 (dp_tau; n + m + m + 1),
353 (dp_sigma; n + m + 1),
354 (tmpw; (n + m + m + 1) * 2)
355 );
356 self.init_vecs(&mut x, &mut y);
357
358 self.calc_precond(&mut dp_tau, &mut dp_sigma);
360
361 log::info!("----- Started");
363 let mut i = 0;
364 loop {
365 let excess_iter = if let Some(max_iter) = self.par.max_iter {
366 i + 1 >= max_iter
367 } else {
368 false
369 };
370
371 let log_trig = if self.par.log_period > 0 {
372 i % self.par.log_period == 0
373 }
374 else {
375 if i == 0 && log::log_enabled!(log::Level::Debug) {
376 log::warn!("log_period == 0: no periodic log");
377 }
378 false
379 };
380
381 let val_tau = self.update_vecs(&mut x, &mut y, &dp_tau, &dp_sigma, &mut tmpw)?;
383
384 if val_tau > self.par.eps_zero {
385 let (cri_pri, cri_dual, cri_gap) = self.criteria_conv(&x, norm_c, norm_b, &mut tmpw);
387
388 let term_conv = (cri_pri <= self.par.eps_acc) && (cri_dual <= self.par.eps_acc) && (cri_gap <= self.par.eps_acc);
389
390 if log_trig || excess_iter || term_conv {
391 log::debug!("{}: pri_dual_gap {:.2e} {:.2e} {:.2e}", i, cri_pri, cri_dual, cri_gap);
392 }
393 else {
394 log::trace!("{}: pri_dual_gap {:.2e} {:.2e} {:.2e}", i, cri_pri, cri_dual, cri_gap);
395 }
396
397 if excess_iter || term_conv {
398 splitm_mut!(x, (x_x_ast; n), (x_y_ast; m));
399 L::scale(val_tau.recip(), &mut x_x_ast);
400 L::scale(val_tau.recip(), &mut x_y_ast);
401
402 log::trace!("{}: x {:?}", i, x_x_ast.get_ref());
403 log::trace!("{}: y {:?}", i, x_y_ast.get_ref());
404
405 if term_conv {
406 log::info!("----- Converged");
407
408 return Ok(());
409 }
410 else {
411 log::warn!("----- ExcessIter");
412
413 return Err(SolverError::ExcessIter);
414 }
415 }
416 }
417 else {
418 let (cri_unbdd, cri_infeas) = self.criteria_inf(&x, norm_c, norm_b, &mut tmpw);
420
421 let term_unbdd = cri_unbdd <= self.par.eps_inf;
422 let term_infeas = cri_infeas <= self.par.eps_inf;
423
424 if log_trig || excess_iter || term_unbdd || term_infeas {
425 log::debug!("{}: unbdd_infeas {:.2e} {:.2e}", i, cri_unbdd, cri_infeas);
426 }
427 else {
428 log::trace!("{}: unbdd_infeas {:.2e} {:.2e}", i, cri_unbdd, cri_infeas);
429 }
430
431 if excess_iter || term_unbdd || term_infeas {
432 splitm!(x, (x_x_cert; n), (x_y_cert; m));
433
434 log::trace!("{}: x {:?}", i, x_x_cert.get_ref());
435 log::trace!("{}: y {:?}", i, x_y_cert.get_ref());
436
437 if term_unbdd {
438 log::warn!("----- Unbounded");
439
440 return Err(SolverError::Unbounded);
441 }
442 else if term_infeas {
443 log::warn!("----- Infeasible");
444
445 return Err(SolverError::Infeasible);
446 }
447 else {
448 log::warn!("----- ExcessIter");
449
450 return Err(SolverError::ExcessIter);
451 }
452 }
453 }
454
455 i += 1;
456 assert!(!excess_iter);
457 } }
459
460 fn calc_norms(&mut self, work: &mut L::Sl)
461 -> (L::F, L::F)
462 {
463 let mut work1 = [L::F::zero()];
464 let mut work_one = L::Sl::new_mut(&mut work1);
465
466 let norm_b = {
467 let (m, _) = self.op_k.b().size();
468 splitm_mut!(work, (t; m));
469
470 self.op_k.norm_b(&mut work_one, &mut t)
471 };
472
473 let norm_c = {
474 let (n, _) = self.op_k.c().size();
475 splitm_mut!(work, (t; n));
476
477 self.op_k.norm_c(&mut work_one, &mut t)
478 };
479
480 (norm_b, norm_c)
481 }
482
483 fn init_vecs(&self, x: &mut L::Sl, y: &mut L::Sl)
484 {
485 let (m, n) = self.op_k.a().size();
486
487 let f0 = L::F::zero();
488 let f1 = L::F::one();
489
490 L::scale(f0, x);
491 L::scale(f0, y);
492
493 x.set(n + m + m, f1); }
495
496 fn calc_precond(&self, dp_tau: &mut L::Sl, dp_sigma: &mut L::Sl)
497 {
498 let (m, n) = self.op_k.a().size();
499
500 self.op_k.abssum(dp_tau, dp_sigma);
501 for tau in dp_tau.get_mut() {
502 *tau = (*tau).max(self.par.eps_zero).recip();
503 }
504 for sigma in dp_sigma.get_mut() {
505 *sigma = (*sigma).max(self.par.eps_zero).recip();
506 }
507
508 let group = |tau_group: &mut L::Sl| {
510 if tau_group.len() > 0 {
511 let tau_group_mut = tau_group.get_mut();
512 let mut min_t = tau_group_mut[0];
513 for t in tau_group_mut.iter() {
514 min_t = min_t.min(*t);
515 }
516 for t in tau_group_mut.iter_mut() {
517 *t = min_t;
518 }
519 }
520 };
521 splitm_mut!(dp_tau, (_dpt_n; n), (dpt_dual_cone; m), (dpt_cone; m), (_dpt_1; 1));
522 self.cone.product_group(&mut dpt_dual_cone, group);
523 self.cone.product_group(&mut dpt_cone, group);
524 }
525
526 fn update_vecs(&mut self, x: &mut L::Sl, y: &mut L::Sl, dp_tau: &L::Sl, dp_sigma: &L::Sl, tmpw: &mut L::Sl)
527 -> Result<L::F, SolverError>
528 {
529 let (m, n) = self.op_k.a().size();
530
531 splitm_mut!(tmpw, (rx; x.len()), (tx; x.len()));
532
533 let val_tau;
534
535 let f0 = L::F::zero();
536 let f1 = L::F::one();
537
538 L::copy(x, &mut rx); { self.op_k.trans_op(-f1, y, f0, &mut tx);
542 L::transform_di(f1, dp_tau, &tx, f1, x);
543 }
544
545 { splitm_mut!(x, (_x_x; n), (x_y; m), (x_s; m), (x_tau; 1));
547
548 self.cone.proj(true, &mut x_y).or(Err(SolverError::ConeFailure))?;
549 self.cone.proj(false, &mut x_s).or(Err(SolverError::ConeFailure))?;
550
551 val_tau = x_tau.get(0).max(f0);
552 x_tau.set(0, val_tau);
553 }
554
555 L::add(-f1-f1, x, &mut rx); { splitm_mut!(tx, (ty; y.len()));
559 self.op_k.op(-f1, &rx, f0, &mut ty);
560 L::transform_di(f1, dp_sigma, &ty, f1, y);
561 }
562
563 { splitm_mut!(y, (_y_nm; n + m), (y_1; 1));
565
566 let kappa = y_1.get(0).min(f0);
567 y_1.set(0, kappa);
568 }
569
570 Ok(val_tau)
571 }
572
573 fn criteria_conv(&self, x: &L::Sl, norm_c: L::F, norm_b: L::F, tmpw: &mut L::Sl)
574 -> (L::F, L::F, L::F)
575 {
576 let (m, n) = self.op_k.a().size();
577
578 splitm!(x, (x_x; n), (x_y; m), (x_s; m), (x_tau; 1));
579 splitm_mut!(tmpw, (p; m), (d; n));
580
581 let f0 = L::F::zero();
582 let f1 = L::F::one();
583
584 let val_tau = x_tau.get(0);
585 assert!(val_tau > f0);
586
587 let mut work1 = [f1];
588 let mut work_one = L::Sl::new_mut(&mut work1);
589
590 L::copy(&x_s, &mut p);
593 self.op_k.b().op(-f1, &work_one, val_tau.recip(), &mut p);
594 self.op_k.a().op(val_tau.recip(), &x_x, f1, &mut p);
595
596 self.op_k.c().op(f1, &work_one, f0, &mut d);
597 self.op_k.a().trans_op(val_tau.recip(), &x_y, f1, &mut d);
598
599 self.op_k.c().trans_op(val_tau.recip(), &x_x, f0, &mut work_one);
600 let g_x = work_one.get(0);
601
602 self.op_k.b().trans_op(val_tau.recip(), &x_y, f0, &mut work_one);
603 let g_y = work_one.get(0);
604
605 let g = g_x + g_y;
606
607 let cri_pri = L::norm(&p) / (f1 + norm_b);
608 let cri_dual = L::norm(&d) / (f1 + norm_c);
609 let cri_gap = g.abs() / (f1 + g_x.abs() + g_y.abs());
610
611 (cri_pri, cri_dual, cri_gap)
612 }
613
614 fn criteria_inf(&self, x: &L::Sl, norm_c: L::F, norm_b: L::F, tmpw: &mut L::Sl)
615 -> (L::F, L::F)
616 {
617 let (m, n) = self.op_k.a().size();
618
619 splitm!(x, (x_x; n), (x_y; m), (x_s; m), (_x_tau; 1));
620 splitm_mut!(tmpw, (p; m), (d; n));
621
622 let f0 = L::F::zero();
623 let f1 = L::F::one();
624 let finf = L::F::infinity();
625
626 let mut work1 = [f0];
627 let mut work_one = L::Sl::new_mut(&mut work1);
628
629 L::copy(&x_s, &mut p);
632 self.op_k.a().op(f1, &x_x, f1, &mut p);
633
634 self.op_k.a().trans_op(f1, &x_y, f0, &mut d);
635
636 self.op_k.c().trans_op(-f1, &x_x, f0, &mut work_one);
637 let m_cx = work_one.get(0);
638
639 self.op_k.b().trans_op(-f1, &x_y, f0, &mut work_one);
640 let m_by = work_one.get(0);
641
642 let cri_unbdd = if m_cx > self.par.eps_zero {
643 L::norm(&p) * norm_c / m_cx
644 }
645 else {
646 finf
647 };
648 let cri_infeas = if m_by > self.par.eps_zero {
649 L::norm(&d) * norm_b / m_by
650 }
651 else {
652 finf
653 };
654
655 (cri_unbdd, cri_infeas)
656 }
657}