scirs2_optimize/proximal/splitting.rs
1//! Operator Splitting Methods
2//!
3//! This module provides splitting algorithms for optimising sums of non-smooth
4//! convex functions, where no single proximal operator is available for the
5//! combined objective.
6//!
7//! # Algorithms
8//!
9//! ## Douglas-Rachford Splitting
10//! Minimises `f(x) + g(x)` using only `prox_f` and `prox_g`:
11//! ```text
12//! y_{k+1} = prox_{γg}(z_k)
13//! x_{k+1} = prox_{γf}(2y_{k+1} − z_k)
14//! z_{k+1} = z_k + x_{k+1} − y_{k+1}
15//! ```
16//!
17//! ## Peaceman-Rachford Splitting
18//! A less damped variant that requires strong monotonicity to converge.
19//!
20//! ## Primal-Dual (Chambolle-Pock)
21//! Solves `min_x f(x) + g(Kx)` where `K` is a linear operator.
22//!
23//! # References
24//! - Lions & Mercier (1979). "Splitting Algorithms for the Sum of Two Nonlinear
25//! Operators". *SIAM J. Numer. Anal.*
26//! - Eckstein & Bertsekas (1992). "On the Douglas-Rachford Splitting Method".
27//! *Math. Programming.*
28//! - Chambolle & Pock (2011). "A First-Order Primal-Dual Algorithm for Convex
29//! Problems with Applications to Imaging". *J. Math. Imaging Vision.*
30
31use crate::error::OptimizeError;
32
33// ─── Douglas-Rachford Splitting ──────────────────────────────────────────────
34
35/// Minimise `f(x) + g(x)` using Douglas-Rachford (DR) splitting.
36///
37/// The algorithm only requires the proximal operators of `f` and `g`
38/// separately and does not require differentiability.
39///
40/// # Convergence
41/// Converges for any pair of proper, closed, convex functions when γ > 0.
42/// The fixed-point iterates `{z_k}` converge; the actual solution is
43/// `prox_{γg}(z_∞)`.
44///
45/// # Arguments
46/// * `prox_f` - Proximal operator of f: `prox_{γf}(·)`
47/// * `prox_g` - Proximal operator of g: `prox_{γg}(·)`
48/// * `x0` - Starting point (initialises z₀ = x₀)
49/// * `gamma` - Step size / scaling parameter (γ > 0)
50/// * `max_iter` - Maximum number of DR iterations
51///
52/// # Returns
53/// The approximate minimiser `x* = prox_{γg}(z_∞)`.
54pub fn douglas_rachford(
55 prox_f: &dyn Fn(&[f64]) -> Vec<f64>,
56 prox_g: &dyn Fn(&[f64]) -> Vec<f64>,
57 x0: Vec<f64>,
58 gamma: f64,
59 max_iter: usize,
60) -> Vec<f64> {
61 let _n = x0.len();
62 let mut z = x0;
63
64 for _ in 0..max_iter {
65 let y = prox_g(&z);
66 let two_y_minus_z: Vec<f64> = y
67 .iter()
68 .zip(z.iter())
69 .map(|(&yi, &zi)| 2.0 * yi - zi)
70 .collect();
71 let x = prox_f(&two_y_minus_z);
72 // z_{k+1} = z_k + x_{k+1} - y_{k+1}
73 z = z
74 .iter()
75 .zip(x.iter().zip(y.iter()))
76 .map(|(&zk, (&xk1, &yk1))| zk + xk1 - yk1)
77 .collect();
78 }
79
80 // Recover solution: x* = prox_g(z)
81 prox_g(&z)
82}
83
84/// Douglas-Rachford splitting with convergence tracking.
85///
86/// Returns the solution along with convergence diagnostics.
87///
88/// # Arguments
89/// Same as `douglas_rachford`, plus:
90/// * `tol` - Convergence tolerance on ‖z_{k+1} − z_k‖
91pub fn douglas_rachford_tracked(
92 prox_f: &dyn Fn(&[f64]) -> Vec<f64>,
93 prox_g: &dyn Fn(&[f64]) -> Vec<f64>,
94 x0: Vec<f64>,
95 gamma: f64,
96 max_iter: usize,
97 tol: f64,
98) -> DRResult {
99 let n = x0.len();
100 let mut z = x0;
101 let _ = gamma; // gamma is used implicitly through the prox scaling
102
103 for iter in 0..max_iter {
104 let z_prev = z.clone();
105
106 let y = prox_g(&z);
107 let two_y_minus_z: Vec<f64> = y
108 .iter()
109 .zip(z.iter())
110 .map(|(&yi, &zi)| 2.0 * yi - zi)
111 .collect();
112 let x = prox_f(&two_y_minus_z);
113 z = z
114 .iter()
115 .zip(x.iter().zip(y.iter()))
116 .map(|(&zk, (&xk1, &yk1))| zk + xk1 - yk1)
117 .collect();
118
119 let dz: f64 = z
120 .iter()
121 .zip(z_prev.iter())
122 .map(|(&a, &b)| (a - b) * (a - b))
123 .sum::<f64>()
124 .sqrt();
125
126 if dz < tol {
127 let x_star = prox_g(&z);
128 return DRResult {
129 x: x_star,
130 nit: iter + 1,
131 converged: true,
132 final_residual: dz,
133 };
134 }
135 }
136
137 let x_star = prox_g(&z);
138 let final_res: f64 = 0.0; // Would need extra iteration to compute
139 DRResult {
140 x: x_star,
141 nit: max_iter,
142 converged: false,
143 final_residual: final_res,
144 }
145}
146
147/// Result of a tracked Douglas-Rachford run.
148#[derive(Debug, Clone)]
149pub struct DRResult {
150 /// Approximate minimiser
151 pub x: Vec<f64>,
152 /// Number of iterations performed
153 pub nit: usize,
154 /// Whether convergence was achieved
155 pub converged: bool,
156 /// Final ‖z_{k+1} − z_k‖ residual
157 pub final_residual: f64,
158}
159
160// ─── Peaceman-Rachford Splitting ─────────────────────────────────────────────
161
162/// Peaceman-Rachford splitting (less damped variant of DR).
163///
164/// Unlike DR, the intermediate iterate is reflected rather than just
165/// forward-stepped:
166/// ```text
167/// y_{k+1} = prox_{γg}(z_k)
168/// x_{k+1} = prox_{γf}(2y_{k+1} − z_k)
169/// z_{k+1} = 2x_{k+1} − (2y_{k+1} − z_k)
170/// ```
171///
172/// Converges faster when both f and g are strongly convex, but may diverge
173/// otherwise. Use `douglas_rachford` for general non-smooth problems.
174///
175/// # Arguments
176/// Same as `douglas_rachford`.
177pub fn peaceman_rachford(
178 prox_f: &dyn Fn(&[f64]) -> Vec<f64>,
179 prox_g: &dyn Fn(&[f64]) -> Vec<f64>,
180 x0: Vec<f64>,
181 _gamma: f64,
182 max_iter: usize,
183) -> Vec<f64> {
184 let mut z = x0;
185
186 for _ in 0..max_iter {
187 let y = prox_g(&z);
188 let refl_y: Vec<f64> = y
189 .iter()
190 .zip(z.iter())
191 .map(|(&yi, &zi)| 2.0 * yi - zi)
192 .collect();
193 let x = prox_f(&refl_y);
194 // z = 2x - reflect_y (full reflection through x)
195 z = x
196 .iter()
197 .zip(refl_y.iter())
198 .map(|(&xi, &ri)| 2.0 * xi - ri)
199 .collect();
200 }
201
202 prox_g(&z)
203}
204
205// ─── Forward-Backward Splitting ──────────────────────────────────────────────
206
207/// Forward-backward splitting: `min f(x) + g(x)` where `f` is smooth.
208///
209/// Performs a gradient step on `f` followed by a proximal step on `g`:
210/// ```text
211/// x_{k+1} = prox_{αg}(x_k − α·∇f(x_k))
212/// ```
213///
214/// This is exactly ISTA generalised to arbitrary proximal operators.
215///
216/// # Arguments
217/// * `grad_f` - Gradient of smooth term f
218/// * `prox_g` - Proximal operator of non-smooth term g
219/// * `x0` - Initial point
220/// * `alpha` - Step size (1/Lipschitz constant of ∇f)
221/// * `max_iter` - Maximum iterations
222/// * `tol` - Convergence tolerance
223pub fn forward_backward(
224 grad_f: &dyn Fn(&[f64]) -> Vec<f64>,
225 prox_g: &dyn Fn(&[f64]) -> Vec<f64>,
226 x0: Vec<f64>,
227 alpha: f64,
228 max_iter: usize,
229 tol: f64,
230) -> Vec<f64> {
231 let mut x = x0;
232
233 for _ in 0..max_iter {
234 let g = grad_f(&x);
235 let x_grad: Vec<f64> = x
236 .iter()
237 .zip(g.iter())
238 .map(|(&xi, &gi)| xi - alpha * gi)
239 .collect();
240 let x_new = prox_g(&x_grad);
241
242 let diff: f64 = x
243 .iter()
244 .zip(x_new.iter())
245 .map(|(&a, &b)| (a - b) * (a - b))
246 .sum::<f64>()
247 .sqrt();
248
249 x = x_new;
250 if diff < tol {
251 break;
252 }
253 }
254 x
255}
256
257// ─── Primal-Dual (Chambolle-Pock) ────────────────────────────────────────────
258
259/// Primal-dual algorithm (Chambolle-Pock) for `min_x f(x) + g(Kx)`.
260///
261/// Iterates:
262/// ```text
263/// y_{k+1} = prox_{σ g*}(y_k + σ·K·x_bar_k)
264/// x_{k+1} = prox_{τ f}(x_k − τ·Kᵀ·y_{k+1})
265/// x_bar_{k+1} = x_{k+1} + θ·(x_{k+1} − x_k)
266/// ```
267///
268/// where `g*` is the convex conjugate of `g`.
269///
270/// # Arguments
271/// * `prox_f` - Proximal operator of f (scaled by τ)
272/// * `prox_g_conj` - Proximal operator of conjugate g* (scaled by σ)
273/// * `k_op` - Linear operator K: x → Kx
274/// * `kt_op` - Adjoint K*: y → Kᵀy
275/// * `x0` - Primal initial point
276/// * `y0` - Dual initial point
277/// * `tau` - Primal step size
278/// * `sigma` - Dual step size
279/// * `theta` - Over-relaxation (0 = no relaxation, 1 = full)
280/// * `max_iter` - Maximum iterations
281///
282/// # Returns
283/// `(x_star, y_star)` — primal and dual solutions.
284#[allow(clippy::too_many_arguments)]
285pub fn primal_dual_chambolle_pock(
286 prox_f: &dyn Fn(&[f64]) -> Vec<f64>,
287 prox_g_conj: &dyn Fn(&[f64]) -> Vec<f64>,
288 k_op: &dyn Fn(&[f64]) -> Vec<f64>,
289 kt_op: &dyn Fn(&[f64]) -> Vec<f64>,
290 x0: Vec<f64>,
291 y0: Vec<f64>,
292 tau: f64,
293 sigma: f64,
294 theta: f64,
295 max_iter: usize,
296) -> (Vec<f64>, Vec<f64>) {
297 let _ = (tau, sigma); // used implicitly through scaled prox operators
298 let mut x = x0;
299 let mut y = y0;
300 let mut x_bar = x.clone();
301
302 for _ in 0..max_iter {
303 let x_old = x.clone();
304
305 // Dual update
306 let kx_bar = k_op(&x_bar);
307 let y_input: Vec<f64> = y
308 .iter()
309 .zip(kx_bar.iter())
310 .map(|(&yi, &kxi)| yi + kxi)
311 .collect();
312 y = prox_g_conj(&y_input);
313
314 // Primal update
315 let kty = kt_op(&y);
316 let x_input: Vec<f64> = x
317 .iter()
318 .zip(kty.iter())
319 .map(|(&xi, &kti)| xi - kti)
320 .collect();
321 x = prox_f(&x_input);
322
323 // Over-relaxation
324 x_bar = x
325 .iter()
326 .zip(x_old.iter())
327 .map(|(&xn, &xo)| xn + theta * (xn - xo))
328 .collect();
329 }
330 (x, y)
331}
332
333/// Result of a splitting algorithm with diagnostics.
334#[derive(Debug, Clone)]
335pub struct SplittingResult {
336 /// Primal solution
337 pub x: Vec<f64>,
338 /// Number of iterations
339 pub nit: usize,
340 /// Whether convergence criterion was met
341 pub converged: bool,
342}
343
344/// Run Douglas-Rachford splitting and return a `SplittingResult`.
345pub fn dr_split(
346 prox_f: &dyn Fn(&[f64]) -> Vec<f64>,
347 prox_g: &dyn Fn(&[f64]) -> Vec<f64>,
348 x0: Vec<f64>,
349 gamma: f64,
350 max_iter: usize,
351 tol: f64,
352) -> Result<SplittingResult, OptimizeError> {
353 if gamma <= 0.0 {
354 return Err(OptimizeError::ValueError(
355 "gamma must be positive for Douglas-Rachford".to_string(),
356 ));
357 }
358 let res = douglas_rachford_tracked(prox_f, prox_g, x0, gamma, max_iter, tol);
359 Ok(SplittingResult {
360 x: res.x,
361 nit: res.nit,
362 converged: res.converged,
363 })
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369 use crate::proximal::operators::{prox_l1, prox_l2};
370 use approx::assert_abs_diff_eq;
371
372 /// Identity proximal (no regularization)
373 fn prox_id(v: &[f64]) -> Vec<f64> {
374 v.to_vec()
375 }
376
377 #[test]
378 fn test_douglas_rachford_l1_l2() {
379 // min ‖x‖₁ + ‖x‖₂² starting near [3, -2]
380 let lambda_l1 = 0.5;
381 let lambda_l2 = 0.5;
382 let prox_f = |v: &[f64]| prox_l1(v, lambda_l1);
383 let prox_g = |v: &[f64]| prox_l2(v, lambda_l2);
384 let x0 = vec![3.0, -2.0, 1.0];
385 let result = douglas_rachford(&prox_f, &prox_g, x0, 1.0, 500);
386 // Solution should be near 0 (L1 + L2 → sparsity near 0)
387 for &xi in &result {
388 assert!(xi.abs() < 1.0, "DR solution out of expected range: {}", xi);
389 }
390 }
391
392 #[test]
393 fn test_douglas_rachford_identity_prox() {
394 // When prox_g = identity, DR degenerates to: x = prox_f(2*x - z)
395 // which should drive x toward the fixed point of prox_f
396 let prox_f = |v: &[f64]| prox_l1(v, 1.0);
397 let x0 = vec![2.0, -3.0];
398 let result = douglas_rachford(&prox_f, &prox_id, x0, 1.0, 1000);
399 // prox_l1(·,1) fixed points: {x : |x| ≤ 1}
400 for &xi in &result {
401 assert!(xi.abs() <= 1.0 + 1e-8, "not in expected set: {}", xi);
402 }
403 }
404
405 #[test]
406 fn test_dr_tracked_convergence() {
407 let prox_f = |v: &[f64]| prox_l1(v, 0.3);
408 let prox_g = |v: &[f64]| prox_l2(v, 0.3);
409 let x0 = vec![2.0, -1.0];
410 let res = douglas_rachford_tracked(&prox_f, &prox_g, x0, 1.0, 2000, 1e-8);
411 assert!(res.converged, "DR should converge within 2000 iters");
412 assert!(res.nit < 2000, "DR should converge before max_iter");
413 }
414
415 #[test]
416 fn test_forward_backward_quadratic() {
417 // f(x) = ½‖x‖², prox_g = identity → x_{k+1} = x_k - α·x_k = (1-α)·x_k
418 let grad_f = |x: &[f64]| x.to_vec();
419 let x0 = vec![3.0, -2.0];
420 let result = forward_backward(&grad_f, &prox_id, x0, 0.5, 500, 1e-8);
421 for &xi in &result {
422 assert_abs_diff_eq!(xi, 0.0, epsilon = 1e-4);
423 }
424 }
425
426 #[test]
427 fn test_peaceman_rachford_converges() {
428 let prox_f = |v: &[f64]| prox_l2(v, 0.5);
429 let prox_g = |v: &[f64]| prox_l2(v, 0.5);
430 let x0 = vec![2.0, -1.5];
431 let result = peaceman_rachford(&prox_f, &prox_g, x0, 1.0, 500);
432 for &xi in &result {
433 assert_abs_diff_eq!(xi, 0.0, epsilon = 0.1);
434 }
435 }
436
437 #[test]
438 fn test_dr_split_negative_gamma() {
439 let prox_f = |v: &[f64]| v.to_vec();
440 let prox_g = |v: &[f64]| v.to_vec();
441 let result = dr_split(&prox_f, &prox_g, vec![1.0], -1.0, 10, 1e-6);
442 assert!(result.is_err());
443 }
444
445 #[test]
446 fn test_primal_dual_basic() {
447 // trivial: K = I, f = ½‖·‖², g(y) = ½‖y‖²
448 // Solution: x* = 0
449 let prox_f = |v: &[f64]| prox_l2(v, 0.5);
450 let prox_g_conj = |v: &[f64]| prox_l2(v, 0.5);
451 let k_op = |x: &[f64]| x.to_vec();
452 let kt_op = |y: &[f64]| y.to_vec();
453 let x0 = vec![2.0, -1.0];
454 let y0 = vec![0.0, 0.0];
455 let (x_star, _) = primal_dual_chambolle_pock(
456 &prox_f,
457 &prox_g_conj,
458 &k_op,
459 &kt_op,
460 x0,
461 y0,
462 0.5,
463 0.5,
464 1.0,
465 500,
466 );
467 for &xi in &x_star {
468 assert_abs_diff_eq!(xi, 0.0, epsilon = 0.1);
469 }
470 }
471}