scirs2_optimize/distributed_admm/
pdmm_extra.rs1use super::types::{AdmmResult, ExtraConfig, PdmmConfig};
36use crate::error::{OptimizeError, OptimizeResult};
37
38fn norm2(v: &[f64]) -> f64 {
43 v.iter().map(|x| x * x).sum::<f64>().sqrt()
44}
45
46fn vec_add(a: &[f64], b: &[f64]) -> Vec<f64> {
47 a.iter().zip(b.iter()).map(|(ai, bi)| ai + bi).collect()
48}
49
50fn vec_sub(a: &[f64], b: &[f64]) -> Vec<f64> {
51 a.iter().zip(b.iter()).map(|(ai, bi)| ai - bi).collect()
52}
53
54fn vec_scale(a: &[f64], s: f64) -> Vec<f64> {
55 a.iter().map(|ai| ai * s).collect()
56}
57
58fn mat_vec(w: &[Vec<f64>], x: &[f64]) -> Vec<f64> {
60 w.iter()
61 .map(|row| row.iter().zip(x.iter()).map(|(wi, xi)| wi * xi).sum())
62 .collect()
63}
64
65fn check_doubly_stochastic(w: &[Vec<f64>], tol: f64) -> bool {
67 let n = w.len();
68 for row in w.iter() {
70 if row.len() != n {
71 return false;
72 }
73 let s: f64 = row.iter().sum();
74 if (s - 1.0).abs() > tol {
75 return false;
76 }
77 }
78 for j in 0..n {
80 let s: f64 = w.iter().map(|row| row[j]).sum();
81 if (s - 1.0).abs() > tol {
82 return false;
83 }
84 }
85 true
86}
87
88#[derive(Debug)]
97pub struct PdmmSolver {
98 pub topology: Vec<Vec<f64>>,
100}
101
102impl PdmmSolver {
103 pub fn new(topology: Vec<Vec<f64>>) -> OptimizeResult<Self> {
105 let n = topology.len();
106 for (i, row) in topology.iter().enumerate() {
107 if row.len() != n {
108 return Err(OptimizeError::InvalidInput(format!(
109 "Topology row {} has length {} but expected {}",
110 i,
111 row.len(),
112 n
113 )));
114 }
115 }
116 Ok(Self { topology })
117 }
118
119 pub fn solve<F>(
124 &self,
125 local_fns: &[F],
126 n_vars: usize,
127 config: &PdmmConfig,
128 ) -> OptimizeResult<AdmmResult>
129 where
130 F: Fn(&[f64], f64) -> Vec<f64>,
131 {
132 let n_agents = self.topology.len();
133 if local_fns.len() != n_agents {
134 return Err(OptimizeError::InvalidInput(format!(
135 "Expected {} local functions but got {}",
136 n_agents,
137 local_fns.len()
138 )));
139 }
140 if n_vars == 0 {
141 return Err(OptimizeError::InvalidInput("n_vars must be > 0".into()));
142 }
143
144 let rho = config.stepsize;
145
146 let mut x: Vec<Vec<f64>> = (0..n_agents).map(|_| vec![0.0; n_vars]).collect();
148 let mut lam: Vec<Vec<Vec<f64>>> = (0..n_agents)
150 .map(|_| (0..n_agents).map(|_| vec![0.0_f64; n_vars]).collect())
151 .collect();
152
153 let mut primal_history = Vec::with_capacity(config.max_iter);
154 let mut dual_history = Vec::with_capacity(config.max_iter);
155 let mut converged = false;
156 let mut iterations = 0;
157
158 for iter in 0..config.max_iter {
159 iterations = iter + 1;
160 let x_old = x.clone();
161
162 for i in 0..n_agents {
164 let mut neighbours = 0usize;
167 let mut agg = vec![0.0_f64; n_vars];
168 for j in 0..n_agents {
169 if self.topology[i][j] > 0.0 {
170 neighbours += 1;
171 for k in 0..n_vars {
172 agg[k] += lam[i][j][k] - rho * x_old[j][k];
173 }
174 }
175 }
176 let rho_eff = rho * (neighbours.max(1) as f64);
179 let prox_arg: Vec<f64> = agg.iter().map(|a| -a / rho_eff).collect();
180 x[i] = (local_fns[i])(&prox_arg, rho_eff);
181 }
182
183 for i in 0..n_agents {
185 for j in 0..n_agents {
186 if self.topology[i][j] > 0.0 {
187 for k in 0..n_vars {
188 lam[i][j][k] += rho * (x[i][k] - x[j][k]);
189 }
190 }
191 }
192 }
193
194 let mut primal_sq = 0.0_f64;
196 let mut dual_sq = 0.0_f64;
197 for i in 0..n_agents {
198 for j in 0..n_agents {
199 if self.topology[i][j] > 0.0 {
200 for k in 0..n_vars {
201 primal_sq += (x[i][k] - x[j][k]).powi(2);
202 }
203 }
204 }
205 for k in 0..n_vars {
206 dual_sq += (x[i][k] - x_old[i][k]).powi(2);
207 }
208 }
209 let primal_res = primal_sq.sqrt();
210 let dual_res = rho * dual_sq.sqrt();
211
212 primal_history.push(primal_res);
213 dual_history.push(dual_res);
214
215 if primal_res < config.tol {
216 converged = true;
217 break;
218 }
219 }
220
221 let mut x_consensus = vec![0.0_f64; n_vars];
223 let scale = 1.0 / n_agents as f64;
224 for xi in x.iter() {
225 for k in 0..n_vars {
226 x_consensus[k] += scale * xi[k];
227 }
228 }
229
230 Ok(AdmmResult {
231 x: x_consensus,
232 primal_residual: primal_history,
233 dual_residual: dual_history,
234 converged,
235 iterations,
236 })
237 }
238}
239
240#[derive(Debug)]
253pub struct ExtraSolver {
254 pub w: Vec<Vec<f64>>,
256 pub w_tilde: Vec<Vec<f64>>,
258}
259
260impl ExtraSolver {
261 pub fn new(w: Vec<Vec<f64>>) -> OptimizeResult<Self> {
263 let n = w.len();
264 if !check_doubly_stochastic(&w, 1e-6) {
265 return Err(OptimizeError::InvalidInput(
266 "W must be doubly stochastic".into(),
267 ));
268 }
269 let w_tilde: Vec<Vec<f64>> = (0..n)
271 .map(|i| {
272 (0..n)
273 .map(|j| {
274 let eye = if i == j { 1.0 } else { 0.0 };
275 (eye + w[i][j]) / 2.0
276 })
277 .collect()
278 })
279 .collect();
280 Ok(Self { w, w_tilde })
281 }
282
283 pub fn solve<F>(
287 &self,
288 grad_fns: &[F],
289 n_vars: usize,
290 config: &ExtraConfig,
291 ) -> OptimizeResult<AdmmResult>
292 where
293 F: Fn(&[f64]) -> Vec<f64>,
294 {
295 let n_agents = self.w.len();
296 if grad_fns.len() != n_agents {
297 return Err(OptimizeError::InvalidInput(format!(
298 "Expected {} gradient functions but got {}",
299 n_agents,
300 grad_fns.len()
301 )));
302 }
303 if n_vars == 0 {
304 return Err(OptimizeError::InvalidInput("n_vars must be > 0".into()));
305 }
306
307 let alpha = config.alpha;
308
309 let mut x_curr: Vec<Vec<f64>> = (0..n_agents).map(|_| vec![0.0; n_vars]).collect();
312
313 let grad_curr: Vec<Vec<f64>> = (0..n_agents).map(|i| (grad_fns[i])(&x_curr[i])).collect();
315
316 let x_next: Vec<Vec<f64>> = (0..n_agents)
319 .map(|i| {
320 let wx_i: Vec<f64> = (0..n_vars)
322 .map(|k| {
323 (0..n_agents)
324 .map(|j| self.w[i][j] * x_curr[j][k])
325 .sum::<f64>()
326 })
327 .collect();
328 wx_i.iter()
330 .zip(grad_curr[i].iter())
331 .map(|(w, g)| w - alpha * g)
332 .collect()
333 })
334 .collect();
335
336 let mut x_prev = x_curr.clone();
337 let mut x_curr = x_next;
338 let mut grad_prev = grad_curr;
339
340 let mut primal_history = Vec::with_capacity(config.max_iter);
341 let mut dual_history = Vec::with_capacity(config.max_iter);
342 let mut converged = false;
343 let mut iterations = 1;
344
345 for iter in 1..config.max_iter {
346 iterations = iter + 1;
347
348 let grad_curr: Vec<Vec<f64>> =
349 (0..n_agents).map(|i| (grad_fns[i])(&x_curr[i])).collect();
350
351 let w_tilde_x_curr: Vec<Vec<f64>> = (0..n_agents)
353 .map(|i| {
354 (0..n_vars)
355 .map(|k| {
356 (0..n_agents)
357 .map(|j| self.w_tilde[i][j] * x_curr[j][k])
358 .sum::<f64>()
359 })
360 .collect()
361 })
362 .collect();
363
364 let w_tilde_x_prev: Vec<Vec<f64>> = (0..n_agents)
366 .map(|i| {
367 (0..n_vars)
368 .map(|k| {
369 (0..n_agents)
370 .map(|j| self.w_tilde[i][j] * x_prev[j][k])
371 .sum::<f64>()
372 })
373 .collect()
374 })
375 .collect();
376
377 let x_new: Vec<Vec<f64>> = (0..n_agents)
380 .map(|i| {
381 (0..n_vars)
382 .map(|k| {
383 w_tilde_x_curr[i][k] + x_curr[i][k]
384 - w_tilde_x_prev[i][k]
385 - alpha * (grad_curr[i][k] - grad_prev[i][k])
386 })
387 .collect()
388 })
389 .collect();
390
391 let x_bar: Vec<f64> = (0..n_vars)
393 .map(|k| x_new.iter().map(|xi| xi[k]).sum::<f64>() / n_agents as f64)
394 .collect();
395 let cons_res: f64 = x_new
396 .iter()
397 .map(|xi| {
398 xi.iter()
399 .zip(x_bar.iter())
400 .map(|(a, b)| (a - b).powi(2))
401 .sum::<f64>()
402 .sqrt()
403 })
404 .fold(0.0_f64, f64::max);
405
406 let dx: f64 = x_new
408 .iter()
409 .zip(x_curr.iter())
410 .map(|(xn, xc)| {
411 xn.iter()
412 .zip(xc.iter())
413 .map(|(a, b)| (a - b).powi(2))
414 .sum::<f64>()
415 })
416 .sum::<f64>()
417 .sqrt();
418
419 primal_history.push(cons_res);
420 dual_history.push(dx);
421
422 x_prev = x_curr;
423 x_curr = x_new;
424 grad_prev = grad_curr;
425
426 if cons_res < config.tol && dx < config.tol {
427 converged = true;
428 break;
429 }
430 }
431
432 let x_bar: Vec<f64> = (0..n_vars)
434 .map(|k| x_curr.iter().map(|xi| xi[k]).sum::<f64>() / n_agents as f64)
435 .collect();
436
437 Ok(AdmmResult {
438 x: x_bar,
439 primal_residual: primal_history,
440 dual_residual: dual_history,
441 converged,
442 iterations,
443 })
444 }
445}
446
447pub fn ring_topology(n: usize) -> Vec<Vec<f64>> {
455 let mut adj = vec![vec![0.0_f64; n]; n];
456 for i in 0..n {
457 let next = (i + 1) % n;
458 let prev = (i + n - 1) % n;
459 adj[i][next] = 1.0;
460 adj[i][prev] = 1.0;
461 }
462 adj
463}
464
465pub fn metropolis_hastings_weights(adj: &[Vec<f64>]) -> Vec<Vec<f64>> {
470 let n = adj.len();
471 let degrees: Vec<usize> = (0..n)
472 .map(|i| adj[i].iter().filter(|&&v| v > 0.0).count())
473 .collect();
474
475 let mut w = vec![vec![0.0_f64; n]; n];
476 for i in 0..n {
477 let mut row_sum = 0.0;
478 for j in 0..n {
479 if adj[i][j] > 0.0 && i != j {
480 let denom = 1.0 + degrees[i].max(degrees[j]) as f64;
481 w[i][j] = 1.0 / denom;
482 row_sum += w[i][j];
483 }
484 }
485 w[i][i] = 1.0 - row_sum;
486 }
487 w
488}
489
490#[cfg(test)]
495mod tests {
496 use super::*;
497
498 fn ring_w(n: usize) -> Vec<Vec<f64>> {
500 let adj = ring_topology(n);
501 metropolis_hastings_weights(&adj)
502 }
503
504 #[test]
505 fn test_ring_topology() {
506 let adj = ring_topology(4);
507 assert_eq!(adj[0][1], 1.0);
509 assert_eq!(adj[0][3], 1.0);
510 assert_eq!(adj[0][0], 0.0);
511 assert_eq!(adj[0][2], 0.0);
512 }
513
514 #[test]
515 fn test_metropolis_hastings_doubly_stochastic() {
516 let w = ring_w(4);
517 for row in w.iter() {
519 let s: f64 = row.iter().sum();
520 assert!((s - 1.0).abs() < 1e-10, "Row sum = {}", s);
521 }
522 let n = w.len();
524 for j in 0..n {
525 let s: f64 = w.iter().map(|row| row[j]).sum();
526 assert!((s - 1.0).abs() < 1e-10, "Col {} sum = {}", j, s);
527 }
528 }
529
530 #[test]
531 fn test_pdmm_converges() {
532 let n_agents = 3;
535 let n_vars = 1;
536 let centers = vec![1.0_f64, 3.0, 5.0]; let topology = vec![
538 vec![0.0, 1.0, 1.0],
539 vec![1.0, 0.0, 1.0],
540 vec![1.0, 1.0, 0.0],
541 ];
542 let solver = PdmmSolver::new(topology).expect("PDMM creation failed");
543 let config = PdmmConfig {
544 stepsize: 0.2,
545 max_iter: 2000,
546 tol: 1e-4,
547 };
548 let prox_fns: Vec<Box<dyn Fn(&[f64], f64) -> Vec<f64>>> = centers
550 .iter()
551 .map(|&c| {
552 let f: Box<dyn Fn(&[f64], f64) -> Vec<f64>> =
553 Box::new(move |v: &[f64], rho: f64| vec![(c + rho * v[0]) / (1.0 + rho)]);
554 f
555 })
556 .collect();
557
558 let result = solver
559 .solve(&prox_fns, n_vars, &config)
560 .expect("PDMM solve failed");
561
562 assert!(
563 result.converged,
564 "PDMM should converge, iters={}",
565 result.iterations
566 );
567 assert!(
568 (result.x[0] - 3.0).abs() < 0.1,
569 "x = {:.4} (expected 3.0)",
570 result.x[0]
571 );
572 }
573
574 #[test]
575 fn test_pdmm_topology_ring() {
576 let centers = vec![0.0_f64, 2.0, 4.0, 6.0]; let adj = ring_topology(4);
579 let solver = PdmmSolver::new(adj).expect("PDMM ring creation failed");
580 let config = PdmmConfig {
581 stepsize: 0.1,
582 max_iter: 5000,
583 tol: 1e-3,
584 };
585 let prox_fns: Vec<Box<dyn Fn(&[f64], f64) -> Vec<f64>>> = centers
586 .iter()
587 .map(|&c| {
588 let f: Box<dyn Fn(&[f64], f64) -> Vec<f64>> =
589 Box::new(move |v: &[f64], rho: f64| vec![(c + rho * v[0]) / (1.0 + rho)]);
590 f
591 })
592 .collect();
593
594 let result = solver
595 .solve(&prox_fns, 1, &config)
596 .expect("PDMM ring solve failed");
597
598 assert!(
600 (result.x[0] - 3.0).abs() < 0.5,
601 "x = {:.4} (expected ~3.0)",
602 result.x[0]
603 );
604 }
605
606 #[test]
607 fn test_extra_exact_consensus() {
608 let centers = vec![1.0_f64, 3.0, 5.0, 7.0]; let w = ring_w(4);
611 let solver = ExtraSolver::new(w).expect("EXTRA creation failed");
612 let config = ExtraConfig {
613 alpha: 0.02,
614 max_iter: 2000,
615 tol: 1e-4,
616 };
617 let grad_fns: Vec<Box<dyn Fn(&[f64]) -> Vec<f64>>> = centers
619 .iter()
620 .map(|&c| {
621 let f: Box<dyn Fn(&[f64]) -> Vec<f64>> =
622 Box::new(move |x: &[f64]| vec![2.0 * (x[0] - c)]);
623 f
624 })
625 .collect();
626
627 let result = solver
628 .solve(&grad_fns, 1, &config)
629 .expect("EXTRA solve failed");
630
631 assert!(
632 result.converged || result.iterations == config.max_iter,
633 "EXTRA iterations: {}",
634 result.iterations
635 );
636 assert!(
637 (result.x[0] - 4.0).abs() < 0.1,
638 "x = {:.4} (expected 4.0), iters={}",
639 result.x[0],
640 result.iterations
641 );
642 }
643
644 #[test]
645 fn test_extra_vs_admm_same_solution() {
646 use super::super::admm::solve_lasso_admm;
647
648 let centers = vec![2.0_f64, 4.0, 6.0]; let n_agents = 3_usize;
652
653 let w = ring_w(n_agents);
655 let solver = ExtraSolver::new(w).expect("EXTRA creation failed");
656 let config = ExtraConfig {
657 alpha: 0.02,
658 max_iter: 2000,
659 tol: 1e-4,
660 };
661 let grad_fns: Vec<Box<dyn Fn(&[f64]) -> Vec<f64>>> = centers
662 .iter()
663 .map(|&c| {
664 let f: Box<dyn Fn(&[f64]) -> Vec<f64>> =
665 Box::new(move |x: &[f64]| vec![2.0 * (x[0] - c)]);
666 f
667 })
668 .collect();
669 let extra_res = solver.solve(&grad_fns, 1, &config).expect("EXTRA failed");
670
671 use super::super::admm::consensus_admm;
673 let admm_config = super::super::types::AdmmConfig {
674 rho: 1.0,
675 max_iter: 500,
676 abs_tol: 1e-6,
677 rel_tol: 1e-4,
678 warm_start: false,
679 over_relaxation: 1.0,
680 };
681 let prox_fns: Vec<Box<dyn Fn(&[f64], f64) -> Vec<f64>>> = centers
682 .iter()
683 .map(|&c| {
684 let f: Box<dyn Fn(&[f64], f64) -> Vec<f64>> =
685 Box::new(move |v: &[f64], rho: f64| {
686 vec![(rho * v[0] + 2.0 * c) / (rho + 2.0)]
688 });
689 f
690 })
691 .collect();
692 let admm_res = consensus_admm(&prox_fns, 1, &admm_config).expect("ADMM failed");
693
694 assert!(
696 (extra_res.x[0] - 4.0).abs() < 0.2,
697 "EXTRA x = {:.4}",
698 extra_res.x[0]
699 );
700 assert!(
701 (admm_res.x[0] - 4.0).abs() < 0.1,
702 "ADMM x = {:.4}",
703 admm_res.x[0]
704 );
705 }
706
707 #[test]
708 fn test_extra_solver_invalid_w() {
709 let w = vec![vec![0.5, 0.5], vec![0.9, 0.1]]; let result = ExtraSolver::new(w);
712 assert!(result.is_err());
713 }
714}