scirs2_optimize/proximal/
operators.rs1use crate::error::OptimizeError;
26
27pub fn prox_l1(x: &[f64], lambda: f64) -> Vec<f64> {
39 x.iter()
40 .map(|&xi| xi.signum() * (xi.abs() - lambda).max(0.0))
41 .collect()
42}
43
44pub fn prox_l2(x: &[f64], lambda: f64) -> Vec<f64> {
58 let scale = 1.0 / (1.0 + 2.0 * lambda);
59 x.iter().map(|&xi| xi * scale).collect()
60}
61
62pub fn prox_linf(x: &[f64], lambda: f64) -> Vec<f64> {
77 x.iter().map(|&xi| xi.clamp(-lambda, lambda)).collect()
79}
80
81pub fn prox_nuclear(
104 matrix: &[f64],
105 rows: usize,
106 cols: usize,
107 lambda: f64,
108) -> Result<Vec<f64>, OptimizeError> {
109 if matrix.len() != rows * cols {
110 return Err(OptimizeError::ValueError(format!(
111 "matrix.len()={} != rows*cols={}",
112 matrix.len(),
113 rows * cols
114 )));
115 }
116 if rows == 0 || cols == 0 {
117 return Ok(Vec::new());
118 }
119
120 let mut a: Vec<Vec<f64>> = (0..rows)
122 .map(|i| matrix[i * cols..(i + 1) * cols].to_vec())
123 .collect();
124
125 let k = rows.min(cols);
128 let (u_mat, sigma, vt_mat) = thin_svd(&mut a, rows, cols, k)?;
129
130 let sigma_thresh: Vec<f64> = sigma.iter().map(|&s| (s - lambda).max(0.0)).collect();
132
133 let mut result = vec![0.0; rows * cols];
137 for i in 0..rows {
138 for j in 0..cols {
139 let mut val = 0.0;
140 for r in 0..k {
141 val += u_mat[r][i] * sigma_thresh[r] * vt_mat[r][j];
142 }
143 result[i * cols + j] = val;
144 }
145 }
146 Ok(result)
147}
148
149fn thin_svd(
152 a: &mut Vec<Vec<f64>>,
153 rows: usize,
154 cols: usize,
155 k: usize,
156) -> Result<(Vec<Vec<f64>>, Vec<f64>, Vec<Vec<f64>>), OptimizeError> {
157 let n_iter = 100;
160 let mut u_vecs: Vec<Vec<f64>> = Vec::with_capacity(k);
161 let mut sigma_vals: Vec<f64> = Vec::with_capacity(k);
162 let mut v_vecs: Vec<Vec<f64>> = Vec::with_capacity(k);
163
164 let mut work: Vec<Vec<f64>> = a.clone();
166
167 for _r in 0..k {
168 let mut v_vec: Vec<f64> = (0..cols).map(|i| (i as f64 + 1.0).sin()).collect();
170 normalise_vec(&mut v_vec);
171
172 let mut u_vec = vec![0.0; rows];
174 for _ in 0..n_iter {
175 for i in 0..rows {
177 u_vec[i] = (0..cols).map(|j| work[i][j] * v_vec[j]).sum();
178 }
179 for j in 0..cols {
181 v_vec[j] = (0..rows).map(|i| work[i][j] * u_vec[i]).sum();
182 }
183 normalise_vec(&mut v_vec);
184 }
185
186 for i in 0..rows {
188 u_vec[i] = (0..cols).map(|j| work[i][j] * v_vec[j]).sum();
189 }
190 let sigma = norm_vec(&u_vec);
191 if sigma < 1e-14 {
192 break; }
194 for ui in &mut u_vec {
195 *ui /= sigma;
196 }
197
198 for i in 0..rows {
200 for j in 0..cols {
201 work[i][j] -= sigma * u_vec[i] * v_vec[j];
202 }
203 }
204
205 u_vecs.push(u_vec);
206 sigma_vals.push(sigma);
207 v_vecs.push(v_vec);
208 }
209
210 let vt = v_vecs; Ok((u_vecs, sigma_vals, vt))
214}
215
216fn normalise_vec(v: &mut Vec<f64>) {
217 let n = norm_vec(v);
218 if n > 1e-14 {
219 for vi in v.iter_mut() {
220 *vi /= n;
221 }
222 }
223}
224
225fn norm_vec(v: &[f64]) -> f64 {
226 v.iter().map(|&x| x * x).sum::<f64>().sqrt()
227}
228
229pub fn project_simplex(x: &[f64]) -> Vec<f64> {
238 let n = x.len();
239 if n == 0 {
240 return Vec::new();
241 }
242 let mut sorted: Vec<f64> = x.to_vec();
244 sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
245
246 let mut cumsum = 0.0;
248 let mut rho = 0usize;
249 for (i, &si) in sorted.iter().enumerate() {
250 cumsum += si;
251 if si > (cumsum - 1.0) / (i as f64 + 1.0) {
252 rho = i;
253 }
254 }
255
256 let cumsum_rho: f64 = sorted[..=rho].iter().sum();
257 let theta = (cumsum_rho - 1.0) / (rho as f64 + 1.0);
258
259 x.iter().map(|&xi| (xi - theta).max(0.0)).collect()
260}
261
262pub fn project_box(x: &[f64], lb: &[f64], ub: &[f64]) -> Result<Vec<f64>, OptimizeError> {
274 let n = x.len();
275 if lb.len() != n || ub.len() != n {
276 return Err(OptimizeError::ValueError(format!(
277 "x.len()={}, lb.len()={}, ub.len()={}",
278 n,
279 lb.len(),
280 ub.len()
281 )));
282 }
283 Ok(x.iter()
284 .zip(lb.iter().zip(ub.iter()))
285 .map(|(&xi, (&lo, &hi))| xi.clamp(lo, hi))
286 .collect())
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292 use approx::assert_abs_diff_eq;
293
294 #[test]
295 fn test_prox_l1_soft_threshold() {
296 let x = vec![-3.0, -0.5, 0.0, 0.5, 3.0];
297 let result = prox_l1(&x, 1.0);
298 assert_abs_diff_eq!(result[0], -2.0, epsilon = 1e-12);
299 assert_abs_diff_eq!(result[1], 0.0, epsilon = 1e-12);
300 assert_abs_diff_eq!(result[2], 0.0, epsilon = 1e-12);
301 assert_abs_diff_eq!(result[3], 0.0, epsilon = 1e-12);
302 assert_abs_diff_eq!(result[4], 2.0, epsilon = 1e-12);
303 }
304
305 #[test]
306 fn test_prox_l1_zero_lambda() {
307 let x = vec![1.0, -2.0, 3.0];
308 let result = prox_l1(&x, 0.0);
309 for (r, orig) in result.iter().zip(x.iter()) {
310 assert_abs_diff_eq!(r, orig, epsilon = 1e-12);
311 }
312 }
313
314 #[test]
315 fn test_prox_l2_ridge() {
316 let x = vec![2.0, -4.0];
317 let result = prox_l2(&x, 0.5);
318 assert_abs_diff_eq!(result[0], 1.0, epsilon = 1e-12);
320 assert_abs_diff_eq!(result[1], -2.0, epsilon = 1e-12);
321 }
322
323 #[test]
324 fn test_prox_linf_clipping() {
325 let x = vec![-3.0, 1.0, 4.0];
326 let result = prox_linf(&x, 2.0);
327 assert_abs_diff_eq!(result[0], -2.0, epsilon = 1e-12);
328 assert_abs_diff_eq!(result[1], 1.0, epsilon = 1e-12);
329 assert_abs_diff_eq!(result[2], 2.0, epsilon = 1e-12);
330 }
331
332 #[test]
333 fn test_project_simplex_basic() {
334 let x = vec![0.5, 0.3, 0.2];
335 let proj = project_simplex(&x);
336 let sum: f64 = proj.iter().sum();
338 assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
339 assert!(proj.iter().all(|&v| v >= -1e-12));
340 }
341
342 #[test]
343 fn test_project_simplex_needs_projection() {
344 let x = vec![3.0, 3.0, 3.0];
345 let proj = project_simplex(&x);
346 let sum: f64 = proj.iter().sum();
347 assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
348 assert!(proj.iter().all(|&v| v >= -1e-12));
349 for p in &proj {
351 assert_abs_diff_eq!(p, &(1.0 / 3.0), epsilon = 1e-10);
352 }
353 }
354
355 #[test]
356 fn test_project_box() {
357 let x = vec![-2.0, 0.5, 3.0];
358 let lb = vec![-1.0, 0.0, 0.0];
359 let ub = vec![1.0, 1.0, 2.0];
360 let proj = project_box(&x, &lb, &ub).expect("box projection failed");
361 assert_abs_diff_eq!(proj[0], -1.0, epsilon = 1e-12);
362 assert_abs_diff_eq!(proj[1], 0.5, epsilon = 1e-12);
363 assert_abs_diff_eq!(proj[2], 2.0, epsilon = 1e-12);
364 }
365
366 #[test]
367 fn test_project_box_length_mismatch() {
368 let x = vec![1.0, 2.0];
369 let lb = vec![0.0];
370 let ub = vec![1.0, 2.0];
371 assert!(project_box(&x, &lb, &ub).is_err());
372 }
373
374 #[test]
375 fn test_prox_nuclear_identity() {
376 let m = vec![1.0, 2.0, 3.0, 4.0]; let result = prox_nuclear(&m, 2, 2, 0.0).expect("nuclear prox failed");
379 for (r, orig) in result.iter().zip(m.iter()) {
380 assert_abs_diff_eq!(r, orig, epsilon = 1e-6);
381 }
382 }
383
384 #[test]
385 fn test_prox_nuclear_shrinks_singular_values() {
386 let m = vec![5.0, 0.0, 0.0, 3.0];
388 let result = prox_nuclear(&m, 2, 2, 2.0).expect("nuclear prox failed");
389 assert!(result[0] < 5.0, "diagonal element should shrink");
392 assert!(result[3] < 3.0, "diagonal element should shrink");
393 }
394
395 #[test]
396 fn test_prox_nuclear_bad_size() {
397 let result = prox_nuclear(&[1.0, 2.0], 2, 2, 1.0);
398 assert!(result.is_err());
399 }
400}