1use crate::error::{OptimizeError, OptimizeResult};
26use scirs2_core::random::{rngs::StdRng, RngExt, SeedableRng};
27
28#[non_exhaustive]
30#[derive(Debug, Clone, PartialEq)]
31pub enum CoordSelectionRule {
32 Random,
34 Cyclic,
36 GreedyGradient,
38 GreedyGauss,
40 StochasticGreedy,
42}
43
44impl Default for CoordSelectionRule {
45 fn default() -> Self {
46 CoordSelectionRule::Cyclic
47 }
48}
49
50#[derive(Clone, Debug)]
52pub struct CoordDescentConfig {
53 pub selection: CoordSelectionRule,
55 pub max_iter: usize,
57 pub tol: f64,
59 pub line_search: bool,
62 pub step_size: f64,
64 pub greedy_subset_size: usize,
66 pub seed: u64,
68 pub armijo_c1: f64,
70 pub armijo_alpha0: f64,
72 pub armijo_tau: f64,
74 pub armijo_max_iter: usize,
76}
77
78impl Default for CoordDescentConfig {
79 fn default() -> Self {
80 Self {
81 selection: CoordSelectionRule::Cyclic,
82 max_iter: 1000,
83 tol: 1e-6,
84 line_search: true,
85 step_size: 1e-3,
86 greedy_subset_size: 10,
87 seed: 42,
88 armijo_c1: 1e-4,
89 armijo_alpha0: 1.0,
90 armijo_tau: 0.5,
91 armijo_max_iter: 50,
92 }
93 }
94}
95
96#[derive(Debug, Clone)]
98pub struct CoordDescentResult {
99 pub x: Vec<f64>,
101 pub f_val: f64,
103 pub n_iter: usize,
105 pub converged: bool,
107 pub gradient_norm: f64,
109}
110
111fn armijo_line_search<F>(
117 f: &F,
118 x: &[f64],
119 i: usize,
120 direction: f64,
121 grad_i: f64,
122 f_x: f64,
123 config: &CoordDescentConfig,
124) -> f64
125where
126 F: Fn(&[f64]) -> f64,
127{
128 let mut alpha = config.armijo_alpha0;
129 let mut x_trial = x.to_vec();
130
131 for _ in 0..config.armijo_max_iter {
132 x_trial[i] = x[i] + alpha * direction;
133 let f_trial = f(&x_trial);
134 if f_trial <= f_x + config.armijo_c1 * alpha * grad_i * direction {
135 return alpha;
136 }
137 alpha *= config.armijo_tau;
138 }
139 alpha
140}
141
142fn golden_section_1d<F>(f: &F, x: &[f64], i: usize, lo: f64, hi: f64) -> (f64, f64)
144where
145 F: Fn(&[f64]) -> f64,
146{
147 const PHI: f64 = 0.618_033_988_749_895; const MAX_ITER: usize = 100;
149 const TOL: f64 = 1e-10;
150
151 let mut a = lo;
152 let mut b = hi;
153 let mut c = b - PHI * (b - a);
154 let mut d = a + PHI * (b - a);
155
156 let mut x_c = x.to_vec();
157 let mut x_d = x.to_vec();
158 x_c[i] = c;
159 x_d[i] = d;
160 let mut f_c = f(&x_c);
161 let mut f_d = f(&x_d);
162
163 for _ in 0..MAX_ITER {
164 if (b - a).abs() < TOL {
165 break;
166 }
167 if f_c < f_d {
168 b = d;
169 d = c;
170 f_d = f_c;
171 c = b - PHI * (b - a);
172 x_c[i] = c;
173 f_c = f(&x_c);
174 } else {
175 a = c;
176 c = d;
177 f_c = f_d;
178 d = a + PHI * (b - a);
179 x_d[i] = d;
180 f_d = f(&x_d);
181 }
182 }
183
184 let x_mid = (a + b) / 2.0;
185 let mut x_eval = x.to_vec();
186 x_eval[i] = x_mid;
187 (x_mid, f(&x_eval))
188}
189
190fn select_coordinate(
192 rule: &CoordSelectionRule,
193 n: usize,
194 grad: &[f64],
195 cycle_idx: usize,
196 rng: &mut StdRng,
197 subset_size: usize,
198) -> usize {
199 match rule {
200 CoordSelectionRule::Random => rng.random_range(0..n),
201 CoordSelectionRule::Cyclic => cycle_idx % n,
202 CoordSelectionRule::GreedyGradient | CoordSelectionRule::GreedyGauss => grad
203 .iter()
204 .enumerate()
205 .max_by(|(_, a), (_, b)| {
206 a.abs()
207 .partial_cmp(&b.abs())
208 .unwrap_or(std::cmp::Ordering::Equal)
209 })
210 .map(|(i, _)| i)
211 .unwrap_or(0),
212 CoordSelectionRule::StochasticGreedy => {
213 let actual_size = subset_size.min(n);
214 let mut best_i = 0;
215 let mut best_val = f64::NEG_INFINITY;
216 for _ in 0..actual_size {
217 let i = rng.random_range(0..n);
218 let v = grad[i].abs();
219 if v > best_val {
220 best_val = v;
221 best_i = i;
222 }
223 }
224 best_i
225 }
226 _ => cycle_idx % n,
227 }
228}
229
230pub fn coordinate_descent<F, G>(
245 f: F,
246 grad: G,
247 x0: &[f64],
248 config: &CoordDescentConfig,
249) -> OptimizeResult<CoordDescentResult>
250where
251 F: Fn(&[f64]) -> f64,
252 G: Fn(&[f64]) -> Vec<f64>,
253{
254 let n = x0.len();
255 if n == 0 {
256 return Err(OptimizeError::InvalidInput(
257 "x0 must be non-empty".to_string(),
258 ));
259 }
260
261 let mut x = x0.to_vec();
262 let mut rng = StdRng::seed_from_u64(config.seed);
263 let mut cycle_idx: usize = 0;
264
265 for iter in 0..config.max_iter {
266 let g = grad(&x);
267
268 let gnorm = g.iter().map(|v| v * v).sum::<f64>().sqrt();
270 if gnorm < config.tol {
271 let f_val = f(&x);
272 return Ok(CoordDescentResult {
273 x,
274 f_val,
275 n_iter: iter,
276 converged: true,
277 gradient_norm: gnorm,
278 });
279 }
280
281 let i = select_coordinate(
283 &config.selection,
284 n,
285 &g,
286 cycle_idx,
287 &mut rng,
288 config.greedy_subset_size,
289 );
290 cycle_idx = cycle_idx.wrapping_add(1);
291
292 let gi = g[i];
293 if gi.abs() < f64::EPSILON * 100.0 {
294 continue;
295 }
296
297 if config.line_search {
298 let direction = -gi.signum();
300 let f_x = f(&x);
301 let alpha = armijo_line_search(&f, &x, i, direction, gi, f_x, config);
302 x[i] += alpha * direction;
303 } else {
304 x[i] -= config.step_size * gi;
305 }
306 }
307
308 let g_final = grad(&x);
310 let gnorm = g_final.iter().map(|v| v * v).sum::<f64>().sqrt();
311 let f_val = f(&x);
312
313 Ok(CoordDescentResult {
314 x,
315 f_val,
316 n_iter: config.max_iter,
317 converged: gnorm < config.tol,
318 gradient_norm: gnorm,
319 })
320}
321
322#[inline]
328fn soft_threshold(u: f64, threshold: f64) -> f64 {
329 if u > threshold {
330 u - threshold
331 } else if u < -threshold {
332 u + threshold
333 } else {
334 0.0
335 }
336}
337
338pub fn proximal_coord_descent<F, G>(
357 f: F,
358 grad_f: G,
359 lambda: f64,
360 x0: &[f64],
361 config: &CoordDescentConfig,
362) -> OptimizeResult<CoordDescentResult>
363where
364 F: Fn(&[f64]) -> f64,
365 G: Fn(&[f64]) -> Vec<f64>,
366{
367 if lambda < 0.0 {
368 return Err(OptimizeError::InvalidParameter(
369 "lambda must be non-negative".to_string(),
370 ));
371 }
372 let n = x0.len();
373 if n == 0 {
374 return Err(OptimizeError::InvalidInput(
375 "x0 must be non-empty".to_string(),
376 ));
377 }
378
379 let mut x = x0.to_vec();
380 let mut rng = StdRng::seed_from_u64(config.seed);
381 let mut cycle_idx: usize = 0;
382
383 for iter in 0..config.max_iter {
384 let g = grad_f(&x);
385
386 let gnorm_sq: f64 = g
388 .iter()
389 .enumerate()
390 .map(|(i, &gi)| {
391 let composite = if x[i].abs() > f64::EPSILON {
394 gi + lambda * x[i].signum()
395 } else {
396 let abs_gi = gi.abs();
397 if abs_gi > lambda {
398 abs_gi - lambda
399 } else {
400 0.0
401 }
402 };
403 composite * composite
404 })
405 .sum();
406 let gnorm = gnorm_sq.sqrt();
407
408 if gnorm < config.tol {
409 let f_val = f(&x);
410 return Ok(CoordDescentResult {
411 x,
412 f_val,
413 n_iter: iter,
414 converged: true,
415 gradient_norm: gnorm,
416 });
417 }
418
419 let i = select_coordinate(
421 &config.selection,
422 n,
423 &g,
424 cycle_idx,
425 &mut rng,
426 config.greedy_subset_size,
427 );
428 cycle_idx = cycle_idx.wrapping_add(1);
429
430 let step = if config.line_search {
432 let direction = -g[i].signum();
434 let f_x = f(&x);
435 armijo_line_search(&f, &x, i, direction, g[i], f_x, config)
436 } else {
437 config.step_size
438 };
439
440 let u = x[i] - step * g[i];
442 x[i] = soft_threshold(u, lambda * step);
443 }
444
445 let g_final = grad_f(&x);
446 let gnorm: f64 = g_final
447 .iter()
448 .enumerate()
449 .map(|(i, &gi)| {
450 let c = if x[i].abs() > f64::EPSILON {
451 gi + lambda * x[i].signum()
452 } else {
453 let abs_gi = gi.abs();
454 if abs_gi > lambda {
455 abs_gi - lambda
456 } else {
457 0.0
458 }
459 };
460 c * c
461 })
462 .sum::<f64>()
463 .sqrt();
464 let f_val = f(&x);
465
466 Ok(CoordDescentResult {
467 x,
468 f_val,
469 n_iter: config.max_iter,
470 converged: gnorm < config.tol,
471 gradient_norm: gnorm,
472 })
473}
474
475#[cfg(test)]
478mod tests {
479 use super::*;
480
481 fn quadratic(x: &[f64]) -> f64 {
483 0.5 * (x[0] - 2.0).powi(2) + 0.5 * (x[1] - 3.0).powi(2)
484 }
485
486 fn quadratic_grad(x: &[f64]) -> Vec<f64> {
487 vec![x[0] - 2.0, x[1] - 3.0]
488 }
489
490 #[test]
491 fn test_coord_descent_quadratic() {
492 let config = CoordDescentConfig {
493 selection: CoordSelectionRule::Cyclic,
494 max_iter: 5000,
495 tol: 1e-8,
496 line_search: true,
497 ..CoordDescentConfig::default()
498 };
499 let result = coordinate_descent(quadratic, quadratic_grad, &[0.0, 0.0], &config)
500 .expect("optimization should succeed");
501 assert!(result.converged, "should converge");
502 assert!((result.x[0] - 2.0).abs() < 1e-5, "x[0] should be ≈ 2");
503 assert!((result.x[1] - 3.0).abs() < 1e-5, "x[1] should be ≈ 3");
504 }
505
506 #[test]
507 fn test_coord_descent_random_vs_cyclic() {
508 let x0 = vec![0.0, 0.0];
509 let config_cyclic = CoordDescentConfig {
510 selection: CoordSelectionRule::Cyclic,
511 max_iter: 5000,
512 tol: 1e-6,
513 line_search: true,
514 ..CoordDescentConfig::default()
515 };
516 let config_random = CoordDescentConfig {
517 selection: CoordSelectionRule::Random,
518 max_iter: 20000,
519 tol: 1e-6,
520 line_search: true,
521 ..CoordDescentConfig::default()
522 };
523
524 let r_cyclic = coordinate_descent(quadratic, quadratic_grad, &x0, &config_cyclic)
525 .expect("cyclic should succeed");
526 let r_random = coordinate_descent(quadratic, quadratic_grad, &x0, &config_random)
527 .expect("random should succeed");
528
529 assert!(r_cyclic.converged, "cyclic should converge");
530 assert!(r_random.converged, "random should converge");
531
532 assert!((r_cyclic.x[0] - 2.0).abs() < 1e-4);
534 assert!((r_random.x[0] - 2.0).abs() < 1e-4);
535 }
536
537 #[test]
538 fn test_coord_descent_greedy() {
539 let f = |x: &[f64]| 0.5 * (100.0 * (x[0] - 1.0).powi(2) + (x[1] - 1.0).powi(2));
541 let g = |x: &[f64]| vec![100.0 * (x[0] - 1.0), x[1] - 1.0];
542
543 let config = CoordDescentConfig {
544 selection: CoordSelectionRule::GreedyGradient,
545 max_iter: 5000,
546 tol: 1e-6,
547 line_search: true,
548 ..CoordDescentConfig::default()
549 };
550 let result = coordinate_descent(f, g, &[0.0, 0.0], &config).expect("greedy should succeed");
551 assert!(result.converged);
552 assert!((result.x[0] - 1.0).abs() < 1e-4);
553 assert!((result.x[1] - 1.0).abs() < 1e-4);
554 }
555
556 #[test]
557 fn test_coord_descent_stochastic_greedy() {
558 let config = CoordDescentConfig {
559 selection: CoordSelectionRule::StochasticGreedy,
560 max_iter: 10000,
561 tol: 1e-5,
562 greedy_subset_size: 2,
563 line_search: true,
564 ..CoordDescentConfig::default()
565 };
566 let result = coordinate_descent(quadratic, quadratic_grad, &[0.0, 0.0], &config)
567 .expect("stochastic greedy should succeed");
568 assert!(result.converged || result.gradient_norm < 1e-4);
569 }
570
571 #[test]
572 fn test_proximal_coord_descent_lasso() {
573 let c = vec![5.0_f64, 0.3, -0.2, 0.0];
577 let c_clone = c.clone();
578 let f = move |x: &[f64]| {
579 x.iter()
580 .zip(c.iter())
581 .map(|(xi, ci)| 0.5 * (xi - ci).powi(2))
582 .sum::<f64>()
583 };
584 let g = move |x: &[f64]| {
585 x.iter()
586 .zip(c_clone.iter())
587 .map(|(xi, ci)| xi - ci)
588 .collect::<Vec<_>>()
589 };
590
591 let lambda = 1.0_f64;
592 let config = CoordDescentConfig {
593 selection: CoordSelectionRule::Cyclic,
594 max_iter: 5000,
595 tol: 1e-8,
596 line_search: false,
597 step_size: 0.5,
598 ..CoordDescentConfig::default()
599 };
600
601 let result = proximal_coord_descent(f, g, lambda, &[0.0; 4], &config)
602 .expect("proximal CD should succeed");
603
604 assert!(result.x[0] > 3.0, "x[0] should be large positive");
606 assert!(result.x[1].abs() < 0.5, "x[1] should be sparse");
608 assert!(result.x[2].abs() < 0.5, "x[2] should be sparse");
609 assert!(result.x[3].abs() < 0.5, "x[3] should be sparse");
610 }
611
612 #[test]
613 fn test_proximal_coord_descent_zero_lambda() {
614 let config = CoordDescentConfig {
616 selection: CoordSelectionRule::Cyclic,
617 max_iter: 5000,
618 tol: 1e-8,
619 line_search: false,
620 step_size: 0.5,
621 ..CoordDescentConfig::default()
622 };
623 let result = proximal_coord_descent(quadratic, quadratic_grad, 0.0, &[0.0, 0.0], &config)
624 .expect("zero-lambda proximal should succeed");
625 assert!((result.x[0] - 2.0).abs() < 0.1, "x[0] ≈ 2 for lambda=0");
627 }
628
629 #[test]
630 fn test_coord_descent_empty_input_error() {
631 let result = coordinate_descent(
632 |_: &[f64]| 0.0,
633 |_: &[f64]| vec![],
634 &[],
635 &CoordDescentConfig::default(),
636 );
637 assert!(result.is_err());
638 }
639
640 #[test]
641 fn test_proximal_negative_lambda_error() {
642 let result = proximal_coord_descent(
643 |_: &[f64]| 0.0,
644 |_: &[f64]| vec![0.0],
645 -1.0,
646 &[0.0],
647 &CoordDescentConfig::default(),
648 );
649 assert!(result.is_err());
650 }
651
652 #[test]
653 fn test_coord_descent_gauss_southwell() {
654 let config = CoordDescentConfig {
655 selection: CoordSelectionRule::GreedyGauss,
656 max_iter: 5000,
657 tol: 1e-6,
658 line_search: true,
659 ..CoordDescentConfig::default()
660 };
661 let result = coordinate_descent(quadratic, quadratic_grad, &[0.0, 0.0], &config)
662 .expect("Gauss-Southwell should succeed");
663 assert!(result.converged);
664 assert!((result.x[0] - 2.0).abs() < 1e-4);
665 assert!((result.x[1] - 3.0).abs() < 1e-4);
666 }
667
668 #[test]
669 fn test_coord_descent_no_line_search() {
670 let config = CoordDescentConfig {
671 selection: CoordSelectionRule::Cyclic,
672 max_iter: 100_000,
673 tol: 1e-5,
674 line_search: false,
675 step_size: 0.1,
676 ..CoordDescentConfig::default()
677 };
678 let result = coordinate_descent(quadratic, quadratic_grad, &[0.0, 0.0], &config)
679 .expect("fixed step CD should succeed");
680 assert!((result.x[0] - 2.0).abs() < 0.1);
681 assert!((result.x[1] - 3.0).abs() < 0.1);
682 }
683}