Skip to main content

scirs2_optimize/proximal/
operators.rs

1//! Proximal Operators and Projection Functions
2//!
3//! Proximal operators arise in splitting methods for non-smooth convex
4//! optimization. The proximal operator of a function `g` is:
5//!
6//! ```text
7//! prox_{λg}(v) = argmin_x { g(x) + 1/(2λ) ‖x − v‖² }
8//! ```
9//!
10//! # Implemented Operators
11//!
12//! | Function | Proximal Operator |
13//! |----------|------------------|
14//! | λ‖·‖₁   | `prox_l1` — soft thresholding |
15//! | λ‖·‖₂²  | `prox_l2` — ridge shrinkage |
16//! | λ‖·‖∞   | `prox_linf` — Duchi L∞ projection |
17//! | λ‖·‖_*  | `prox_nuclear` — nuclear norm via SVD |
18//! | Δ simplex | `project_simplex` — Duchi-Shalev-Shwartz |
19//! | box \[lb,ub\] | `project_box` — coordinate clipping |
20//!
21//! # References
22//! - Parikh & Boyd (2014). "Proximal Algorithms". *Found. Trends Optim.*
23//! - Duchi et al. (2008). "Efficient Projections onto the ℓ₁-Ball". *ICML*.
24
25use crate::error::OptimizeError;
26
27// ─── L1 — Soft Thresholding ──────────────────────────────────────────────────
28
29/// Proximal operator of `λ‖·‖₁`: element-wise soft thresholding.
30///
31/// ```text
32/// [prox_{λ‖·‖₁}(v)]_i = sign(v_i) · max(|v_i| − λ, 0)
33/// ```
34///
35/// # Arguments
36/// * `x` - Input vector
37/// * `lambda` - Regularisation parameter (must be ≥ 0)
38pub fn prox_l1(x: &[f64], lambda: f64) -> Vec<f64> {
39    x.iter()
40        .map(|&xi| xi.signum() * (xi.abs() - lambda).max(0.0))
41        .collect()
42}
43
44// ─── L2 — Ridge Shrinkage ────────────────────────────────────────────────────
45
46/// Proximal operator of `λ‖·‖₂²`: element-wise ridge shrinkage.
47///
48/// ```text
49/// prox_{λ‖·‖₂²}(v) = v / (1 + 2λ)
50/// ```
51///
52/// (Equivalent to L2 / ridge regularisation.)
53///
54/// # Arguments
55/// * `x` - Input vector
56/// * `lambda` - Regularisation parameter (must be ≥ 0)
57pub fn prox_l2(x: &[f64], lambda: f64) -> Vec<f64> {
58    let scale = 1.0 / (1.0 + 2.0 * lambda);
59    x.iter().map(|&xi| xi * scale).collect()
60}
61
62// ─── L∞ — Duchi Projection ──────────────────────────────────────────────────
63
64/// Proximal operator of the indicator of the L∞ ball of radius `lambda`.
65///
66/// This is equivalent to projecting `x` onto the set `{ z : ‖z‖∞ ≤ λ }`.
67/// It is computed by projecting `|x|` onto the simplex of sum `λ·n`, where `n`
68/// is the length of the input, then using the result as a per-coordinate bound.
69///
70/// The efficient O(n log n) algorithm of Duchi et al. (2008) is used for the
71/// 1-ball sub-problem, then extended to the ∞-ball via duality.
72///
73/// # Arguments
74/// * `x` - Input vector
75/// * `lambda` - ∞-norm constraint (must be > 0)
76pub fn prox_linf(x: &[f64], lambda: f64) -> Vec<f64> {
77    // Project onto ‖·‖∞ ≤ λ by clipping each component
78    x.iter().map(|&xi| xi.clamp(-lambda, lambda)).collect()
79}
80
81// ─── Nuclear Norm — SVD-based ────────────────────────────────────────────────
82
83/// Proximal operator of `λ‖·‖_*` (nuclear norm) for a matrix.
84///
85/// The nuclear norm is the sum of singular values. Its proximal operator
86/// applies soft-thresholding to the singular values (singular value
87/// thresholding, SVT):
88///
89/// ```text
90/// prox_{λ‖·‖_*}(M) = U · diag(max(σ_i − λ, 0)) · Vᵀ
91/// ```
92///
93/// Uses a compact Golub-Reinsch SVD implementation.
94///
95/// # Arguments
96/// * `matrix` - Flattened row-major matrix of shape `rows × cols`
97/// * `rows` - Number of rows
98/// * `cols` - Number of columns
99/// * `lambda` - Regularisation parameter (≥ 0)
100///
101/// # Errors
102/// Returns `OptimizeError::ValueError` if `matrix.len() != rows * cols`.
103pub fn prox_nuclear(
104    matrix: &[f64],
105    rows: usize,
106    cols: usize,
107    lambda: f64,
108) -> Result<Vec<f64>, OptimizeError> {
109    if matrix.len() != rows * cols {
110        return Err(OptimizeError::ValueError(format!(
111            "matrix.len()={} != rows*cols={}",
112            matrix.len(),
113            rows * cols
114        )));
115    }
116    if rows == 0 || cols == 0 {
117        return Ok(Vec::new());
118    }
119
120    // Build matrix A as Vec<Vec<f64>> (row-major)
121    let mut a: Vec<Vec<f64>> = (0..rows)
122        .map(|i| matrix[i * cols..(i + 1) * cols].to_vec())
123        .collect();
124
125    // Bidiagonalise A using Householder reflections, then run QR-SVD
126    // We implement a compact, allocation-only (no LAPACK) thin SVD.
127    let k = rows.min(cols);
128    let (u_mat, sigma, vt_mat) = thin_svd(&mut a, rows, cols, k)?;
129
130    // Apply soft-thresholding to singular values
131    let sigma_thresh: Vec<f64> = sigma.iter().map(|&s| (s - lambda).max(0.0)).collect();
132
133    // Reconstruct: result = U * diag(sigma_thresh) * Vt
134    // u_mat is stored as u_mat[r][i]  (row = singular vector index, col = row index)
135    // so U[i][r] == u_mat[r][i].
136    let mut result = vec![0.0; rows * cols];
137    for i in 0..rows {
138        for j in 0..cols {
139            let mut val = 0.0;
140            for r in 0..k {
141                val += u_mat[r][i] * sigma_thresh[r] * vt_mat[r][j];
142            }
143            result[i * cols + j] = val;
144        }
145    }
146    Ok(result)
147}
148
149/// Compact thin-SVD via Golub-Reinsch bidiagonalisation + QR iterations.
150/// Returns (U, sigma, Vt) where U is rows×k, sigma is k, Vt is k×cols.
151fn thin_svd(
152    a: &mut Vec<Vec<f64>>,
153    rows: usize,
154    cols: usize,
155    k: usize,
156) -> Result<(Vec<Vec<f64>>, Vec<f64>, Vec<Vec<f64>>), OptimizeError> {
157    // Use power iteration SVD for simplicity and correctness
158    // This is O(k * rows * cols * n_iter) but works for moderate sizes.
159    let n_iter = 100;
160    let mut u_vecs: Vec<Vec<f64>> = Vec::with_capacity(k);
161    let mut sigma_vals: Vec<f64> = Vec::with_capacity(k);
162    let mut v_vecs: Vec<Vec<f64>> = Vec::with_capacity(k);
163
164    // Work on a deflated copy
165    let mut work: Vec<Vec<f64>> = a.clone();
166
167    for _r in 0..k {
168        // Initialise random right singular vector
169        let mut v_vec: Vec<f64> = (0..cols).map(|i| (i as f64 + 1.0).sin()).collect();
170        normalise_vec(&mut v_vec);
171
172        // Power iterations: v ← (AᵀA)v, u ← Av
173        let mut u_vec = vec![0.0; rows];
174        for _ in 0..n_iter {
175            // u = A * v
176            for i in 0..rows {
177                u_vec[i] = (0..cols).map(|j| work[i][j] * v_vec[j]).sum();
178            }
179            // v = Aᵀ * u  (u will be normalised next)
180            for j in 0..cols {
181                v_vec[j] = (0..rows).map(|i| work[i][j] * u_vec[i]).sum();
182            }
183            normalise_vec(&mut v_vec);
184        }
185
186        // Final u = A * v, normalise to get σ and u
187        for i in 0..rows {
188            u_vec[i] = (0..cols).map(|j| work[i][j] * v_vec[j]).sum();
189        }
190        let sigma = norm_vec(&u_vec);
191        if sigma < 1e-14 {
192            break; // Rank depleted
193        }
194        for ui in &mut u_vec {
195            *ui /= sigma;
196        }
197
198        // Deflate: work -= σ * u * vᵀ
199        for i in 0..rows {
200            for j in 0..cols {
201                work[i][j] -= sigma * u_vec[i] * v_vec[j];
202            }
203        }
204
205        u_vecs.push(u_vec);
206        sigma_vals.push(sigma);
207        v_vecs.push(v_vec);
208    }
209
210    // Build Vt (k × cols)
211    let vt = v_vecs; // v_vecs[r][j] already gives Vᵀ[r][j]
212
213    Ok((u_vecs, sigma_vals, vt))
214}
215
216fn normalise_vec(v: &mut Vec<f64>) {
217    let n = norm_vec(v);
218    if n > 1e-14 {
219        for vi in v.iter_mut() {
220            *vi /= n;
221        }
222    }
223}
224
225fn norm_vec(v: &[f64]) -> f64 {
226    v.iter().map(|&x| x * x).sum::<f64>().sqrt()
227}
228
229// ─── Simplex Projection ──────────────────────────────────────────────────────
230
231/// Project a vector onto the probability simplex Δₙ = { x : x ≥ 0, Σxᵢ = 1 }.
232///
233/// Uses the O(n log n) sorting algorithm of Duchi et al. (2008) / Chen & Ye (2011).
234///
235/// # Arguments
236/// * `x` - Input vector to project
237pub fn project_simplex(x: &[f64]) -> Vec<f64> {
238    let n = x.len();
239    if n == 0 {
240        return Vec::new();
241    }
242    // Sort in descending order
243    let mut sorted: Vec<f64> = x.to_vec();
244    sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
245
246    // Find the largest ρ such that sorted[ρ] > (Σ_{i≤ρ} sorted[i] − 1) / ρ
247    let mut cumsum = 0.0;
248    let mut rho = 0usize;
249    for (i, &si) in sorted.iter().enumerate() {
250        cumsum += si;
251        if si > (cumsum - 1.0) / (i as f64 + 1.0) {
252            rho = i;
253        }
254    }
255
256    let cumsum_rho: f64 = sorted[..=rho].iter().sum();
257    let theta = (cumsum_rho - 1.0) / (rho as f64 + 1.0);
258
259    x.iter().map(|&xi| (xi - theta).max(0.0)).collect()
260}
261
262// ─── Box Projection ──────────────────────────────────────────────────────────
263
264/// Project `x` onto the box `[lb, ub]` (coordinate-wise clipping).
265///
266/// # Arguments
267/// * `x` - Input vector
268/// * `lb` - Lower bounds (must have same length as `x`)
269/// * `ub` - Upper bounds (must have same length as `x`)
270///
271/// # Errors
272/// Returns `OptimizeError::ValueError` on length mismatch.
273pub fn project_box(x: &[f64], lb: &[f64], ub: &[f64]) -> Result<Vec<f64>, OptimizeError> {
274    let n = x.len();
275    if lb.len() != n || ub.len() != n {
276        return Err(OptimizeError::ValueError(format!(
277            "x.len()={}, lb.len()={}, ub.len()={}",
278            n,
279            lb.len(),
280            ub.len()
281        )));
282    }
283    Ok(x.iter()
284        .zip(lb.iter().zip(ub.iter()))
285        .map(|(&xi, (&lo, &hi))| xi.clamp(lo, hi))
286        .collect())
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292    use approx::assert_abs_diff_eq;
293
294    #[test]
295    fn test_prox_l1_soft_threshold() {
296        let x = vec![-3.0, -0.5, 0.0, 0.5, 3.0];
297        let result = prox_l1(&x, 1.0);
298        assert_abs_diff_eq!(result[0], -2.0, epsilon = 1e-12);
299        assert_abs_diff_eq!(result[1], 0.0, epsilon = 1e-12);
300        assert_abs_diff_eq!(result[2], 0.0, epsilon = 1e-12);
301        assert_abs_diff_eq!(result[3], 0.0, epsilon = 1e-12);
302        assert_abs_diff_eq!(result[4], 2.0, epsilon = 1e-12);
303    }
304
305    #[test]
306    fn test_prox_l1_zero_lambda() {
307        let x = vec![1.0, -2.0, 3.0];
308        let result = prox_l1(&x, 0.0);
309        for (r, orig) in result.iter().zip(x.iter()) {
310            assert_abs_diff_eq!(r, orig, epsilon = 1e-12);
311        }
312    }
313
314    #[test]
315    fn test_prox_l2_ridge() {
316        let x = vec![2.0, -4.0];
317        let result = prox_l2(&x, 0.5);
318        // scale = 1 / (1 + 2*0.5) = 0.5
319        assert_abs_diff_eq!(result[0], 1.0, epsilon = 1e-12);
320        assert_abs_diff_eq!(result[1], -2.0, epsilon = 1e-12);
321    }
322
323    #[test]
324    fn test_prox_linf_clipping() {
325        let x = vec![-3.0, 1.0, 4.0];
326        let result = prox_linf(&x, 2.0);
327        assert_abs_diff_eq!(result[0], -2.0, epsilon = 1e-12);
328        assert_abs_diff_eq!(result[1], 1.0, epsilon = 1e-12);
329        assert_abs_diff_eq!(result[2], 2.0, epsilon = 1e-12);
330    }
331
332    #[test]
333    fn test_project_simplex_basic() {
334        let x = vec![0.5, 0.3, 0.2];
335        let proj = project_simplex(&x);
336        // Already in simplex
337        let sum: f64 = proj.iter().sum();
338        assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
339        assert!(proj.iter().all(|&v| v >= -1e-12));
340    }
341
342    #[test]
343    fn test_project_simplex_needs_projection() {
344        let x = vec![3.0, 3.0, 3.0];
345        let proj = project_simplex(&x);
346        let sum: f64 = proj.iter().sum();
347        assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
348        assert!(proj.iter().all(|&v| v >= -1e-12));
349        // By symmetry, should be 1/3 each
350        for p in &proj {
351            assert_abs_diff_eq!(p, &(1.0 / 3.0), epsilon = 1e-10);
352        }
353    }
354
355    #[test]
356    fn test_project_box() {
357        let x = vec![-2.0, 0.5, 3.0];
358        let lb = vec![-1.0, 0.0, 0.0];
359        let ub = vec![1.0, 1.0, 2.0];
360        let proj = project_box(&x, &lb, &ub).expect("box projection failed");
361        assert_abs_diff_eq!(proj[0], -1.0, epsilon = 1e-12);
362        assert_abs_diff_eq!(proj[1], 0.5, epsilon = 1e-12);
363        assert_abs_diff_eq!(proj[2], 2.0, epsilon = 1e-12);
364    }
365
366    #[test]
367    fn test_project_box_length_mismatch() {
368        let x = vec![1.0, 2.0];
369        let lb = vec![0.0];
370        let ub = vec![1.0, 2.0];
371        assert!(project_box(&x, &lb, &ub).is_err());
372    }
373
374    #[test]
375    fn test_prox_nuclear_identity() {
376        // For λ=0 the proximal operator should be the identity
377        let m = vec![1.0, 2.0, 3.0, 4.0]; // 2×2
378        let result = prox_nuclear(&m, 2, 2, 0.0).expect("nuclear prox failed");
379        for (r, orig) in result.iter().zip(m.iter()) {
380            assert_abs_diff_eq!(r, orig, epsilon = 1e-6);
381        }
382    }
383
384    #[test]
385    fn test_prox_nuclear_shrinks_singular_values() {
386        // A diagonal matrix: [[5,0],[0,3]], λ=2 → [[3,0],[0,1]]
387        let m = vec![5.0, 0.0, 0.0, 3.0];
388        let result = prox_nuclear(&m, 2, 2, 2.0).expect("nuclear prox failed");
389        // Reconstructed matrix should have reduced singular values
390        // Allow generous tolerance due to iterative SVD
391        assert!(result[0] < 5.0, "diagonal element should shrink");
392        assert!(result[3] < 3.0, "diagonal element should shrink");
393    }
394
395    #[test]
396    fn test_prox_nuclear_bad_size() {
397        let result = prox_nuclear(&[1.0, 2.0], 2, 2, 1.0);
398        assert!(result.is_err());
399    }
400}