Skip to main content

scirs2_optimize/proximal/
splitting.rs

1//! Operator Splitting Methods
2//!
3//! This module provides splitting algorithms for optimising sums of non-smooth
4//! convex functions, where no single proximal operator is available for the
5//! combined objective.
6//!
7//! # Algorithms
8//!
9//! ## Douglas-Rachford Splitting
10//! Minimises `f(x) + g(x)` using only `prox_f` and `prox_g`:
11//! ```text
12//! y_{k+1} = prox_{γg}(z_k)
13//! x_{k+1} = prox_{γf}(2y_{k+1} − z_k)
14//! z_{k+1} = z_k + x_{k+1} − y_{k+1}
15//! ```
16//!
17//! ## Peaceman-Rachford Splitting
18//! A less damped variant that requires strong monotonicity to converge.
19//!
20//! ## Primal-Dual (Chambolle-Pock)
21//! Solves `min_x f(x) + g(Kx)` where `K` is a linear operator.
22//!
23//! # References
24//! - Lions & Mercier (1979). "Splitting Algorithms for the Sum of Two Nonlinear
25//!   Operators". *SIAM J. Numer. Anal.*
26//! - Eckstein & Bertsekas (1992). "On the Douglas-Rachford Splitting Method".
27//!   *Math. Programming.*
28//! - Chambolle & Pock (2011). "A First-Order Primal-Dual Algorithm for Convex
29//!   Problems with Applications to Imaging". *J. Math. Imaging Vision.*
30
31use crate::error::OptimizeError;
32
33// ─── Douglas-Rachford Splitting ──────────────────────────────────────────────
34
35/// Minimise `f(x) + g(x)` using Douglas-Rachford (DR) splitting.
36///
37/// The algorithm only requires the proximal operators of `f` and `g`
38/// separately and does not require differentiability.
39///
40/// # Convergence
41/// Converges for any pair of proper, closed, convex functions when γ > 0.
42/// The fixed-point iterates `{z_k}` converge; the actual solution is
43/// `prox_{γg}(z_∞)`.
44///
45/// # Arguments
46/// * `prox_f` - Proximal operator of f: `prox_{γf}(·)`
47/// * `prox_g` - Proximal operator of g: `prox_{γg}(·)`
48/// * `x0` - Starting point (initialises z₀ = x₀)
49/// * `gamma` - Step size / scaling parameter (γ > 0)
50/// * `max_iter` - Maximum number of DR iterations
51///
52/// # Returns
53/// The approximate minimiser `x* = prox_{γg}(z_∞)`.
54pub fn douglas_rachford(
55    prox_f: &dyn Fn(&[f64]) -> Vec<f64>,
56    prox_g: &dyn Fn(&[f64]) -> Vec<f64>,
57    x0: Vec<f64>,
58    gamma: f64,
59    max_iter: usize,
60) -> Vec<f64> {
61    let _n = x0.len();
62    let mut z = x0;
63
64    for _ in 0..max_iter {
65        let y = prox_g(&z);
66        let two_y_minus_z: Vec<f64> = y
67            .iter()
68            .zip(z.iter())
69            .map(|(&yi, &zi)| 2.0 * yi - zi)
70            .collect();
71        let x = prox_f(&two_y_minus_z);
72        // z_{k+1} = z_k + x_{k+1} - y_{k+1}
73        z = z
74            .iter()
75            .zip(x.iter().zip(y.iter()))
76            .map(|(&zk, (&xk1, &yk1))| zk + xk1 - yk1)
77            .collect();
78    }
79
80    // Recover solution: x* = prox_g(z)
81    prox_g(&z)
82}
83
84/// Douglas-Rachford splitting with convergence tracking.
85///
86/// Returns the solution along with convergence diagnostics.
87///
88/// # Arguments
89/// Same as `douglas_rachford`, plus:
90/// * `tol` - Convergence tolerance on ‖z_{k+1} − z_k‖
91pub fn douglas_rachford_tracked(
92    prox_f: &dyn Fn(&[f64]) -> Vec<f64>,
93    prox_g: &dyn Fn(&[f64]) -> Vec<f64>,
94    x0: Vec<f64>,
95    gamma: f64,
96    max_iter: usize,
97    tol: f64,
98) -> DRResult {
99    let n = x0.len();
100    let mut z = x0;
101    let _ = gamma; // gamma is used implicitly through the prox scaling
102
103    for iter in 0..max_iter {
104        let z_prev = z.clone();
105
106        let y = prox_g(&z);
107        let two_y_minus_z: Vec<f64> = y
108            .iter()
109            .zip(z.iter())
110            .map(|(&yi, &zi)| 2.0 * yi - zi)
111            .collect();
112        let x = prox_f(&two_y_minus_z);
113        z = z
114            .iter()
115            .zip(x.iter().zip(y.iter()))
116            .map(|(&zk, (&xk1, &yk1))| zk + xk1 - yk1)
117            .collect();
118
119        let dz: f64 = z
120            .iter()
121            .zip(z_prev.iter())
122            .map(|(&a, &b)| (a - b) * (a - b))
123            .sum::<f64>()
124            .sqrt();
125
126        if dz < tol {
127            let x_star = prox_g(&z);
128            return DRResult {
129                x: x_star,
130                nit: iter + 1,
131                converged: true,
132                final_residual: dz,
133            };
134        }
135    }
136
137    let x_star = prox_g(&z);
138    let final_res: f64 = 0.0; // Would need extra iteration to compute
139    DRResult {
140        x: x_star,
141        nit: max_iter,
142        converged: false,
143        final_residual: final_res,
144    }
145}
146
147/// Result of a tracked Douglas-Rachford run.
148#[derive(Debug, Clone)]
149pub struct DRResult {
150    /// Approximate minimiser
151    pub x: Vec<f64>,
152    /// Number of iterations performed
153    pub nit: usize,
154    /// Whether convergence was achieved
155    pub converged: bool,
156    /// Final ‖z_{k+1} − z_k‖ residual
157    pub final_residual: f64,
158}
159
160// ─── Peaceman-Rachford Splitting ─────────────────────────────────────────────
161
162/// Peaceman-Rachford splitting (less damped variant of DR).
163///
164/// Unlike DR, the intermediate iterate is reflected rather than just
165/// forward-stepped:
166/// ```text
167/// y_{k+1} = prox_{γg}(z_k)
168/// x_{k+1} = prox_{γf}(2y_{k+1} − z_k)
169/// z_{k+1} = 2x_{k+1} − (2y_{k+1} − z_k)
170/// ```
171///
172/// Converges faster when both f and g are strongly convex, but may diverge
173/// otherwise. Use `douglas_rachford` for general non-smooth problems.
174///
175/// # Arguments
176/// Same as `douglas_rachford`.
177pub fn peaceman_rachford(
178    prox_f: &dyn Fn(&[f64]) -> Vec<f64>,
179    prox_g: &dyn Fn(&[f64]) -> Vec<f64>,
180    x0: Vec<f64>,
181    _gamma: f64,
182    max_iter: usize,
183) -> Vec<f64> {
184    let mut z = x0;
185
186    for _ in 0..max_iter {
187        let y = prox_g(&z);
188        let refl_y: Vec<f64> = y
189            .iter()
190            .zip(z.iter())
191            .map(|(&yi, &zi)| 2.0 * yi - zi)
192            .collect();
193        let x = prox_f(&refl_y);
194        // z = 2x - reflect_y  (full reflection through x)
195        z = x
196            .iter()
197            .zip(refl_y.iter())
198            .map(|(&xi, &ri)| 2.0 * xi - ri)
199            .collect();
200    }
201
202    prox_g(&z)
203}
204
205// ─── Forward-Backward Splitting ──────────────────────────────────────────────
206
207/// Forward-backward splitting: `min f(x) + g(x)` where `f` is smooth.
208///
209/// Performs a gradient step on `f` followed by a proximal step on `g`:
210/// ```text
211/// x_{k+1} = prox_{αg}(x_k − α·∇f(x_k))
212/// ```
213///
214/// This is exactly ISTA generalised to arbitrary proximal operators.
215///
216/// # Arguments
217/// * `grad_f` - Gradient of smooth term f
218/// * `prox_g` - Proximal operator of non-smooth term g
219/// * `x0` - Initial point
220/// * `alpha` - Step size (1/Lipschitz constant of ∇f)
221/// * `max_iter` - Maximum iterations
222/// * `tol` - Convergence tolerance
223pub fn forward_backward(
224    grad_f: &dyn Fn(&[f64]) -> Vec<f64>,
225    prox_g: &dyn Fn(&[f64]) -> Vec<f64>,
226    x0: Vec<f64>,
227    alpha: f64,
228    max_iter: usize,
229    tol: f64,
230) -> Vec<f64> {
231    let mut x = x0;
232
233    for _ in 0..max_iter {
234        let g = grad_f(&x);
235        let x_grad: Vec<f64> = x
236            .iter()
237            .zip(g.iter())
238            .map(|(&xi, &gi)| xi - alpha * gi)
239            .collect();
240        let x_new = prox_g(&x_grad);
241
242        let diff: f64 = x
243            .iter()
244            .zip(x_new.iter())
245            .map(|(&a, &b)| (a - b) * (a - b))
246            .sum::<f64>()
247            .sqrt();
248
249        x = x_new;
250        if diff < tol {
251            break;
252        }
253    }
254    x
255}
256
257// ─── Primal-Dual (Chambolle-Pock) ────────────────────────────────────────────
258
259/// Primal-dual algorithm (Chambolle-Pock) for `min_x f(x) + g(Kx)`.
260///
261/// Iterates:
262/// ```text
263/// y_{k+1}   = prox_{σ g*}(y_k + σ·K·x_bar_k)
264/// x_{k+1}   = prox_{τ f}(x_k − τ·Kᵀ·y_{k+1})
265/// x_bar_{k+1} = x_{k+1} + θ·(x_{k+1} − x_k)
266/// ```
267///
268/// where `g*` is the convex conjugate of `g`.
269///
270/// # Arguments
271/// * `prox_f` - Proximal operator of f (scaled by τ)
272/// * `prox_g_conj` - Proximal operator of conjugate g* (scaled by σ)
273/// * `k_op` - Linear operator K: x → Kx
274/// * `kt_op` - Adjoint K*: y → Kᵀy
275/// * `x0` - Primal initial point
276/// * `y0` - Dual initial point
277/// * `tau` - Primal step size
278/// * `sigma` - Dual step size
279/// * `theta` - Over-relaxation (0 = no relaxation, 1 = full)
280/// * `max_iter` - Maximum iterations
281///
282/// # Returns
283/// `(x_star, y_star)` — primal and dual solutions.
284#[allow(clippy::too_many_arguments)]
285pub fn primal_dual_chambolle_pock(
286    prox_f: &dyn Fn(&[f64]) -> Vec<f64>,
287    prox_g_conj: &dyn Fn(&[f64]) -> Vec<f64>,
288    k_op: &dyn Fn(&[f64]) -> Vec<f64>,
289    kt_op: &dyn Fn(&[f64]) -> Vec<f64>,
290    x0: Vec<f64>,
291    y0: Vec<f64>,
292    tau: f64,
293    sigma: f64,
294    theta: f64,
295    max_iter: usize,
296) -> (Vec<f64>, Vec<f64>) {
297    let _ = (tau, sigma); // used implicitly through scaled prox operators
298    let mut x = x0;
299    let mut y = y0;
300    let mut x_bar = x.clone();
301
302    for _ in 0..max_iter {
303        let x_old = x.clone();
304
305        // Dual update
306        let kx_bar = k_op(&x_bar);
307        let y_input: Vec<f64> = y
308            .iter()
309            .zip(kx_bar.iter())
310            .map(|(&yi, &kxi)| yi + kxi)
311            .collect();
312        y = prox_g_conj(&y_input);
313
314        // Primal update
315        let kty = kt_op(&y);
316        let x_input: Vec<f64> = x
317            .iter()
318            .zip(kty.iter())
319            .map(|(&xi, &kti)| xi - kti)
320            .collect();
321        x = prox_f(&x_input);
322
323        // Over-relaxation
324        x_bar = x
325            .iter()
326            .zip(x_old.iter())
327            .map(|(&xn, &xo)| xn + theta * (xn - xo))
328            .collect();
329    }
330    (x, y)
331}
332
333/// Result of a splitting algorithm with diagnostics.
334#[derive(Debug, Clone)]
335pub struct SplittingResult {
336    /// Primal solution
337    pub x: Vec<f64>,
338    /// Number of iterations
339    pub nit: usize,
340    /// Whether convergence criterion was met
341    pub converged: bool,
342}
343
344/// Run Douglas-Rachford splitting and return a `SplittingResult`.
345pub fn dr_split(
346    prox_f: &dyn Fn(&[f64]) -> Vec<f64>,
347    prox_g: &dyn Fn(&[f64]) -> Vec<f64>,
348    x0: Vec<f64>,
349    gamma: f64,
350    max_iter: usize,
351    tol: f64,
352) -> Result<SplittingResult, OptimizeError> {
353    if gamma <= 0.0 {
354        return Err(OptimizeError::ValueError(
355            "gamma must be positive for Douglas-Rachford".to_string(),
356        ));
357    }
358    let res = douglas_rachford_tracked(prox_f, prox_g, x0, gamma, max_iter, tol);
359    Ok(SplittingResult {
360        x: res.x,
361        nit: res.nit,
362        converged: res.converged,
363    })
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369    use crate::proximal::operators::{prox_l1, prox_l2};
370    use approx::assert_abs_diff_eq;
371
372    /// Identity proximal (no regularization)
373    fn prox_id(v: &[f64]) -> Vec<f64> {
374        v.to_vec()
375    }
376
377    #[test]
378    fn test_douglas_rachford_l1_l2() {
379        // min ‖x‖₁ + ‖x‖₂² starting near [3, -2]
380        let lambda_l1 = 0.5;
381        let lambda_l2 = 0.5;
382        let prox_f = |v: &[f64]| prox_l1(v, lambda_l1);
383        let prox_g = |v: &[f64]| prox_l2(v, lambda_l2);
384        let x0 = vec![3.0, -2.0, 1.0];
385        let result = douglas_rachford(&prox_f, &prox_g, x0, 1.0, 500);
386        // Solution should be near 0 (L1 + L2 → sparsity near 0)
387        for &xi in &result {
388            assert!(xi.abs() < 1.0, "DR solution out of expected range: {}", xi);
389        }
390    }
391
392    #[test]
393    fn test_douglas_rachford_identity_prox() {
394        // When prox_g = identity, DR degenerates to: x = prox_f(2*x - z)
395        // which should drive x toward the fixed point of prox_f
396        let prox_f = |v: &[f64]| prox_l1(v, 1.0);
397        let x0 = vec![2.0, -3.0];
398        let result = douglas_rachford(&prox_f, &prox_id, x0, 1.0, 1000);
399        // prox_l1(·,1) fixed points: {x : |x| ≤ 1}
400        for &xi in &result {
401            assert!(xi.abs() <= 1.0 + 1e-8, "not in expected set: {}", xi);
402        }
403    }
404
405    #[test]
406    fn test_dr_tracked_convergence() {
407        let prox_f = |v: &[f64]| prox_l1(v, 0.3);
408        let prox_g = |v: &[f64]| prox_l2(v, 0.3);
409        let x0 = vec![2.0, -1.0];
410        let res = douglas_rachford_tracked(&prox_f, &prox_g, x0, 1.0, 2000, 1e-8);
411        assert!(res.converged, "DR should converge within 2000 iters");
412        assert!(res.nit < 2000, "DR should converge before max_iter");
413    }
414
415    #[test]
416    fn test_forward_backward_quadratic() {
417        // f(x) = ½‖x‖², prox_g = identity → x_{k+1} = x_k - α·x_k = (1-α)·x_k
418        let grad_f = |x: &[f64]| x.to_vec();
419        let x0 = vec![3.0, -2.0];
420        let result = forward_backward(&grad_f, &prox_id, x0, 0.5, 500, 1e-8);
421        for &xi in &result {
422            assert_abs_diff_eq!(xi, 0.0, epsilon = 1e-4);
423        }
424    }
425
426    #[test]
427    fn test_peaceman_rachford_converges() {
428        let prox_f = |v: &[f64]| prox_l2(v, 0.5);
429        let prox_g = |v: &[f64]| prox_l2(v, 0.5);
430        let x0 = vec![2.0, -1.5];
431        let result = peaceman_rachford(&prox_f, &prox_g, x0, 1.0, 500);
432        for &xi in &result {
433            assert_abs_diff_eq!(xi, 0.0, epsilon = 0.1);
434        }
435    }
436
437    #[test]
438    fn test_dr_split_negative_gamma() {
439        let prox_f = |v: &[f64]| v.to_vec();
440        let prox_g = |v: &[f64]| v.to_vec();
441        let result = dr_split(&prox_f, &prox_g, vec![1.0], -1.0, 10, 1e-6);
442        assert!(result.is_err());
443    }
444
445    #[test]
446    fn test_primal_dual_basic() {
447        // trivial: K = I, f = ½‖·‖², g(y) = ½‖y‖²
448        // Solution: x* = 0
449        let prox_f = |v: &[f64]| prox_l2(v, 0.5);
450        let prox_g_conj = |v: &[f64]| prox_l2(v, 0.5);
451        let k_op = |x: &[f64]| x.to_vec();
452        let kt_op = |y: &[f64]| y.to_vec();
453        let x0 = vec![2.0, -1.0];
454        let y0 = vec![0.0, 0.0];
455        let (x_star, _) = primal_dual_chambolle_pock(
456            &prox_f,
457            &prox_g_conj,
458            &k_op,
459            &kt_op,
460            x0,
461            y0,
462            0.5,
463            0.5,
464            1.0,
465            500,
466        );
467        for &xi in &x_star {
468            assert_abs_diff_eq!(xi, 0.0, epsilon = 0.1);
469        }
470    }
471}