1use crate::error::{OptimizeError, OptimizeResult};
15use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
16use scirs2_core::random::{rngs::StdRng, RngExt, SeedableRng};
17
18#[derive(Debug, Clone, Copy, PartialEq)]
20pub enum CoordinateSelectionStrategy {
21 Cyclic,
23 Randomized,
25 Greedy,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq)]
31pub enum RegularizationType {
32 None,
34 L1,
36 L2,
38 ElasticNet,
40}
41
42#[derive(Debug, Clone)]
44pub struct CoordinateDescentConfig {
45 pub max_iter: usize,
47 pub tol: f64,
49 pub strategy: CoordinateSelectionStrategy,
51 pub step_size: Option<f64>,
53 pub regularization: RegularizationType,
55 pub lambda: f64,
57 pub alpha: f64,
59 pub seed: u64,
61 pub track_objective: bool,
63}
64
65impl Default for CoordinateDescentConfig {
66 fn default() -> Self {
67 Self {
68 max_iter: 1000,
69 tol: 1e-8,
70 strategy: CoordinateSelectionStrategy::Cyclic,
71 step_size: None,
72 regularization: RegularizationType::None,
73 lambda: 0.0,
74 alpha: 0.5,
75 seed: 42,
76 track_objective: false,
77 }
78 }
79}
80
81#[derive(Debug, Clone)]
83pub struct CoordinateDescentResult {
84 pub x: Array1<f64>,
86 pub fun: f64,
88 pub fun_regularized: f64,
90 pub iterations: usize,
92 pub converged: bool,
94 pub objective_history: Vec<f64>,
96 pub grad_norm: f64,
98}
99
100fn soft_threshold(x: f64, threshold: f64) -> f64 {
104 if x > threshold {
105 x - threshold
106 } else if x < -threshold {
107 x + threshold
108 } else {
109 0.0
110 }
111}
112
113fn regularization_penalty(
115 x: &Array1<f64>,
116 reg_type: RegularizationType,
117 lambda: f64,
118 alpha: f64,
119) -> f64 {
120 match reg_type {
121 RegularizationType::None => 0.0,
122 RegularizationType::L1 => lambda * x.mapv(f64::abs).sum(),
123 RegularizationType::L2 => lambda * x.dot(x),
124 RegularizationType::ElasticNet => {
125 let l1_part = alpha * lambda * x.mapv(f64::abs).sum();
126 let l2_part = (1.0 - alpha) * lambda * x.dot(x);
127 l1_part + l2_part
128 }
129 }
130}
131
132pub struct CoordinateDescentSolver {
137 config: CoordinateDescentConfig,
138}
139
140impl CoordinateDescentSolver {
141 pub fn new(config: CoordinateDescentConfig) -> Self {
143 Self { config }
144 }
145
146 pub fn default_solver() -> Self {
148 Self::new(CoordinateDescentConfig::default())
149 }
150
151 pub fn minimize<F, G>(
161 &self,
162 objective: F,
163 gradient: G,
164 x0: &Array1<f64>,
165 ) -> OptimizeResult<CoordinateDescentResult>
166 where
167 F: Fn(&ArrayView1<f64>) -> f64,
168 G: Fn(&ArrayView1<f64>) -> Array1<f64>,
169 {
170 let n = x0.len();
171 if n == 0 {
172 return Err(OptimizeError::InvalidInput(
173 "Initial point must have at least one dimension".to_string(),
174 ));
175 }
176
177 let mut x = x0.clone();
178 let mut rng = StdRng::seed_from_u64(self.config.seed);
179 let mut objective_history = Vec::new();
180
181 let step_size = self.config.step_size.unwrap_or(0.01);
182
183 let mut prev_obj = objective(&x.view())
184 + regularization_penalty(
185 &x,
186 self.config.regularization,
187 self.config.lambda,
188 self.config.alpha,
189 );
190
191 if self.config.track_objective {
192 objective_history.push(prev_obj);
193 }
194
195 let mut converged = false;
196 let mut iterations = 0;
197
198 for iter in 0..self.config.max_iter {
199 iterations = iter + 1;
200
201 for _coord_step in 0..n {
203 let coord = match self.config.strategy {
204 CoordinateSelectionStrategy::Cyclic => _coord_step,
205 CoordinateSelectionStrategy::Randomized => rng.random_range(0..n),
206 CoordinateSelectionStrategy::Greedy => {
207 let grad = gradient(&x.view());
208 let mut best_coord = 0;
210 let mut best_abs_grad = f64::NEG_INFINITY;
211 for i in 0..n {
212 let abs_g = grad[i].abs();
213 if abs_g > best_abs_grad {
214 best_abs_grad = abs_g;
215 best_coord = i;
216 }
217 }
218 best_coord
219 }
220 };
221
222 let grad = gradient(&x.view());
224 let grad_coord = grad[coord];
225
226 match self.config.regularization {
228 RegularizationType::None => {
229 x[coord] -= step_size * grad_coord;
230 }
231 RegularizationType::L1 => {
232 let proposal = x[coord] - step_size * grad_coord;
234 x[coord] = soft_threshold(proposal, step_size * self.config.lambda);
235 }
236 RegularizationType::L2 => {
237 let total_grad = grad_coord + 2.0 * self.config.lambda * x[coord];
239 x[coord] -= step_size * total_grad;
240 }
241 RegularizationType::ElasticNet => {
242 let l2_grad =
244 2.0 * (1.0 - self.config.alpha) * self.config.lambda * x[coord];
245 let proposal = x[coord] - step_size * (grad_coord + l2_grad);
246 x[coord] = soft_threshold(
248 proposal,
249 step_size * self.config.alpha * self.config.lambda,
250 );
251 }
252 }
253 }
254
255 let smooth_obj = objective(&x.view());
256 let total_obj = smooth_obj
257 + regularization_penalty(
258 &x,
259 self.config.regularization,
260 self.config.lambda,
261 self.config.alpha,
262 );
263
264 if self.config.track_objective {
265 objective_history.push(total_obj);
266 }
267
268 let change = (prev_obj - total_obj).abs();
269 prev_obj = total_obj;
270
271 if change < self.config.tol {
272 converged = true;
273 break;
274 }
275 }
276
277 let final_grad = gradient(&x.view());
278 let grad_norm = final_grad.dot(&final_grad).sqrt();
279 let smooth_obj = objective(&x.view());
280 let reg_penalty = regularization_penalty(
281 &x,
282 self.config.regularization,
283 self.config.lambda,
284 self.config.alpha,
285 );
286
287 Ok(CoordinateDescentResult {
288 x,
289 fun: smooth_obj,
290 fun_regularized: smooth_obj + reg_penalty,
291 iterations,
292 converged,
293 objective_history,
294 grad_norm,
295 })
296 }
297}
298
299pub struct ProximalCoordinateDescent {
306 config: CoordinateDescentConfig,
307}
308
309impl ProximalCoordinateDescent {
310 pub fn new(config: CoordinateDescentConfig) -> Self {
312 Self { config }
313 }
314
315 pub fn minimize_lasso<F, G>(
324 &self,
325 objective: F,
326 gradient: G,
327 x0: &Array1<f64>,
328 ) -> OptimizeResult<CoordinateDescentResult>
329 where
330 F: Fn(&ArrayView1<f64>) -> f64,
331 G: Fn(&ArrayView1<f64>) -> Array1<f64>,
332 {
333 let mut config = self.config.clone();
334 config.regularization = RegularizationType::L1;
335 let solver = CoordinateDescentSolver::new(config);
336 solver.minimize(objective, gradient, x0)
337 }
338
339 pub fn minimize_ridge<F, G>(
346 &self,
347 objective: F,
348 gradient: G,
349 x0: &Array1<f64>,
350 ) -> OptimizeResult<CoordinateDescentResult>
351 where
352 F: Fn(&ArrayView1<f64>) -> f64,
353 G: Fn(&ArrayView1<f64>) -> Array1<f64>,
354 {
355 let mut config = self.config.clone();
356 config.regularization = RegularizationType::L2;
357 let solver = CoordinateDescentSolver::new(config);
358 solver.minimize(objective, gradient, x0)
359 }
360
361 pub fn minimize_elastic_net<F, G>(
368 &self,
369 objective: F,
370 gradient: G,
371 x0: &Array1<f64>,
372 ) -> OptimizeResult<CoordinateDescentResult>
373 where
374 F: Fn(&ArrayView1<f64>) -> f64,
375 G: Fn(&ArrayView1<f64>) -> Array1<f64>,
376 {
377 let mut config = self.config.clone();
378 config.regularization = RegularizationType::ElasticNet;
379 let solver = CoordinateDescentSolver::new(config);
380 solver.minimize(objective, gradient, x0)
381 }
382}
383
384pub struct BlockCoordinateDescent {
389 config: CoordinateDescentConfig,
390 blocks: Vec<Vec<usize>>,
392}
393
394impl BlockCoordinateDescent {
395 pub fn new(config: CoordinateDescentConfig, blocks: Vec<Vec<usize>>) -> Self {
401 Self { config, blocks }
402 }
403
404 pub fn with_uniform_blocks(
411 config: CoordinateDescentConfig,
412 n: usize,
413 block_size: usize,
414 ) -> Self {
415 let mut blocks = Vec::new();
416 let mut start = 0;
417 while start < n {
418 let end = (start + block_size).min(n);
419 blocks.push((start..end).collect());
420 start = end;
421 }
422 Self { config, blocks }
423 }
424
425 pub fn minimize<F, G>(
435 &self,
436 objective: F,
437 gradient: G,
438 x0: &Array1<f64>,
439 ) -> OptimizeResult<CoordinateDescentResult>
440 where
441 F: Fn(&ArrayView1<f64>) -> f64,
442 G: Fn(&ArrayView1<f64>) -> Array1<f64>,
443 {
444 let n = x0.len();
445 if n == 0 {
446 return Err(OptimizeError::InvalidInput(
447 "Initial point must have at least one dimension".to_string(),
448 ));
449 }
450
451 for (bi, block) in self.blocks.iter().enumerate() {
453 for &idx in block {
454 if idx >= n {
455 return Err(OptimizeError::InvalidInput(format!(
456 "Block {} contains index {} which exceeds dimension {}",
457 bi, idx, n
458 )));
459 }
460 }
461 }
462
463 let mut x = x0.clone();
464 let step_size = self.config.step_size.unwrap_or(0.01);
465 let mut rng = StdRng::seed_from_u64(self.config.seed);
466 let mut objective_history = Vec::new();
467
468 let mut prev_obj = objective(&x.view())
469 + regularization_penalty(
470 &x,
471 self.config.regularization,
472 self.config.lambda,
473 self.config.alpha,
474 );
475
476 if self.config.track_objective {
477 objective_history.push(prev_obj);
478 }
479
480 let mut converged = false;
481 let mut iterations = 0;
482 let num_blocks = self.blocks.len();
483
484 for iter in 0..self.config.max_iter {
485 iterations = iter + 1;
486
487 for block_step in 0..num_blocks {
489 let block_idx = match self.config.strategy {
490 CoordinateSelectionStrategy::Cyclic => block_step,
491 CoordinateSelectionStrategy::Randomized => rng.random_range(0..num_blocks),
492 CoordinateSelectionStrategy::Greedy => {
493 let grad = gradient(&x.view());
495 let mut best_block = 0;
496 let mut best_norm = f64::NEG_INFINITY;
497 for (bi, block) in self.blocks.iter().enumerate() {
498 let block_norm_sq: f64 = block.iter().map(|&i| grad[i] * grad[i]).sum();
499 if block_norm_sq > best_norm {
500 best_norm = block_norm_sq;
501 best_block = bi;
502 }
503 }
504 best_block
505 }
506 };
507
508 let block = &self.blocks[block_idx];
509 let grad = gradient(&x.view());
510
511 for &coord in block {
513 match self.config.regularization {
514 RegularizationType::None => {
515 x[coord] -= step_size * grad[coord];
516 }
517 RegularizationType::L1 => {
518 let proposal = x[coord] - step_size * grad[coord];
519 x[coord] = soft_threshold(proposal, step_size * self.config.lambda);
520 }
521 RegularizationType::L2 => {
522 let total_grad = grad[coord] + 2.0 * self.config.lambda * x[coord];
523 x[coord] -= step_size * total_grad;
524 }
525 RegularizationType::ElasticNet => {
526 let l2_grad =
527 2.0 * (1.0 - self.config.alpha) * self.config.lambda * x[coord];
528 let proposal = x[coord] - step_size * (grad[coord] + l2_grad);
529 x[coord] = soft_threshold(
530 proposal,
531 step_size * self.config.alpha * self.config.lambda,
532 );
533 }
534 }
535 }
536 }
537
538 let smooth_obj = objective(&x.view());
539 let total_obj = smooth_obj
540 + regularization_penalty(
541 &x,
542 self.config.regularization,
543 self.config.lambda,
544 self.config.alpha,
545 );
546
547 if self.config.track_objective {
548 objective_history.push(total_obj);
549 }
550
551 let change = (prev_obj - total_obj).abs();
552 prev_obj = total_obj;
553
554 if change < self.config.tol {
555 converged = true;
556 break;
557 }
558 }
559
560 let final_grad = gradient(&x.view());
561 let grad_norm = final_grad.dot(&final_grad).sqrt();
562 let smooth_obj = objective(&x.view());
563 let reg_penalty = regularization_penalty(
564 &x,
565 self.config.regularization,
566 self.config.lambda,
567 self.config.alpha,
568 );
569
570 Ok(CoordinateDescentResult {
571 x,
572 fun: smooth_obj,
573 fun_regularized: smooth_obj + reg_penalty,
574 iterations,
575 converged,
576 objective_history,
577 grad_norm,
578 })
579 }
580}
581
582pub fn coordinate_descent_minimize<F, G>(
590 objective: F,
591 gradient: G,
592 x0: &Array1<f64>,
593 config: Option<CoordinateDescentConfig>,
594) -> OptimizeResult<CoordinateDescentResult>
595where
596 F: Fn(&ArrayView1<f64>) -> f64,
597 G: Fn(&ArrayView1<f64>) -> Array1<f64>,
598{
599 let config = config.unwrap_or_default();
600 let solver = CoordinateDescentSolver::new(config);
601 solver.minimize(objective, gradient, x0)
602}
603
604pub fn lasso_coordinate_descent<F, G>(
608 objective: F,
609 gradient: G,
610 x0: &Array1<f64>,
611 lambda: f64,
612 config: Option<CoordinateDescentConfig>,
613) -> OptimizeResult<CoordinateDescentResult>
614where
615 F: Fn(&ArrayView1<f64>) -> f64,
616 G: Fn(&ArrayView1<f64>) -> Array1<f64>,
617{
618 let mut config = config.unwrap_or_default();
619 config.regularization = RegularizationType::L1;
620 config.lambda = lambda;
621 let solver = CoordinateDescentSolver::new(config);
622 solver.minimize(objective, gradient, x0)
623}
624
625pub fn quadratic_coordinate_descent(
636 a: &Array2<f64>,
637 b: &Array1<f64>,
638 x0: &Array1<f64>,
639 config: Option<CoordinateDescentConfig>,
640) -> OptimizeResult<CoordinateDescentResult> {
641 let n = x0.len();
642 let config = config.unwrap_or_default();
643
644 if a.nrows() != n || a.ncols() != n {
645 return Err(OptimizeError::InvalidInput(format!(
646 "Matrix A has shape ({}, {}), expected ({}, {})",
647 a.nrows(),
648 a.ncols(),
649 n,
650 n
651 )));
652 }
653 if b.len() != n {
654 return Err(OptimizeError::InvalidInput(format!(
655 "Vector b has length {}, expected {}",
656 b.len(),
657 n
658 )));
659 }
660
661 let mut x = x0.clone();
662 let mut rng = StdRng::seed_from_u64(config.seed);
663 let mut objective_history = Vec::new();
664
665 let compute_obj = |x: &Array1<f64>| -> f64 {
667 let ax = a.dot(x);
668 0.5 * x.dot(&ax) - b.dot(x)
669 };
670
671 let mut prev_obj = compute_obj(&x)
672 + regularization_penalty(&x, config.regularization, config.lambda, config.alpha);
673
674 if config.track_objective {
675 objective_history.push(prev_obj);
676 }
677
678 let mut converged = false;
679 let mut iterations = 0;
680
681 for iter in 0..config.max_iter {
682 iterations = iter + 1;
683
684 for _coord_step in 0..n {
685 let coord = match config.strategy {
686 CoordinateSelectionStrategy::Cyclic => _coord_step,
687 CoordinateSelectionStrategy::Randomized => rng.random_range(0..n),
688 CoordinateSelectionStrategy::Greedy => {
689 let grad = a.dot(&x) - b;
691 let mut best = 0;
692 let mut best_val = f64::NEG_INFINITY;
693 for i in 0..n {
694 let abs_g = grad[i].abs();
695 if abs_g > best_val {
696 best_val = abs_g;
697 best = i;
698 }
699 }
700 best
701 }
702 };
703
704 let a_ii = a[[coord, coord]];
705 if a_ii.abs() < 1e-15 {
706 continue; }
708
709 let mut residual_coord = -b[coord];
711 for j in 0..n {
712 residual_coord += a[[coord, j]] * x[j];
713 }
714
715 match config.regularization {
716 RegularizationType::None => {
717 x[coord] -= residual_coord / a_ii;
719 }
720 RegularizationType::L1 => {
721 let rhs = b[coord]
723 - (0..n)
724 .filter(|&j| j != coord)
725 .map(|j| a[[coord, j]] * x[j])
726 .sum::<f64>();
727 x[coord] = soft_threshold(rhs, config.lambda) / a_ii;
728 }
729 RegularizationType::L2 => {
730 let rhs = b[coord]
732 - (0..n)
733 .filter(|&j| j != coord)
734 .map(|j| a[[coord, j]] * x[j])
735 .sum::<f64>();
736 x[coord] = rhs / (a_ii + 2.0 * config.lambda);
737 }
738 RegularizationType::ElasticNet => {
739 let rhs = b[coord]
740 - (0..n)
741 .filter(|&j| j != coord)
742 .map(|j| a[[coord, j]] * x[j])
743 .sum::<f64>();
744 x[coord] = soft_threshold(rhs, config.alpha * config.lambda)
745 / (a_ii + 2.0 * (1.0 - config.alpha) * config.lambda);
746 }
747 }
748 }
749
750 let smooth_obj = compute_obj(&x);
751 let total_obj = smooth_obj
752 + regularization_penalty(&x, config.regularization, config.lambda, config.alpha);
753
754 if config.track_objective {
755 objective_history.push(total_obj);
756 }
757
758 let change = (prev_obj - total_obj).abs();
759 prev_obj = total_obj;
760
761 if change < config.tol {
762 converged = true;
763 break;
764 }
765 }
766
767 let grad = a.dot(&x) - b;
768 let grad_norm = grad.dot(&grad).sqrt();
769 let smooth_obj = compute_obj(&x);
770 let reg_penalty =
771 regularization_penalty(&x, config.regularization, config.lambda, config.alpha);
772
773 Ok(CoordinateDescentResult {
774 x,
775 fun: smooth_obj,
776 fun_regularized: smooth_obj + reg_penalty,
777 iterations,
778 converged,
779 objective_history,
780 grad_norm,
781 })
782}
783
784#[cfg(test)]
785mod tests {
786 use super::*;
787 use scirs2_core::ndarray::{array, Array1, Array2};
788
789 #[test]
791 fn test_cyclic_cd_quadratic_minimum() {
792 let objective = |x: &ArrayView1<f64>| -> f64 { x[0] * x[0] + x[1] * x[1] };
793 let gradient = |x: &ArrayView1<f64>| -> Array1<f64> { array![2.0 * x[0], 2.0 * x[1]] };
794
795 let x0 = array![5.0, 3.0];
796 let config = CoordinateDescentConfig {
797 max_iter: 5000,
798 tol: 1e-12,
799 strategy: CoordinateSelectionStrategy::Cyclic,
800 step_size: Some(0.4),
801 ..Default::default()
802 };
803
804 let result = coordinate_descent_minimize(objective, gradient, &x0, Some(config));
805 assert!(result.is_ok());
806 let result = result.expect("should succeed");
807 assert!(result.converged);
808 assert!(result.x[0].abs() < 1e-5);
809 assert!(result.x[1].abs() < 1e-5);
810 assert!(result.fun < 1e-10);
811 }
812
813 #[test]
815 fn test_lasso_sparse_solution() {
816 let a = Array2::eye(3);
820 let b = array![0.5, 0.05, 0.5];
821 let x0 = array![0.0, 0.0, 0.0];
822
823 let config = CoordinateDescentConfig {
824 max_iter: 1000,
825 tol: 1e-12,
826 regularization: RegularizationType::L1,
827 lambda: 0.1,
828 ..Default::default()
829 };
830
831 let result = quadratic_coordinate_descent(&a, &b, &x0, Some(config));
832 assert!(result.is_ok());
833 let result = result.expect("should succeed");
834 assert!(result.converged);
835 assert!(
837 result.x[1].abs() < 1e-10,
838 "Expected sparse: x[1]={} should be ~0",
839 result.x[1]
840 );
841 assert!(result.x[0].abs() > 0.1);
843 assert!(result.x[2].abs() > 0.1);
844 }
845
846 #[test]
848 fn test_greedy_vs_cyclic_convergence() {
849 let n = 10;
851 let a = Array2::from_diag(&Array1::from_vec((1..=n).map(|i| i as f64).collect()));
852 let b = Array1::ones(n);
853 let x0 = Array1::from_vec(vec![10.0; n]);
854
855 let config_cyclic = CoordinateDescentConfig {
856 max_iter: 50,
857 tol: 1e-20, strategy: CoordinateSelectionStrategy::Cyclic,
859 track_objective: true,
860 ..Default::default()
861 };
862
863 let config_greedy = CoordinateDescentConfig {
864 max_iter: 50,
865 tol: 1e-20,
866 strategy: CoordinateSelectionStrategy::Greedy,
867 track_objective: true,
868 ..Default::default()
869 };
870
871 let result_cyclic = quadratic_coordinate_descent(&a, &b, &x0, Some(config_cyclic));
872 let result_greedy = quadratic_coordinate_descent(&a, &b, &x0, Some(config_greedy));
873
874 assert!(result_cyclic.is_ok());
875 assert!(result_greedy.is_ok());
876 let r_cyclic = result_cyclic.expect("cyclic should succeed");
877 let r_greedy = result_greedy.expect("greedy should succeed");
878
879 assert!(r_cyclic.fun.is_finite());
882 assert!(r_greedy.fun.is_finite());
883 }
884
885 #[test]
887 fn test_randomized_cd_converges() {
888 let objective = |x: &ArrayView1<f64>| -> f64 {
889 0.5 * (x[0] - 1.0).powi(2) + 0.5 * (x[1] - 2.0).powi(2)
890 };
891 let gradient = |x: &ArrayView1<f64>| -> Array1<f64> { array![x[0] - 1.0, x[1] - 2.0] };
892
893 let x0 = array![10.0, -5.0];
894 let config = CoordinateDescentConfig {
895 max_iter: 10000,
896 tol: 1e-10,
897 strategy: CoordinateSelectionStrategy::Randomized,
898 step_size: Some(0.9),
899 seed: 123,
900 ..Default::default()
901 };
902
903 let result = coordinate_descent_minimize(objective, gradient, &x0, Some(config));
904 assert!(result.is_ok());
905 let result = result.expect("should succeed");
906 assert!(result.converged);
907 assert!((result.x[0] - 1.0).abs() < 1e-4, "x[0]={}", result.x[0]);
908 assert!((result.x[1] - 2.0).abs() < 1e-4, "x[1]={}", result.x[1]);
909 }
910
911 #[test]
913 fn test_block_cd() {
914 let objective = |x: &ArrayView1<f64>| -> f64 {
915 (x[0] - 1.0).powi(2)
916 + (x[1] - 2.0).powi(2)
917 + (x[2] - 3.0).powi(2)
918 + (x[3] - 4.0).powi(2)
919 };
920 let gradient = |x: &ArrayView1<f64>| -> Array1<f64> {
921 array![
922 2.0 * (x[0] - 1.0),
923 2.0 * (x[1] - 2.0),
924 2.0 * (x[2] - 3.0),
925 2.0 * (x[3] - 4.0)
926 ]
927 };
928
929 let x0 = array![0.0, 0.0, 0.0, 0.0];
930 let config = CoordinateDescentConfig {
931 max_iter: 5000,
932 tol: 1e-12,
933 step_size: Some(0.4),
934 ..Default::default()
935 };
936
937 let solver = BlockCoordinateDescent::with_uniform_blocks(config, 4, 2);
938 let result = solver.minimize(objective, gradient, &x0);
939 assert!(result.is_ok());
940 let result = result.expect("should succeed");
941 assert!(result.converged);
942 assert!((result.x[0] - 1.0).abs() < 1e-4);
943 assert!((result.x[1] - 2.0).abs() < 1e-4);
944 assert!((result.x[2] - 3.0).abs() < 1e-4);
945 assert!((result.x[3] - 4.0).abs() < 1e-4);
946 }
947
948 #[test]
950 fn test_quadratic_cd_exact() {
951 let a = array![[2.0, 1.0], [1.0, 3.0]];
954 let b = array![1.0, 2.0];
955 let x0 = array![0.0, 0.0];
956
957 let config = CoordinateDescentConfig {
958 max_iter: 500,
959 tol: 1e-14,
960 ..Default::default()
961 };
962
963 let result = quadratic_coordinate_descent(&a, &b, &x0, Some(config));
964 assert!(result.is_ok());
965 let result = result.expect("should succeed");
966 assert!(result.converged);
967 assert!(
970 (result.x[0] - 0.2).abs() < 1e-8,
971 "x[0]={}, expected 0.2",
972 result.x[0]
973 );
974 assert!(
975 (result.x[1] - 0.6).abs() < 1e-8,
976 "x[1]={}, expected ~0.6",
977 result.x[1]
978 );
979 }
980
981 #[test]
983 fn test_ridge_cd() {
984 let a = Array2::eye(3);
985 let b = array![1.0, 2.0, 3.0];
986 let x0 = array![0.0, 0.0, 0.0];
987
988 let config = CoordinateDescentConfig {
989 max_iter: 1000,
990 tol: 1e-14,
991 regularization: RegularizationType::L2,
992 lambda: 0.5,
993 ..Default::default()
994 };
995
996 let result = quadratic_coordinate_descent(&a, &b, &x0, Some(config));
997 assert!(result.is_ok());
998 let result = result.expect("should succeed");
999 assert!((result.x[0] - 0.5).abs() < 1e-8, "x[0]={}", result.x[0]);
1001 assert!((result.x[1] - 1.0).abs() < 1e-8, "x[1]={}", result.x[1]);
1002 assert!((result.x[2] - 1.5).abs() < 1e-8, "x[2]={}", result.x[2]);
1003 }
1004
1005 #[test]
1007 fn test_objective_history_tracking() {
1008 let a = Array2::eye(2);
1009 let b = array![1.0, 1.0];
1010 let x0 = array![5.0, 5.0];
1011
1012 let config = CoordinateDescentConfig {
1013 max_iter: 20,
1014 tol: 1e-20,
1015 track_objective: true,
1016 ..Default::default()
1017 };
1018
1019 let result = quadratic_coordinate_descent(&a, &b, &x0, Some(config));
1020 assert!(result.is_ok());
1021 let result = result.expect("should succeed");
1022 assert!(result.objective_history.len() > 1);
1024 for i in 1..result.objective_history.len() {
1026 assert!(
1027 result.objective_history[i] <= result.objective_history[i - 1] + 1e-12,
1028 "Objective increased at iter {}: {} -> {}",
1029 i,
1030 result.objective_history[i - 1],
1031 result.objective_history[i]
1032 );
1033 }
1034 }
1035}