Skip to main content

scirs2_optimize/minimax/
mod.rs

1//! Minimax Optimization and Saddle-Point Problems
2//!
3//! This module provides algorithms for solving minimax (saddle-point) problems of the form:
4//!
5//! ```text
6//! min_x  max_y  f(x, y)
7//! ```
8//!
9//! Such problems arise in:
10//! - Game theory (zero-sum games)
11//! - Generative Adversarial Networks (GANs)
12//! - Robust optimization (worst-case formulations)
13//! - Constrained optimization (Lagrangian duality)
14//!
15//! # Algorithms
16//!
17//! | Function | Method | Convergence guarantee |
18//! |----------|--------|-----------------------|
19//! | [`minimax_solve`] | Gradient Descent-Ascent (GDA) | Convex-concave |
20//! | [`extragradient_solve`] | Extragradient (Korpelevich) | Monotone VI |
21//! | [`primal_dual`] | Primal-Dual splitting | Convex-concave |
22//!
23//! # References
24//!
25//! - Korpelevich, G.M. (1976). "The extragradient method for finding saddle points and
26//!   other problems". *Ekonomika i Matematicheskie Metody*.
27//! - Chambolle, A. & Pock, T. (2011). "A first-order primal-dual algorithm for convex
28//!   problems with applications to imaging". *JMIV*.
29//! - Tseng, P. (1995). "On linear convergence of iterative methods for the variational
30//!   inequality problem". *JOTA*.
31//! - Gidel, G. et al. (2019). "A Variational Inequality Perspective on Generative
32//!   Adversarial Networks". *ICLR*.
33
34use crate::error::{OptimizeError, OptimizeResult};
35use scirs2_core::ndarray::{Array1, ArrayView1};
36
37// ─── Configuration ───────────────────────────────────────────────────────────
38
39/// Configuration for minimax / saddle-point solvers.
40#[derive(Debug, Clone)]
41pub struct MinimaxConfig {
42    /// Maximum number of iterations.
43    pub max_iter: usize,
44    /// Convergence tolerance: stop when ‖x-x_prev‖ + ‖y-y_prev‖ < tol.
45    pub tol: f64,
46    /// Step size for the primal (minimisation) player.
47    pub step_size_x: f64,
48    /// Step size for the dual (maximisation) player.
49    pub step_size_y: f64,
50    /// Finite-difference step for gradient estimation.
51    pub fd_step: f64,
52    /// Whether to print progress every `print_every` iterations (0 = silent).
53    pub print_every: usize,
54}
55
56impl Default for MinimaxConfig {
57    fn default() -> Self {
58        Self {
59            max_iter: 5_000,
60            tol: 1e-6,
61            step_size_x: 1e-3,
62            step_size_y: 1e-3,
63            fd_step: 1e-5,
64            print_every: 0,
65        }
66    }
67}
68
69/// Result from a minimax / saddle-point solve.
70#[derive(Debug, Clone)]
71pub struct MinimaxResult {
72    /// Approximate primal minimiser x*.
73    pub x: Array1<f64>,
74    /// Approximate dual maximiser y*.
75    pub y: Array1<f64>,
76    /// Saddle-point value f(x*, y*).
77    pub fun: f64,
78    /// Number of iterations performed.
79    pub n_iter: usize,
80    /// Primal-dual gap at termination (lower is better; 0 at exact saddle point).
81    pub gap: f64,
82    /// Whether the algorithm converged within tolerance.
83    pub converged: bool,
84    /// Status message.
85    pub message: String,
86}
87
88// ─── Finite-difference helpers ───────────────────────────────────────────────
89
90/// Gradient of f(·, y) with respect to x (primal gradient; descent direction).
91fn grad_x<F>(f: &F, x: &ArrayView1<f64>, y: &ArrayView1<f64>, h: f64) -> Array1<f64>
92where
93    F: Fn(&ArrayView1<f64>, &ArrayView1<f64>) -> f64,
94{
95    let n = x.len();
96    let f0 = f(x, y);
97    let mut g = Array1::<f64>::zeros(n);
98    let mut x_fwd = x.to_owned();
99    for i in 0..n {
100        x_fwd[i] += h;
101        g[i] = (f(&x_fwd.view(), y) - f0) / h;
102        x_fwd[i] = x[i];
103    }
104    g
105}
106
107/// Gradient of f(x, ·) with respect to y (dual gradient; ascent direction).
108fn grad_y<F>(f: &F, x: &ArrayView1<f64>, y: &ArrayView1<f64>, h: f64) -> Array1<f64>
109where
110    F: Fn(&ArrayView1<f64>, &ArrayView1<f64>) -> f64,
111{
112    let m = y.len();
113    let f0 = f(x, y);
114    let mut g = Array1::<f64>::zeros(m);
115    let mut y_fwd = y.to_owned();
116    for i in 0..m {
117        y_fwd[i] += h;
118        g[i] = (f(x, &y_fwd.view()) - f0) / h;
119        y_fwd[i] = y[i];
120    }
121    g
122}
123
124#[inline]
125fn vec_norm(v: &Array1<f64>) -> f64 {
126    v.iter().map(|vi| vi * vi).sum::<f64>().sqrt()
127}
128
129// ─── Gradient Descent-Ascent ─────────────────────────────────────────────────
130
131/// Solve a minimax problem with Gradient Descent-Ascent (GDA).
132///
133/// Simultaneously performs:
134/// - Gradient descent on x:  xₖ₊₁ = xₖ - ηₓ ∇ₓ f(xₖ, yₖ)
135/// - Gradient ascent  on y:  yₖ₊₁ = yₖ + ηᵧ ∇ᵧ f(xₖ, yₖ)
136///
137/// GDA converges to the unique saddle point for convex-concave problems.
138/// For non-convex/non-concave problems it may cycle; use [`extragradient_solve`]
139/// for more robust behaviour.
140///
141/// # Arguments
142///
143/// * `f`      – objective: (x, y) → f64
144/// * `x0`     – initial primal point
145/// * `y0`     – initial dual point
146/// * `config` – solver configuration
147///
148/// # Returns
149///
150/// [`MinimaxResult`] containing the approximate saddle point.
151pub fn minimax_solve<F>(
152    f: &F,
153    x0: &ArrayView1<f64>,
154    y0: &ArrayView1<f64>,
155    config: &MinimaxConfig,
156) -> OptimizeResult<MinimaxResult>
157where
158    F: Fn(&ArrayView1<f64>, &ArrayView1<f64>) -> f64,
159{
160    let nx = x0.len();
161    let ny = y0.len();
162    if nx == 0 || ny == 0 {
163        return Err(OptimizeError::ValueError(
164            "x0 and y0 must be non-empty".to_string(),
165        ));
166    }
167
168    let mut x = x0.to_owned();
169    let mut y = y0.to_owned();
170    let mut converged = false;
171    let h = config.fd_step;
172
173    for k in 0..config.max_iter {
174        let gx = grad_x(f, &x.view(), &y.view(), h);
175        let gy = grad_y(f, &x.view(), &y.view(), h);
176
177        // Simultaneous update
178        let mut dx_norm = 0.0_f64;
179        let mut dy_norm = 0.0_f64;
180        for i in 0..nx {
181            let step = config.step_size_x * gx[i];
182            x[i] -= step;
183            dx_norm += step * step;
184        }
185        for i in 0..ny {
186            let step = config.step_size_y * gy[i];
187            y[i] += step;
188            dy_norm += step * step;
189        }
190
191        let delta = dx_norm.sqrt() + dy_norm.sqrt();
192        if delta < config.tol {
193            converged = true;
194            if config.print_every > 0 {
195                eprintln!("[GDA] converged at iteration {}", k + 1);
196            }
197            break;
198        }
199        if config.print_every > 0 && (k + 1) % config.print_every == 0 {
200            eprintln!("[GDA] iter {}: delta={:.2e}", k + 1, delta);
201        }
202    }
203
204    let fun = f(&x.view(), &y.view());
205    let gap = compute_gap(f, &x.view(), &y.view(), h);
206
207    Ok(MinimaxResult {
208        x,
209        y,
210        fun,
211        n_iter: config.max_iter,
212        gap,
213        converged,
214        message: if converged {
215            "GDA converged".to_string()
216        } else {
217            "GDA reached maximum iterations".to_string()
218        },
219    })
220}
221
222// ─── Extragradient ───────────────────────────────────────────────────────────
223
224/// Solve a minimax problem (or monotone variational inequality) with the
225/// Extragradient method (Korpelevich 1976).
226///
227/// The extragradient method performs a *prediction* step followed by a
228/// *correction* step, which eliminates the oscillations of plain GDA:
229///
230/// ```text
231/// Prediction: x̄ = xₖ - ηₓ ∇ₓ f(xₖ, yₖ)
232///             ȳ = yₖ + ηᵧ ∇ᵧ f(xₖ, yₖ)
233/// Correction: xₖ₊₁ = xₖ - ηₓ ∇ₓ f(x̄, ȳ)
234///             yₖ₊₁ = yₖ + ηᵧ ∇ᵧ f(x̄, ȳ)
235/// ```
236///
237/// Converges for monotone variational inequalities (includes convex-concave games).
238///
239/// # Arguments
240///
241/// * `f`      – objective: (x, y) → f64
242/// * `x0`     – initial primal point
243/// * `y0`     – initial dual point
244/// * `config` – solver configuration
245pub fn extragradient_solve<F>(
246    f: &F,
247    x0: &ArrayView1<f64>,
248    y0: &ArrayView1<f64>,
249    config: &MinimaxConfig,
250) -> OptimizeResult<MinimaxResult>
251where
252    F: Fn(&ArrayView1<f64>, &ArrayView1<f64>) -> f64,
253{
254    let nx = x0.len();
255    let ny = y0.len();
256    if nx == 0 || ny == 0 {
257        return Err(OptimizeError::ValueError(
258            "x0 and y0 must be non-empty".to_string(),
259        ));
260    }
261
262    let mut x = x0.to_owned();
263    let mut y = y0.to_owned();
264    // Running Cesaro (ergodic) averages – these are returned instead of the
265    // last iterate.  The Cesaro average converges at O(1/(η T)) in the primal-
266    // dual gap for monotone VIs even when the last iterates orbit slowly (e.g.
267    // bilinear games under rotation).  We always run to `max_iter` so the
268    // accumulation is not cut short by the early-stop heuristic.
269    let mut x_avg = Array1::<f64>::zeros(nx);
270    let mut y_avg = Array1::<f64>::zeros(ny);
271    let mut converged = false;
272    let h = config.fd_step;
273    let mut n_iters_done = 0usize;
274    let mut converged_at = config.max_iter; // iteration index where convergence was declared
275
276    for k in 0..config.max_iter {
277        // ── Prediction step ──────────────────────────────────────────────────
278        let gx_k = grad_x(f, &x.view(), &y.view(), h);
279        let gy_k = grad_y(f, &x.view(), &y.view(), h);
280
281        let x_bar: Array1<f64> = x
282            .iter()
283            .zip(gx_k.iter())
284            .map(|(&xi, &gi)| xi - config.step_size_x * gi)
285            .collect();
286        let y_bar: Array1<f64> = y
287            .iter()
288            .zip(gy_k.iter())
289            .map(|(&yi, &gi)| yi + config.step_size_y * gi)
290            .collect();
291
292        // ── Correction step ──────────────────────────────────────────────────
293        let gx_bar = grad_x(f, &x_bar.view(), &y_bar.view(), h);
294        let gy_bar = grad_y(f, &x_bar.view(), &y_bar.view(), h);
295
296        let mut delta = 0.0_f64;
297        for i in 0..nx {
298            let step = config.step_size_x * gx_bar[i];
299            x[i] -= step;
300            delta += step * step;
301        }
302        for i in 0..ny {
303            let step = config.step_size_y * gy_bar[i];
304            y[i] += step;
305            delta += step * step;
306        }
307
308        // Accumulate running Cesaro average using the current (post-correction) iterate.
309        let t = (k + 1) as f64;
310        for i in 0..nx {
311            x_avg[i] += (x[i] - x_avg[i]) / t;
312        }
313        for i in 0..ny {
314            y_avg[i] += (y[i] - y_avg[i]) / t;
315        }
316
317        n_iters_done = k + 1;
318
319        if !converged && delta.sqrt() < config.tol {
320            converged = true;
321            converged_at = k + 1;
322            if config.print_every > 0 {
323                eprintln!("[EG] converged at iteration {}", k + 1);
324            }
325            // Do NOT break – continue iterating so the Cesaro average accumulates
326            // contributions from the converged (near-zero) iterates, which dilutes
327            // the influence of the large initial transient.
328        }
329        if config.print_every > 0 && (k + 1) % config.print_every == 0 {
330            eprintln!("[EG] iter {}: delta={:.2e}", k + 1, delta.sqrt());
331        }
332    }
333    let _ = converged_at; // informational only
334
335    // Return the Cesaro-averaged iterates.  For strongly convex-concave problems
336    // the iterate contracts geometrically, so by the time all max_iter steps are
337    // done the average is dominated by the near-zero converged tail.  For general
338    // monotone VIs (e.g. bilinear games) the Cesaro average converges at O(1/T)
339    // in the primal-dual gap even when the last iterate oscillates.
340    let x_out = x_avg;
341    let y_out = y_avg;
342
343    let fun = f(&x_out.view(), &y_out.view());
344    let gap = compute_gap(f, &x_out.view(), &y_out.view(), h);
345
346    Ok(MinimaxResult {
347        x: x_out,
348        y: y_out,
349        fun,
350        n_iter: n_iters_done,
351        gap,
352        converged,
353        message: if converged {
354            "Extragradient converged".to_string()
355        } else {
356            "Extragradient reached maximum iterations".to_string()
357        },
358    })
359}
360
361// ─── Primal-Dual Splitting ───────────────────────────────────────────────────
362
363/// Options for the primal-dual splitting method.
364#[derive(Debug, Clone)]
365pub struct PrimalDualConfig {
366    /// Maximum number of iterations.
367    pub max_iter: usize,
368    /// Convergence tolerance.
369    pub tol: f64,
370    /// Primal step size σ (should satisfy σ τ ‖K‖² < 1 for convergence).
371    pub sigma: f64,
372    /// Dual step size τ.
373    pub tau: f64,
374    /// Finite-difference step for gradient approximation.
375    pub fd_step: f64,
376}
377
378impl Default for PrimalDualConfig {
379    fn default() -> Self {
380        Self {
381            max_iter: 5_000,
382            tol: 1e-6,
383            sigma: 1e-3,
384            tau: 1e-3,
385            fd_step: 1e-5,
386        }
387    }
388}
389
390/// Chambolle-Pock primal-dual splitting for convex-concave saddle-point problems.
391///
392/// Solves:
393/// ```text
394/// min_x  max_y  primal_fn(x) + <K x, y> - dual_fn(y)
395/// ```
396///
397/// using the over-relaxed primal-dual update:
398/// ```text
399/// yₖ₊₁ = prox_{τ dual_fn*}(yₖ + τ K x̄ₖ)
400/// xₖ₊₁ = prox_{σ primal_fn}(xₖ - σ Kᵀ yₖ₊₁)
401/// x̄ₖ₊₁ = 2 xₖ₊₁ - xₖ
402/// ```
403///
404/// In the gradient-based formulation used here, `primal_fn` and `dual_fn` are
405/// evaluated via their gradients (no prox operators required).  This reduces
406/// to a form of gradient descent-ascent with over-relaxation.
407///
408/// # Arguments
409///
410/// * `primal_fn` – primal objective ∂g(x) (gradient of g w.r.t. x)
411/// * `dual_fn`   – dual objective ∂h(y) (gradient of h w.r.t. y)
412/// * `x0`        – initial primal point
413/// * `y0`        – initial dual point
414/// * `config`    – solver configuration
415///
416/// # Returns
417///
418/// `(x*, y*)` approximate saddle point.
419pub fn primal_dual<Px, Py>(
420    primal_fn: &Px,
421    dual_fn: &Py,
422    x0: &ArrayView1<f64>,
423    y0: &ArrayView1<f64>,
424    config: &PrimalDualConfig,
425) -> OptimizeResult<(Array1<f64>, Array1<f64>)>
426where
427    Px: Fn(&ArrayView1<f64>) -> Array1<f64>,
428    Py: Fn(&ArrayView1<f64>) -> Array1<f64>,
429{
430    let nx = x0.len();
431    let ny = y0.len();
432    if nx == 0 || ny == 0 {
433        return Err(OptimizeError::ValueError(
434            "x0 and y0 must be non-empty".to_string(),
435        ));
436    }
437
438    let mut x = x0.to_owned();
439    let mut y = y0.to_owned();
440    // Over-relaxation variable (extrapolated primal)
441    let mut x_bar = x.clone();
442
443    for _k in 0..config.max_iter {
444        // ── Dual update (gradient ascent on dual objective) ──────────────────
445        let gy = dual_fn(&y.view());
446        // y ← y + τ (gradient contribution from x_bar) - τ * dual gradient
447        // In the simple decoupled case: y_{k+1} = y + τ * dual_grad(y)
448        let y_new: Array1<f64> = y
449            .iter()
450            .zip(gy.iter())
451            .map(|(&yi, &gyi)| yi + config.tau * gyi)
452            .collect();
453
454        // ── Primal update (gradient descent on primal objective) ─────────────
455        let gx = primal_fn(&x.view());
456        let x_new: Array1<f64> = x
457            .iter()
458            .zip(gx.iter())
459            .map(|(&xi, &gxi)| xi - config.sigma * gxi)
460            .collect();
461
462        // ── Over-relaxation: x̄ = 2 x_new - x ───────────────────────────────
463        let x_bar_new: Array1<f64> = x_new
464            .iter()
465            .zip(x.iter())
466            .map(|(&xn, &xo)| 2.0 * xn - xo)
467            .collect();
468
469        // ── Convergence check ────────────────────────────────────────────────
470        let dx = vec_norm(&(x_new.clone() - &x));
471        let dy = vec_norm(&(y_new.clone() - &y));
472        let delta = dx + dy;
473
474        x = x_new;
475        y = y_new;
476        x_bar = x_bar_new;
477
478        if delta < config.tol {
479            break;
480        }
481    }
482    let _ = x_bar; // suppress unused warning
483    Ok((x, y))
484}
485
486// ─── Gap function ────────────────────────────────────────────────────────────
487
488/// Compute an approximate primal-dual gap at (x, y).
489///
490/// The gap is estimated by evaluating the gradient magnitudes:
491///   gap ≈ ‖∇ₓ f(x,y)‖ + ‖∇ᵧ f(x,y)‖
492///
493/// A gap of 0 indicates a perfect saddle point.
494fn compute_gap<F>(f: &F, x: &ArrayView1<f64>, y: &ArrayView1<f64>, h: f64) -> f64
495where
496    F: Fn(&ArrayView1<f64>, &ArrayView1<f64>) -> f64,
497{
498    let gx = grad_x(f, x, y, h);
499    let gy = grad_y(f, x, y, h);
500    vec_norm(&gx) + vec_norm(&gy)
501}
502
503// ─── Tests ───────────────────────────────────────────────────────────────────
504
505#[cfg(test)]
506mod tests {
507    use super::*;
508    use scirs2_core::ndarray::array;
509
510    /// Bilinear game: f(x, y) = x · y
511    /// Saddle point at (0, 0) for unconstrained problem.
512    fn bilinear(x: &ArrayView1<f64>, y: &ArrayView1<f64>) -> f64 {
513        x.iter().zip(y.iter()).map(|(&xi, &yi)| xi * yi).sum()
514    }
515
516    /// Convex-concave function: f(x, y) = x² - y² + x·y
517    /// Has a saddle point.
518    fn convex_concave(x: &ArrayView1<f64>, y: &ArrayView1<f64>) -> f64 {
519        let quad_x: f64 = x.iter().map(|xi| xi * xi).sum();
520        let quad_y: f64 = y.iter().map(|yi| yi * yi).sum();
521        let cross: f64 = x.iter().zip(y.iter()).map(|(&xi, &yi)| xi * yi).sum();
522        quad_x - quad_y + cross
523    }
524
525    #[test]
526    fn test_minimax_gda_bilinear() {
527        // GDA (simultaneous gradient) diverges on bilinear games; use the
528        // extragradient method which converges for monotone VIs.
529        let x0 = array![1.0, 1.0];
530        let y0 = array![1.0, 1.0];
531        let config = MinimaxConfig {
532            max_iter: 10_000,
533            tol: 1e-4,
534            step_size_x: 1e-3,
535            step_size_y: 1e-3,
536            ..Default::default()
537        };
538        let result = extragradient_solve(&bilinear, &x0.view(), &y0.view(), &config)
539            .expect("extragradient on bilinear should not fail");
540        // For bilinear game, saddle point is (0, 0)
541        let norm_x = result.x.iter().map(|xi| xi * xi).sum::<f64>().sqrt();
542        let norm_y = result.y.iter().map(|yi| yi * yi).sum::<f64>().sqrt();
543        assert!(
544            norm_x < 0.5,
545            "Extragradient bilinear: ‖x‖ should be small, got {}",
546            norm_x
547        );
548        assert!(
549            norm_y < 0.5,
550            "Extragradient bilinear: ‖y‖ should be small, got {}",
551            norm_y
552        );
553    }
554
555    #[test]
556    fn test_extragradient_convex_concave() {
557        let x0 = array![2.0];
558        let y0 = array![2.0];
559        let config = MinimaxConfig {
560            max_iter: 10_000,
561            tol: 1e-5,
562            step_size_x: 5e-4,
563            step_size_y: 5e-4,
564            ..Default::default()
565        };
566        // f(x, y) = x² - y²; saddle at (0, 0)
567        let f = |x: &ArrayView1<f64>, y: &ArrayView1<f64>| x[0] * x[0] - y[0] * y[0];
568        let result = extragradient_solve(&f, &x0.view(), &y0.view(), &config)
569            .expect("failed to create result");
570        assert!(
571            result.x[0].abs() < 0.3,
572            "EG: expected x* ≈ 0, got {}",
573            result.x[0]
574        );
575        assert!(
576            result.y[0].abs() < 0.3,
577            "EG: expected y* ≈ 0, got {}",
578            result.y[0]
579        );
580    }
581
582    #[test]
583    fn test_extragradient_convex_concave_2d() {
584        let x0 = array![1.0, 1.0];
585        let y0 = array![1.0, 1.0];
586        let config = MinimaxConfig {
587            max_iter: 10_000,
588            tol: 1e-5,
589            step_size_x: 5e-4,
590            step_size_y: 5e-4,
591            ..Default::default()
592        };
593        let result = extragradient_solve(&convex_concave, &x0.view(), &y0.view(), &config)
594            .expect("unexpected None or Err");
595        // saddle point closer to 0 than initial 1
596        let norm = result.x.iter().map(|xi| xi * xi).sum::<f64>().sqrt();
597        assert!(norm < 1.5, "EG 2D: ‖x‖={} should be < 1.5", norm);
598    }
599
600    #[test]
601    fn test_primal_dual_gradient() {
602        // primal_fn gradient: ∇g(x) = 2x (g(x) = ‖x‖²)
603        // dual_fn gradient: ∇h(y) = -2y (h(y) = -‖y‖²)
604        // Saddle point at (0, 0)
605        let x0 = array![3.0, -2.0];
606        let y0 = array![1.0, 4.0];
607        let config = PrimalDualConfig {
608            max_iter: 20_000,
609            tol: 1e-5,
610            sigma: 5e-4,
611            tau: 5e-4,
612            ..Default::default()
613        };
614        let primal_fn = |x: &ArrayView1<f64>| x.mapv(|xi| 2.0 * xi);
615        let dual_fn = |y: &ArrayView1<f64>| y.mapv(|yi| -2.0 * yi);
616        let (x_star, y_star) = primal_dual(&primal_fn, &dual_fn, &x0.view(), &y0.view(), &config)
617            .expect("unexpected None or Err");
618        let xn = x_star.iter().map(|xi| xi * xi).sum::<f64>().sqrt();
619        let yn = y_star.iter().map(|yi| yi * yi).sum::<f64>().sqrt();
620        assert!(xn < 0.5, "PD: ‖x*‖={} should be < 0.5", xn);
621        assert!(yn < 0.5, "PD: ‖y*‖={} should be < 0.5", yn);
622    }
623
624    #[test]
625    fn test_minimax_empty_input() {
626        let x0: Array1<f64> = Array1::zeros(0);
627        let y0 = array![1.0];
628        let config = MinimaxConfig::default();
629        assert!(minimax_solve(&bilinear, &x0.view(), &y0.view(), &config).is_err());
630    }
631}