Skip to main content

scirs2_optimize/distributed_admm/
pdmm_extra.rs

1//! PDMM (Primal-Dual Method of Multipliers) and EXTRA algorithms.
2//!
3//! **PDMM** (Chang et al. 2015):
4//!
5//! Decentralized optimization over a graph G = (V, E) where each agent i
6//! communicates only with its neighbours N_i:
7//!
8//! ```text
9//!   min  Σ_i f_i(x)   s.t. x_i = x_j  ∀(i,j) ∈ E
10//! ```
11//!
12//! PDMM updates per agent i:
13//!
14//! ```text
15//!   x_i^{k+1} = argmin_x [ f_i(x) + Σ_{j∈N_i} (λ_{ij}^T x + (ρ/2)||x - x_j^k||²) ]
16//!   λ_{ij}^{k+1} = λ_{ij}^k + ρ (x_i^{k+1} - x_j^{k+1})
17//! ```
18//!
19//! **EXTRA** (Shi et al. 2015):
20//!
21//! Uses two mixing matrices W (doubly stochastic) and W̃ = (I+W)/2.
22//! Gradient tracking ensures exact consensus without diminishing step size:
23//!
24//! ```text
25//!   x^1   = W x^0 - α ∇F(x^0)
26//!   x^{k+2} = W̃ x^{k+1} + x^{k+1} - W̃ x^k - α (∇F(x^{k+1}) - ∇F(x^k))
27//! ```
28//!
29//! # References
30//! - Chang et al. (2015). "Multi-Agent Distributed Optimization via Inexact
31//!   Consensus ADMM." IEEE Trans. Signal Processing.
32//! - Shi et al. (2015). "EXTRA: An Exact First-Order Algorithm for Decentralized
33//!   Consensus Optimization." SIAM J. Optim.
34
35use super::types::{AdmmResult, ExtraConfig, PdmmConfig};
36use crate::error::{OptimizeError, OptimizeResult};
37
38// ─────────────────────────────────────────────────────────────────────────────
39// Internal helpers
40// ─────────────────────────────────────────────────────────────────────────────
41
42fn norm2(v: &[f64]) -> f64 {
43    v.iter().map(|x| x * x).sum::<f64>().sqrt()
44}
45
46fn vec_add(a: &[f64], b: &[f64]) -> Vec<f64> {
47    a.iter().zip(b.iter()).map(|(ai, bi)| ai + bi).collect()
48}
49
50fn vec_sub(a: &[f64], b: &[f64]) -> Vec<f64> {
51    a.iter().zip(b.iter()).map(|(ai, bi)| ai - bi).collect()
52}
53
54fn vec_scale(a: &[f64], s: f64) -> Vec<f64> {
55    a.iter().map(|ai| ai * s).collect()
56}
57
58/// Matrix-vector product y = W x (row-major W).
59fn mat_vec(w: &[Vec<f64>], x: &[f64]) -> Vec<f64> {
60    w.iter()
61        .map(|row| row.iter().zip(x.iter()).map(|(wi, xi)| wi * xi).sum())
62        .collect()
63}
64
65/// Verify that a matrix is doubly stochastic (all rows and columns sum to 1).
66fn check_doubly_stochastic(w: &[Vec<f64>], tol: f64) -> bool {
67    let n = w.len();
68    // Row sums
69    for row in w.iter() {
70        if row.len() != n {
71            return false;
72        }
73        let s: f64 = row.iter().sum();
74        if (s - 1.0).abs() > tol {
75            return false;
76        }
77    }
78    // Column sums
79    for j in 0..n {
80        let s: f64 = w.iter().map(|row| row[j]).sum();
81        if (s - 1.0).abs() > tol {
82            return false;
83        }
84    }
85    true
86}
87
88// ─────────────────────────────────────────────────────────────────────────────
89// PDMM Solver
90// ─────────────────────────────────────────────────────────────────────────────
91
92/// PDMM solver for decentralized consensus optimization over a network.
93///
94/// Each agent has a quadratic local objective f_i(x) = (1/2)||x - c_i||²
95/// and communicates only with adjacent agents.
96#[derive(Debug)]
97pub struct PdmmSolver {
98    /// Adjacency matrix (symmetric, 0-1 valued). Entry `[i][j]` = 1 iff agents i and j are connected.
99    pub topology: Vec<Vec<f64>>,
100}
101
102impl PdmmSolver {
103    /// Create a new PDMM solver with the given adjacency matrix.
104    pub fn new(topology: Vec<Vec<f64>>) -> OptimizeResult<Self> {
105        let n = topology.len();
106        for (i, row) in topology.iter().enumerate() {
107            if row.len() != n {
108                return Err(OptimizeError::InvalidInput(format!(
109                    "Topology row {} has length {} but expected {}",
110                    i,
111                    row.len(),
112                    n
113                )));
114            }
115        }
116        Ok(Self { topology })
117    }
118
119    /// Solve the consensus problem where each agent i has the local function
120    /// `local_fns[i](x)` (which must have a known proximal operator).
121    ///
122    /// `local_fns[i]` is the proximal operator prox_{f_i/ρ}(v, rho).
123    pub fn solve<F>(
124        &self,
125        local_fns: &[F],
126        n_vars: usize,
127        config: &PdmmConfig,
128    ) -> OptimizeResult<AdmmResult>
129    where
130        F: Fn(&[f64], f64) -> Vec<f64>,
131    {
132        let n_agents = self.topology.len();
133        if local_fns.len() != n_agents {
134            return Err(OptimizeError::InvalidInput(format!(
135                "Expected {} local functions but got {}",
136                n_agents,
137                local_fns.len()
138            )));
139        }
140        if n_vars == 0 {
141            return Err(OptimizeError::InvalidInput("n_vars must be > 0".into()));
142        }
143
144        let rho = config.stepsize;
145
146        // Primal variables x[i], dual edge variables λ[i][j] (for j in N_i)
147        let mut x: Vec<Vec<f64>> = (0..n_agents).map(|_| vec![0.0; n_vars]).collect();
148        // λ_{ij}: for each directed edge i→j, one dual vector
149        let mut lam: Vec<Vec<Vec<f64>>> = (0..n_agents)
150            .map(|_| (0..n_agents).map(|_| vec![0.0_f64; n_vars]).collect())
151            .collect();
152
153        let mut primal_history = Vec::with_capacity(config.max_iter);
154        let mut dual_history = Vec::with_capacity(config.max_iter);
155        let mut converged = false;
156        let mut iterations = 0;
157
158        for iter in 0..config.max_iter {
159            iterations = iter + 1;
160            let x_old = x.clone();
161
162            // ── x-updates ─────────────────────────────────────────────────────
163            for i in 0..n_agents {
164                // Aggregate neighbour contribution:
165                // v_i = -( Σ_{j∈N_i} (λ_{ij} - ρ x_j^old) ) / (ρ * |N_i|)
166                let mut neighbours = 0usize;
167                let mut agg = vec![0.0_f64; n_vars];
168                for j in 0..n_agents {
169                    if self.topology[i][j] > 0.0 {
170                        neighbours += 1;
171                        for k in 0..n_vars {
172                            agg[k] += lam[i][j][k] - rho * x_old[j][k];
173                        }
174                    }
175                }
176                // prox argument: v = -(1/ρ) agg_i / |N_i| (per-neighbour average)
177                // PDMM prox: prox_{f_i/ρ̃}(v) where ρ̃ = ρ * |N_i|
178                let rho_eff = rho * (neighbours.max(1) as f64);
179                let prox_arg: Vec<f64> = agg.iter().map(|a| -a / rho_eff).collect();
180                x[i] = (local_fns[i])(&prox_arg, rho_eff);
181            }
182
183            // ── λ-updates ──────────────────────────────────────────────────────
184            for i in 0..n_agents {
185                for j in 0..n_agents {
186                    if self.topology[i][j] > 0.0 {
187                        for k in 0..n_vars {
188                            lam[i][j][k] += rho * (x[i][k] - x[j][k]);
189                        }
190                    }
191                }
192            }
193
194            // ── Consensus residual (max disagreement between neighbours) ───────
195            let mut primal_sq = 0.0_f64;
196            let mut dual_sq = 0.0_f64;
197            for i in 0..n_agents {
198                for j in 0..n_agents {
199                    if self.topology[i][j] > 0.0 {
200                        for k in 0..n_vars {
201                            primal_sq += (x[i][k] - x[j][k]).powi(2);
202                        }
203                    }
204                }
205                for k in 0..n_vars {
206                    dual_sq += (x[i][k] - x_old[i][k]).powi(2);
207                }
208            }
209            let primal_res = primal_sq.sqrt();
210            let dual_res = rho * dual_sq.sqrt();
211
212            primal_history.push(primal_res);
213            dual_history.push(dual_res);
214
215            if primal_res < config.tol {
216                converged = true;
217                break;
218            }
219        }
220
221        // Consensus solution: average over agents
222        let mut x_consensus = vec![0.0_f64; n_vars];
223        let scale = 1.0 / n_agents as f64;
224        for xi in x.iter() {
225            for k in 0..n_vars {
226                x_consensus[k] += scale * xi[k];
227            }
228        }
229
230        Ok(AdmmResult {
231            x: x_consensus,
232            primal_residual: primal_history,
233            dual_residual: dual_history,
234            converged,
235            iterations,
236        })
237    }
238}
239
240// ─────────────────────────────────────────────────────────────────────────────
241// EXTRA Solver
242// ─────────────────────────────────────────────────────────────────────────────
243
244/// EXTRA solver for exact decentralized consensus.
245///
246/// EXTRA uses gradient tracking and converges to the exact minimizer
247/// of Σ_i f_i(x) with a fixed step size.
248///
249/// Requires:
250/// - W: doubly stochastic mixing matrix (n_agents × n_agents)
251/// - `grad_fns[i]`: gradient of f\_i at any point x
252#[derive(Debug)]
253pub struct ExtraSolver {
254    /// Doubly stochastic mixing matrix W (n_agents × n_agents).
255    pub w: Vec<Vec<f64>>,
256    /// W̃ = (I + W) / 2.
257    pub w_tilde: Vec<Vec<f64>>,
258}
259
260impl ExtraSolver {
261    /// Create a new EXTRA solver from a doubly stochastic mixing matrix W.
262    pub fn new(w: Vec<Vec<f64>>) -> OptimizeResult<Self> {
263        let n = w.len();
264        if !check_doubly_stochastic(&w, 1e-6) {
265            return Err(OptimizeError::InvalidInput(
266                "W must be doubly stochastic".into(),
267            ));
268        }
269        // W̃ = (I + W) / 2
270        let w_tilde: Vec<Vec<f64>> = (0..n)
271            .map(|i| {
272                (0..n)
273                    .map(|j| {
274                        let eye = if i == j { 1.0 } else { 0.0 };
275                        (eye + w[i][j]) / 2.0
276                    })
277                    .collect()
278            })
279            .collect();
280        Ok(Self { w, w_tilde })
281    }
282
283    /// Solve the decentralized problem using EXTRA.
284    ///
285    /// `grad_fns[i](x)` returns the gradient of f_i at x.
286    pub fn solve<F>(
287        &self,
288        grad_fns: &[F],
289        n_vars: usize,
290        config: &ExtraConfig,
291    ) -> OptimizeResult<AdmmResult>
292    where
293        F: Fn(&[f64]) -> Vec<f64>,
294    {
295        let n_agents = self.w.len();
296        if grad_fns.len() != n_agents {
297            return Err(OptimizeError::InvalidInput(format!(
298                "Expected {} gradient functions but got {}",
299                n_agents,
300                grad_fns.len()
301            )));
302        }
303        if n_vars == 0 {
304            return Err(OptimizeError::InvalidInput("n_vars must be > 0".into()));
305        }
306
307        let alpha = config.alpha;
308
309        // x^0: all agents start at zero
310        // Agent states: matrix n_agents × n_vars
311        let mut x_curr: Vec<Vec<f64>> = (0..n_agents).map(|_| vec![0.0; n_vars]).collect();
312
313        // Compute gradients at x^0
314        let grad_curr: Vec<Vec<f64>> = (0..n_agents).map(|i| (grad_fns[i])(&x_curr[i])).collect();
315
316        // Stack gradients (sum of per-agent grad, column-wise mixing)
317        // Stacked agent vectors as rows for mixing: y[i] = Σ_j W[i][j] x[j]
318        let x_next: Vec<Vec<f64>> = (0..n_agents)
319            .map(|i| {
320                // W x^0 row i
321                let wx_i: Vec<f64> = (0..n_vars)
322                    .map(|k| {
323                        (0..n_agents)
324                            .map(|j| self.w[i][j] * x_curr[j][k])
325                            .sum::<f64>()
326                    })
327                    .collect();
328                // x^1_i = (W x^0)_i - α ∇f_i(x^0_i)
329                wx_i.iter()
330                    .zip(grad_curr[i].iter())
331                    .map(|(w, g)| w - alpha * g)
332                    .collect()
333            })
334            .collect();
335
336        let mut x_prev = x_curr.clone();
337        let mut x_curr = x_next;
338        let mut grad_prev = grad_curr;
339
340        let mut primal_history = Vec::with_capacity(config.max_iter);
341        let mut dual_history = Vec::with_capacity(config.max_iter);
342        let mut converged = false;
343        let mut iterations = 1;
344
345        for iter in 1..config.max_iter {
346            iterations = iter + 1;
347
348            let grad_curr: Vec<Vec<f64>> =
349                (0..n_agents).map(|i| (grad_fns[i])(&x_curr[i])).collect();
350
351            // W̃ x^{k+1}
352            let w_tilde_x_curr: Vec<Vec<f64>> = (0..n_agents)
353                .map(|i| {
354                    (0..n_vars)
355                        .map(|k| {
356                            (0..n_agents)
357                                .map(|j| self.w_tilde[i][j] * x_curr[j][k])
358                                .sum::<f64>()
359                        })
360                        .collect()
361                })
362                .collect();
363
364            // W̃ x^k
365            let w_tilde_x_prev: Vec<Vec<f64>> = (0..n_agents)
366                .map(|i| {
367                    (0..n_vars)
368                        .map(|k| {
369                            (0..n_agents)
370                                .map(|j| self.w_tilde[i][j] * x_prev[j][k])
371                                .sum::<f64>()
372                        })
373                        .collect()
374                })
375                .collect();
376
377            // x^{k+2}_i = W̃ x^{k+1}_i + x^{k+1}_i - W̃ x^k_i
378            //              - α (∇f_i(x^{k+1}) - ∇f_i(x^k))
379            let x_new: Vec<Vec<f64>> = (0..n_agents)
380                .map(|i| {
381                    (0..n_vars)
382                        .map(|k| {
383                            w_tilde_x_curr[i][k] + x_curr[i][k]
384                                - w_tilde_x_prev[i][k]
385                                - alpha * (grad_curr[i][k] - grad_prev[i][k])
386                        })
387                        .collect()
388                })
389                .collect();
390
391            // Consensus residual: max ||x_i - x̄||
392            let x_bar: Vec<f64> = (0..n_vars)
393                .map(|k| x_new.iter().map(|xi| xi[k]).sum::<f64>() / n_agents as f64)
394                .collect();
395            let cons_res: f64 = x_new
396                .iter()
397                .map(|xi| {
398                    xi.iter()
399                        .zip(x_bar.iter())
400                        .map(|(a, b)| (a - b).powi(2))
401                        .sum::<f64>()
402                        .sqrt()
403                })
404                .fold(0.0_f64, f64::max);
405
406            // Dual residual: ||x^{k+1} - x^k||
407            let dx: f64 = x_new
408                .iter()
409                .zip(x_curr.iter())
410                .map(|(xn, xc)| {
411                    xn.iter()
412                        .zip(xc.iter())
413                        .map(|(a, b)| (a - b).powi(2))
414                        .sum::<f64>()
415                })
416                .sum::<f64>()
417                .sqrt();
418
419            primal_history.push(cons_res);
420            dual_history.push(dx);
421
422            x_prev = x_curr;
423            x_curr = x_new;
424            grad_prev = grad_curr;
425
426            if cons_res < config.tol && dx < config.tol {
427                converged = true;
428                break;
429            }
430        }
431
432        // Return average of agent states as consensus solution
433        let x_bar: Vec<f64> = (0..n_vars)
434            .map(|k| x_curr.iter().map(|xi| xi[k]).sum::<f64>() / n_agents as f64)
435            .collect();
436
437        Ok(AdmmResult {
438            x: x_bar,
439            primal_residual: primal_history,
440            dual_residual: dual_history,
441            converged,
442            iterations,
443        })
444    }
445}
446
447// ─────────────────────────────────────────────────────────────────────────────
448// Topology builders
449// ─────────────────────────────────────────────────────────────────────────────
450
451/// Build a ring topology adjacency matrix for n agents.
452///
453/// Agent i is connected to (i-1) mod n and (i+1) mod n.
454pub fn ring_topology(n: usize) -> Vec<Vec<f64>> {
455    let mut adj = vec![vec![0.0_f64; n]; n];
456    for i in 0..n {
457        let next = (i + 1) % n;
458        let prev = (i + n - 1) % n;
459        adj[i][next] = 1.0;
460        adj[i][prev] = 1.0;
461    }
462    adj
463}
464
465/// Build a Metropolis-Hastings doubly stochastic mixing matrix from an adjacency matrix.
466///
467/// W_{ij} = 1 / (1 + max(deg_i, deg_j))  for (i,j) ∈ E
468/// W_{ii} = 1 - Σ_{j≠i} W_{ij}
469pub fn metropolis_hastings_weights(adj: &[Vec<f64>]) -> Vec<Vec<f64>> {
470    let n = adj.len();
471    let degrees: Vec<usize> = (0..n)
472        .map(|i| adj[i].iter().filter(|&&v| v > 0.0).count())
473        .collect();
474
475    let mut w = vec![vec![0.0_f64; n]; n];
476    for i in 0..n {
477        let mut row_sum = 0.0;
478        for j in 0..n {
479            if adj[i][j] > 0.0 && i != j {
480                let denom = 1.0 + degrees[i].max(degrees[j]) as f64;
481                w[i][j] = 1.0 / denom;
482                row_sum += w[i][j];
483            }
484        }
485        w[i][i] = 1.0 - row_sum;
486    }
487    w
488}
489
490// ─────────────────────────────────────────────────────────────────────────────
491// Tests
492// ─────────────────────────────────────────────────────────────────────────────
493
494#[cfg(test)]
495mod tests {
496    use super::*;
497
498    /// Build a ring mixing matrix for n agents.
499    fn ring_w(n: usize) -> Vec<Vec<f64>> {
500        let adj = ring_topology(n);
501        metropolis_hastings_weights(&adj)
502    }
503
504    #[test]
505    fn test_ring_topology() {
506        let adj = ring_topology(4);
507        // Agent 0 connects to 1 and 3
508        assert_eq!(adj[0][1], 1.0);
509        assert_eq!(adj[0][3], 1.0);
510        assert_eq!(adj[0][0], 0.0);
511        assert_eq!(adj[0][2], 0.0);
512    }
513
514    #[test]
515    fn test_metropolis_hastings_doubly_stochastic() {
516        let w = ring_w(4);
517        // Row sums should be 1
518        for row in w.iter() {
519            let s: f64 = row.iter().sum();
520            assert!((s - 1.0).abs() < 1e-10, "Row sum = {}", s);
521        }
522        // Column sums should be 1
523        let n = w.len();
524        for j in 0..n {
525            let s: f64 = w.iter().map(|row| row[j]).sum();
526            assert!((s - 1.0).abs() < 1e-10, "Col {} sum = {}", j, s);
527        }
528    }
529
530    #[test]
531    fn test_pdmm_converges() {
532        // 3 agents on a complete graph, each minimising f_i(x) = (x - c_i)^2 / 2
533        // Optimal consensus: x* = mean(c_i)
534        let n_agents = 3;
535        let n_vars = 1;
536        let centers = vec![1.0_f64, 3.0, 5.0]; // mean = 3.0
537        let topology = vec![
538            vec![0.0, 1.0, 1.0],
539            vec![1.0, 0.0, 1.0],
540            vec![1.0, 1.0, 0.0],
541        ];
542        let solver = PdmmSolver::new(topology).expect("PDMM creation failed");
543        let config = PdmmConfig {
544            stepsize: 0.2,
545            max_iter: 2000,
546            tol: 1e-4,
547        };
548        // Proximal operator: prox_{f_i/ρ}(v) = (c_i + ρ * v) / (1 + ρ)
549        let prox_fns: Vec<Box<dyn Fn(&[f64], f64) -> Vec<f64>>> = centers
550            .iter()
551            .map(|&c| {
552                let f: Box<dyn Fn(&[f64], f64) -> Vec<f64>> =
553                    Box::new(move |v: &[f64], rho: f64| vec![(c + rho * v[0]) / (1.0 + rho)]);
554                f
555            })
556            .collect();
557
558        let result = solver
559            .solve(&prox_fns, n_vars, &config)
560            .expect("PDMM solve failed");
561
562        assert!(
563            result.converged,
564            "PDMM should converge, iters={}",
565            result.iterations
566        );
567        assert!(
568            (result.x[0] - 3.0).abs() < 0.1,
569            "x = {:.4} (expected 3.0)",
570            result.x[0]
571        );
572    }
573
574    #[test]
575    fn test_pdmm_topology_ring() {
576        // 4 agents on a ring, each with f_i(x) = (x - c_i)^2 / 2
577        let centers = vec![0.0_f64, 2.0, 4.0, 6.0]; // mean = 3.0
578        let adj = ring_topology(4);
579        let solver = PdmmSolver::new(adj).expect("PDMM ring creation failed");
580        let config = PdmmConfig {
581            stepsize: 0.1,
582            max_iter: 5000,
583            tol: 1e-3,
584        };
585        let prox_fns: Vec<Box<dyn Fn(&[f64], f64) -> Vec<f64>>> = centers
586            .iter()
587            .map(|&c| {
588                let f: Box<dyn Fn(&[f64], f64) -> Vec<f64>> =
589                    Box::new(move |v: &[f64], rho: f64| vec![(c + rho * v[0]) / (1.0 + rho)]);
590                f
591            })
592            .collect();
593
594        let result = solver
595            .solve(&prox_fns, 1, &config)
596            .expect("PDMM ring solve failed");
597
598        // Ring topology converges more slowly; check approximate consensus
599        assert!(
600            (result.x[0] - 3.0).abs() < 0.5,
601            "x = {:.4} (expected ~3.0)",
602            result.x[0]
603        );
604    }
605
606    #[test]
607    fn test_extra_exact_consensus() {
608        // 4 agents: f_i(x) = (x - c_i)^2, consensus → x* = mean(c_i)
609        let centers = vec![1.0_f64, 3.0, 5.0, 7.0]; // mean = 4.0
610        let w = ring_w(4);
611        let solver = ExtraSolver::new(w).expect("EXTRA creation failed");
612        let config = ExtraConfig {
613            alpha: 0.02,
614            max_iter: 2000,
615            tol: 1e-4,
616        };
617        // Gradient of f_i(x) = 2(x - c_i)
618        let grad_fns: Vec<Box<dyn Fn(&[f64]) -> Vec<f64>>> = centers
619            .iter()
620            .map(|&c| {
621                let f: Box<dyn Fn(&[f64]) -> Vec<f64>> =
622                    Box::new(move |x: &[f64]| vec![2.0 * (x[0] - c)]);
623                f
624            })
625            .collect();
626
627        let result = solver
628            .solve(&grad_fns, 1, &config)
629            .expect("EXTRA solve failed");
630
631        assert!(
632            result.converged || result.iterations == config.max_iter,
633            "EXTRA iterations: {}",
634            result.iterations
635        );
636        assert!(
637            (result.x[0] - 4.0).abs() < 0.1,
638            "x = {:.4} (expected 4.0), iters={}",
639            result.x[0],
640            result.iterations
641        );
642    }
643
644    #[test]
645    fn test_extra_vs_admm_same_solution() {
646        use super::super::admm::solve_lasso_admm;
647
648        // Both should solve the mean-consensus problem consistently
649        // EXTRA: f_i(x) = (x - c_i)^2, grad = 2(x-c_i)
650        let centers = vec![2.0_f64, 4.0, 6.0]; // mean = 4.0
651        let n_agents = 3_usize;
652
653        // EXTRA solution
654        let w = ring_w(n_agents);
655        let solver = ExtraSolver::new(w).expect("EXTRA creation failed");
656        let config = ExtraConfig {
657            alpha: 0.02,
658            max_iter: 2000,
659            tol: 1e-4,
660        };
661        let grad_fns: Vec<Box<dyn Fn(&[f64]) -> Vec<f64>>> = centers
662            .iter()
663            .map(|&c| {
664                let f: Box<dyn Fn(&[f64]) -> Vec<f64>> =
665                    Box::new(move |x: &[f64]| vec![2.0 * (x[0] - c)]);
666                f
667            })
668            .collect();
669        let extra_res = solver.solve(&grad_fns, 1, &config).expect("EXTRA failed");
670
671        // ADMM consensus solution (same problem via consensus_admm)
672        use super::super::admm::consensus_admm;
673        let admm_config = super::super::types::AdmmConfig {
674            rho: 1.0,
675            max_iter: 500,
676            abs_tol: 1e-6,
677            rel_tol: 1e-4,
678            warm_start: false,
679            over_relaxation: 1.0,
680        };
681        let prox_fns: Vec<Box<dyn Fn(&[f64], f64) -> Vec<f64>>> = centers
682            .iter()
683            .map(|&c| {
684                let f: Box<dyn Fn(&[f64], f64) -> Vec<f64>> =
685                    Box::new(move |v: &[f64], rho: f64| {
686                        // prox for f_i(x) = (x-c)^2: solution is (rho*v + 2*c) / (rho + 2)
687                        vec![(rho * v[0] + 2.0 * c) / (rho + 2.0)]
688                    });
689                f
690            })
691            .collect();
692        let admm_res = consensus_admm(&prox_fns, 1, &admm_config).expect("ADMM failed");
693
694        // Both should be close to the true mean (4.0)
695        assert!(
696            (extra_res.x[0] - 4.0).abs() < 0.2,
697            "EXTRA x = {:.4}",
698            extra_res.x[0]
699        );
700        assert!(
701            (admm_res.x[0] - 4.0).abs() < 0.1,
702            "ADMM x = {:.4}",
703            admm_res.x[0]
704        );
705    }
706
707    #[test]
708    fn test_extra_solver_invalid_w() {
709        // Non-doubly-stochastic W should fail
710        let w = vec![vec![0.5, 0.5], vec![0.9, 0.1]]; // col 0 sums to 1.4 ≠ 1
711        let result = ExtraSolver::new(w);
712        assert!(result.is_err());
713    }
714}