Skip to main content

scirs2_optimize/high_dimensional/
coordinate_descent.rs

1//! Coordinate Descent optimization methods
2//!
3//! Implements multiple variants of coordinate descent for high-dimensional optimization:
4//!
5//! - **Cyclic**: Optimize one variable at a time, cycling through all
6//! - **Randomized**: Randomly select coordinate to update
7//! - **Greedy (Gauss-Southwell)**: Select coordinate with largest gradient magnitude
8//! - **Proximal**: Support for L1 (Lasso) and L2 (Ridge) regularization
9//! - **Block**: Update groups of variables together
10//!
11//! Coordinate descent is particularly effective for problems where the per-coordinate
12//! update is cheap and the problem has separable structure.
13
14use crate::error::{OptimizeError, OptimizeResult};
15use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
16use scirs2_core::random::{rngs::StdRng, RngExt, SeedableRng};
17
18/// Strategy for selecting which coordinate to update
19#[derive(Debug, Clone, Copy, PartialEq)]
20pub enum CoordinateSelectionStrategy {
21    /// Cycle through coordinates in order 0, 1, ..., n-1, 0, 1, ...
22    Cyclic,
23    /// Randomly select a coordinate uniformly at random
24    Randomized,
25    /// Select the coordinate with the largest absolute gradient (Gauss-Southwell rule)
26    Greedy,
27}
28
29/// Type of regularization for proximal coordinate descent
30#[derive(Debug, Clone, Copy, PartialEq)]
31pub enum RegularizationType {
32    /// No regularization
33    None,
34    /// L1 (Lasso) regularization: lambda * ||x||_1
35    L1,
36    /// L2 (Ridge) regularization: lambda * ||x||_2^2
37    L2,
38    /// Elastic net: alpha * lambda * ||x||_1 + (1 - alpha) * lambda * ||x||_2^2
39    ElasticNet,
40}
41
42/// Configuration for coordinate descent solver
43#[derive(Debug, Clone)]
44pub struct CoordinateDescentConfig {
45    /// Maximum number of full passes over all coordinates
46    pub max_iter: usize,
47    /// Convergence tolerance on the change in objective value
48    pub tol: f64,
49    /// Coordinate selection strategy
50    pub strategy: CoordinateSelectionStrategy,
51    /// Step size (learning rate). If None, uses exact line search for quadratics
52    pub step_size: Option<f64>,
53    /// Regularization type
54    pub regularization: RegularizationType,
55    /// Regularization strength (lambda)
56    pub lambda: f64,
57    /// Elastic net mixing parameter alpha in \[0,1\] (only used for ElasticNet)
58    pub alpha: f64,
59    /// Random seed for reproducibility (used with Randomized strategy)
60    pub seed: u64,
61    /// Whether to track objective value history
62    pub track_objective: bool,
63}
64
65impl Default for CoordinateDescentConfig {
66    fn default() -> Self {
67        Self {
68            max_iter: 1000,
69            tol: 1e-8,
70            strategy: CoordinateSelectionStrategy::Cyclic,
71            step_size: None,
72            regularization: RegularizationType::None,
73            lambda: 0.0,
74            alpha: 0.5,
75            seed: 42,
76            track_objective: false,
77        }
78    }
79}
80
81/// Result of coordinate descent optimization
82#[derive(Debug, Clone)]
83pub struct CoordinateDescentResult {
84    /// Optimal solution vector
85    pub x: Array1<f64>,
86    /// Final objective value (smooth part only, not including regularization)
87    pub fun: f64,
88    /// Final objective value including regularization
89    pub fun_regularized: f64,
90    /// Number of full iterations (passes over all coordinates)
91    pub iterations: usize,
92    /// Whether the solver converged
93    pub converged: bool,
94    /// History of objective values (if tracking enabled)
95    pub objective_history: Vec<f64>,
96    /// Final gradient norm
97    pub grad_norm: f64,
98}
99
100/// Soft-thresholding operator for L1 proximal step
101///
102/// S(x, t) = sign(x) * max(|x| - t, 0)
103fn soft_threshold(x: f64, threshold: f64) -> f64 {
104    if x > threshold {
105        x - threshold
106    } else if x < -threshold {
107        x + threshold
108    } else {
109        0.0
110    }
111}
112
113/// Compute regularization penalty value
114fn regularization_penalty(
115    x: &Array1<f64>,
116    reg_type: RegularizationType,
117    lambda: f64,
118    alpha: f64,
119) -> f64 {
120    match reg_type {
121        RegularizationType::None => 0.0,
122        RegularizationType::L1 => lambda * x.mapv(f64::abs).sum(),
123        RegularizationType::L2 => lambda * x.dot(x),
124        RegularizationType::ElasticNet => {
125            let l1_part = alpha * lambda * x.mapv(f64::abs).sum();
126            let l2_part = (1.0 - alpha) * lambda * x.dot(x);
127            l1_part + l2_part
128        }
129    }
130}
131
132/// Coordinate Descent Solver
133///
134/// Minimizes f(x) + g(x), where f is a smooth objective and g is a separable
135/// (possibly non-smooth) regularization term.
136pub struct CoordinateDescentSolver {
137    config: CoordinateDescentConfig,
138}
139
140impl CoordinateDescentSolver {
141    /// Create a new coordinate descent solver with the given configuration
142    pub fn new(config: CoordinateDescentConfig) -> Self {
143        Self { config }
144    }
145
146    /// Create a solver with default configuration
147    pub fn default_solver() -> Self {
148        Self::new(CoordinateDescentConfig::default())
149    }
150
151    /// Minimize the objective function f(x) given gradient function grad_f(x)
152    ///
153    /// # Arguments
154    /// * `objective` - Smooth objective function f(x)
155    /// * `gradient` - Gradient of the smooth part: grad f(x)
156    /// * `x0` - Initial point
157    ///
158    /// # Returns
159    /// The optimization result containing the solution and convergence info
160    pub fn minimize<F, G>(
161        &self,
162        objective: F,
163        gradient: G,
164        x0: &Array1<f64>,
165    ) -> OptimizeResult<CoordinateDescentResult>
166    where
167        F: Fn(&ArrayView1<f64>) -> f64,
168        G: Fn(&ArrayView1<f64>) -> Array1<f64>,
169    {
170        let n = x0.len();
171        if n == 0 {
172            return Err(OptimizeError::InvalidInput(
173                "Initial point must have at least one dimension".to_string(),
174            ));
175        }
176
177        let mut x = x0.clone();
178        let mut rng = StdRng::seed_from_u64(self.config.seed);
179        let mut objective_history = Vec::new();
180
181        let step_size = self.config.step_size.unwrap_or(0.01);
182
183        let mut prev_obj = objective(&x.view())
184            + regularization_penalty(
185                &x,
186                self.config.regularization,
187                self.config.lambda,
188                self.config.alpha,
189            );
190
191        if self.config.track_objective {
192            objective_history.push(prev_obj);
193        }
194
195        let mut converged = false;
196        let mut iterations = 0;
197
198        for iter in 0..self.config.max_iter {
199            iterations = iter + 1;
200
201            // One full pass over coordinates
202            for _coord_step in 0..n {
203                let coord = match self.config.strategy {
204                    CoordinateSelectionStrategy::Cyclic => _coord_step,
205                    CoordinateSelectionStrategy::Randomized => rng.random_range(0..n),
206                    CoordinateSelectionStrategy::Greedy => {
207                        let grad = gradient(&x.view());
208                        // Select coordinate with largest absolute gradient
209                        let mut best_coord = 0;
210                        let mut best_abs_grad = f64::NEG_INFINITY;
211                        for i in 0..n {
212                            let abs_g = grad[i].abs();
213                            if abs_g > best_abs_grad {
214                                best_abs_grad = abs_g;
215                                best_coord = i;
216                            }
217                        }
218                        best_coord
219                    }
220                };
221
222                // Compute partial gradient for selected coordinate
223                let grad = gradient(&x.view());
224                let grad_coord = grad[coord];
225
226                // Update step depends on regularization
227                match self.config.regularization {
228                    RegularizationType::None => {
229                        x[coord] -= step_size * grad_coord;
230                    }
231                    RegularizationType::L1 => {
232                        // Proximal gradient step: soft threshold
233                        let proposal = x[coord] - step_size * grad_coord;
234                        x[coord] = soft_threshold(proposal, step_size * self.config.lambda);
235                    }
236                    RegularizationType::L2 => {
237                        // L2 gradient includes 2*lambda*x_i term
238                        let total_grad = grad_coord + 2.0 * self.config.lambda * x[coord];
239                        x[coord] -= step_size * total_grad;
240                    }
241                    RegularizationType::ElasticNet => {
242                        // L2 gradient part
243                        let l2_grad =
244                            2.0 * (1.0 - self.config.alpha) * self.config.lambda * x[coord];
245                        let proposal = x[coord] - step_size * (grad_coord + l2_grad);
246                        // L1 proximal part
247                        x[coord] = soft_threshold(
248                            proposal,
249                            step_size * self.config.alpha * self.config.lambda,
250                        );
251                    }
252                }
253            }
254
255            let smooth_obj = objective(&x.view());
256            let total_obj = smooth_obj
257                + regularization_penalty(
258                    &x,
259                    self.config.regularization,
260                    self.config.lambda,
261                    self.config.alpha,
262                );
263
264            if self.config.track_objective {
265                objective_history.push(total_obj);
266            }
267
268            let change = (prev_obj - total_obj).abs();
269            prev_obj = total_obj;
270
271            if change < self.config.tol {
272                converged = true;
273                break;
274            }
275        }
276
277        let final_grad = gradient(&x.view());
278        let grad_norm = final_grad.dot(&final_grad).sqrt();
279        let smooth_obj = objective(&x.view());
280        let reg_penalty = regularization_penalty(
281            &x,
282            self.config.regularization,
283            self.config.lambda,
284            self.config.alpha,
285        );
286
287        Ok(CoordinateDescentResult {
288            x,
289            fun: smooth_obj,
290            fun_regularized: smooth_obj + reg_penalty,
291            iterations,
292            converged,
293            objective_history,
294            grad_norm,
295        })
296    }
297}
298
299/// Proximal Coordinate Descent
300///
301/// Specialized for composite optimization problems of the form:
302///   minimize f(x) + sum_i g_i(x_i)
303///
304/// where f is smooth and each g_i is a separable (possibly non-smooth) penalty.
305pub struct ProximalCoordinateDescent {
306    config: CoordinateDescentConfig,
307}
308
309impl ProximalCoordinateDescent {
310    /// Create a new proximal coordinate descent solver
311    pub fn new(config: CoordinateDescentConfig) -> Self {
312        Self { config }
313    }
314
315    /// Minimize f(x) + lambda * ||x||_1 (Lasso)
316    ///
317    /// Uses coordinate-wise soft-thresholding updates.
318    ///
319    /// # Arguments
320    /// * `objective` - Smooth part f(x)
321    /// * `gradient` - Gradient of f
322    /// * `x0` - Initial point
323    pub fn minimize_lasso<F, G>(
324        &self,
325        objective: F,
326        gradient: G,
327        x0: &Array1<f64>,
328    ) -> OptimizeResult<CoordinateDescentResult>
329    where
330        F: Fn(&ArrayView1<f64>) -> f64,
331        G: Fn(&ArrayView1<f64>) -> Array1<f64>,
332    {
333        let mut config = self.config.clone();
334        config.regularization = RegularizationType::L1;
335        let solver = CoordinateDescentSolver::new(config);
336        solver.minimize(objective, gradient, x0)
337    }
338
339    /// Minimize f(x) + lambda * ||x||_2^2 (Ridge)
340    ///
341    /// # Arguments
342    /// * `objective` - Smooth part f(x)
343    /// * `gradient` - Gradient of f
344    /// * `x0` - Initial point
345    pub fn minimize_ridge<F, G>(
346        &self,
347        objective: F,
348        gradient: G,
349        x0: &Array1<f64>,
350    ) -> OptimizeResult<CoordinateDescentResult>
351    where
352        F: Fn(&ArrayView1<f64>) -> f64,
353        G: Fn(&ArrayView1<f64>) -> Array1<f64>,
354    {
355        let mut config = self.config.clone();
356        config.regularization = RegularizationType::L2;
357        let solver = CoordinateDescentSolver::new(config);
358        solver.minimize(objective, gradient, x0)
359    }
360
361    /// Minimize f(x) + alpha*lambda*||x||_1 + (1-alpha)*lambda*||x||_2^2 (Elastic Net)
362    ///
363    /// # Arguments
364    /// * `objective` - Smooth part f(x)
365    /// * `gradient` - Gradient of f
366    /// * `x0` - Initial point
367    pub fn minimize_elastic_net<F, G>(
368        &self,
369        objective: F,
370        gradient: G,
371        x0: &Array1<f64>,
372    ) -> OptimizeResult<CoordinateDescentResult>
373    where
374        F: Fn(&ArrayView1<f64>) -> f64,
375        G: Fn(&ArrayView1<f64>) -> Array1<f64>,
376    {
377        let mut config = self.config.clone();
378        config.regularization = RegularizationType::ElasticNet;
379        let solver = CoordinateDescentSolver::new(config);
380        solver.minimize(objective, gradient, x0)
381    }
382}
383
384/// Block Coordinate Descent
385///
386/// Updates groups (blocks) of variables together rather than single coordinates.
387/// This is useful when variables have strong within-block dependencies.
388pub struct BlockCoordinateDescent {
389    config: CoordinateDescentConfig,
390    /// Block definitions: each inner Vec contains coordinate indices for that block
391    blocks: Vec<Vec<usize>>,
392}
393
394impl BlockCoordinateDescent {
395    /// Create a new block coordinate descent solver
396    ///
397    /// # Arguments
398    /// * `config` - Solver configuration
399    /// * `blocks` - Block definitions, each block is a list of coordinate indices
400    pub fn new(config: CoordinateDescentConfig, blocks: Vec<Vec<usize>>) -> Self {
401        Self { config, blocks }
402    }
403
404    /// Create blocks of equal size from dimension n
405    ///
406    /// # Arguments
407    /// * `config` - Solver configuration
408    /// * `n` - Total number of variables
409    /// * `block_size` - Number of variables per block
410    pub fn with_uniform_blocks(
411        config: CoordinateDescentConfig,
412        n: usize,
413        block_size: usize,
414    ) -> Self {
415        let mut blocks = Vec::new();
416        let mut start = 0;
417        while start < n {
418            let end = (start + block_size).min(n);
419            blocks.push((start..end).collect());
420            start = end;
421        }
422        Self { config, blocks }
423    }
424
425    /// Minimize using block coordinate descent
426    ///
427    /// For each block, computes the gradient restricted to those coordinates
428    /// and performs a gradient step on the block variables.
429    ///
430    /// # Arguments
431    /// * `objective` - Objective function f(x)
432    /// * `gradient` - Full gradient of f
433    /// * `x0` - Initial point
434    pub fn minimize<F, G>(
435        &self,
436        objective: F,
437        gradient: G,
438        x0: &Array1<f64>,
439    ) -> OptimizeResult<CoordinateDescentResult>
440    where
441        F: Fn(&ArrayView1<f64>) -> f64,
442        G: Fn(&ArrayView1<f64>) -> Array1<f64>,
443    {
444        let n = x0.len();
445        if n == 0 {
446            return Err(OptimizeError::InvalidInput(
447                "Initial point must have at least one dimension".to_string(),
448            ));
449        }
450
451        // Validate blocks
452        for (bi, block) in self.blocks.iter().enumerate() {
453            for &idx in block {
454                if idx >= n {
455                    return Err(OptimizeError::InvalidInput(format!(
456                        "Block {} contains index {} which exceeds dimension {}",
457                        bi, idx, n
458                    )));
459                }
460            }
461        }
462
463        let mut x = x0.clone();
464        let step_size = self.config.step_size.unwrap_or(0.01);
465        let mut rng = StdRng::seed_from_u64(self.config.seed);
466        let mut objective_history = Vec::new();
467
468        let mut prev_obj = objective(&x.view())
469            + regularization_penalty(
470                &x,
471                self.config.regularization,
472                self.config.lambda,
473                self.config.alpha,
474            );
475
476        if self.config.track_objective {
477            objective_history.push(prev_obj);
478        }
479
480        let mut converged = false;
481        let mut iterations = 0;
482        let num_blocks = self.blocks.len();
483
484        for iter in 0..self.config.max_iter {
485            iterations = iter + 1;
486
487            // Iterate over blocks
488            for block_step in 0..num_blocks {
489                let block_idx = match self.config.strategy {
490                    CoordinateSelectionStrategy::Cyclic => block_step,
491                    CoordinateSelectionStrategy::Randomized => rng.random_range(0..num_blocks),
492                    CoordinateSelectionStrategy::Greedy => {
493                        // Select block with largest gradient norm
494                        let grad = gradient(&x.view());
495                        let mut best_block = 0;
496                        let mut best_norm = f64::NEG_INFINITY;
497                        for (bi, block) in self.blocks.iter().enumerate() {
498                            let block_norm_sq: f64 = block.iter().map(|&i| grad[i] * grad[i]).sum();
499                            if block_norm_sq > best_norm {
500                                best_norm = block_norm_sq;
501                                best_block = bi;
502                            }
503                        }
504                        best_block
505                    }
506                };
507
508                let block = &self.blocks[block_idx];
509                let grad = gradient(&x.view());
510
511                // Update all coordinates in the block
512                for &coord in block {
513                    match self.config.regularization {
514                        RegularizationType::None => {
515                            x[coord] -= step_size * grad[coord];
516                        }
517                        RegularizationType::L1 => {
518                            let proposal = x[coord] - step_size * grad[coord];
519                            x[coord] = soft_threshold(proposal, step_size * self.config.lambda);
520                        }
521                        RegularizationType::L2 => {
522                            let total_grad = grad[coord] + 2.0 * self.config.lambda * x[coord];
523                            x[coord] -= step_size * total_grad;
524                        }
525                        RegularizationType::ElasticNet => {
526                            let l2_grad =
527                                2.0 * (1.0 - self.config.alpha) * self.config.lambda * x[coord];
528                            let proposal = x[coord] - step_size * (grad[coord] + l2_grad);
529                            x[coord] = soft_threshold(
530                                proposal,
531                                step_size * self.config.alpha * self.config.lambda,
532                            );
533                        }
534                    }
535                }
536            }
537
538            let smooth_obj = objective(&x.view());
539            let total_obj = smooth_obj
540                + regularization_penalty(
541                    &x,
542                    self.config.regularization,
543                    self.config.lambda,
544                    self.config.alpha,
545                );
546
547            if self.config.track_objective {
548                objective_history.push(total_obj);
549            }
550
551            let change = (prev_obj - total_obj).abs();
552            prev_obj = total_obj;
553
554            if change < self.config.tol {
555                converged = true;
556                break;
557            }
558        }
559
560        let final_grad = gradient(&x.view());
561        let grad_norm = final_grad.dot(&final_grad).sqrt();
562        let smooth_obj = objective(&x.view());
563        let reg_penalty = regularization_penalty(
564            &x,
565            self.config.regularization,
566            self.config.lambda,
567            self.config.alpha,
568        );
569
570        Ok(CoordinateDescentResult {
571            x,
572            fun: smooth_obj,
573            fun_regularized: smooth_obj + reg_penalty,
574            iterations,
575            converged,
576            objective_history,
577            grad_norm,
578        })
579    }
580}
581
582/// Convenience function: minimize a smooth objective using cyclic coordinate descent
583///
584/// # Arguments
585/// * `objective` - Objective function f(x) -> f64
586/// * `gradient` - Gradient function grad f(x) -> `Array1<f64>`
587/// * `x0` - Initial point
588/// * `config` - Optional configuration (uses defaults if None)
589pub fn coordinate_descent_minimize<F, G>(
590    objective: F,
591    gradient: G,
592    x0: &Array1<f64>,
593    config: Option<CoordinateDescentConfig>,
594) -> OptimizeResult<CoordinateDescentResult>
595where
596    F: Fn(&ArrayView1<f64>) -> f64,
597    G: Fn(&ArrayView1<f64>) -> Array1<f64>,
598{
599    let config = config.unwrap_or_default();
600    let solver = CoordinateDescentSolver::new(config);
601    solver.minimize(objective, gradient, x0)
602}
603
604/// Convenience function: minimize with L1 (Lasso) regularization
605///
606/// Solves: min_x f(x) + lambda * ||x||_1
607pub fn lasso_coordinate_descent<F, G>(
608    objective: F,
609    gradient: G,
610    x0: &Array1<f64>,
611    lambda: f64,
612    config: Option<CoordinateDescentConfig>,
613) -> OptimizeResult<CoordinateDescentResult>
614where
615    F: Fn(&ArrayView1<f64>) -> f64,
616    G: Fn(&ArrayView1<f64>) -> Array1<f64>,
617{
618    let mut config = config.unwrap_or_default();
619    config.regularization = RegularizationType::L1;
620    config.lambda = lambda;
621    let solver = CoordinateDescentSolver::new(config);
622    solver.minimize(objective, gradient, x0)
623}
624
625/// Coordinate descent for quadratic objectives: min 0.5 * x^T A x - b^T x
626///
627/// When the objective is quadratic with a known Hessian, we can compute exact
628/// coordinate-wise minimizers without a step size parameter.
629///
630/// # Arguments
631/// * `a` - Symmetric positive definite matrix (Hessian)
632/// * `b` - Linear term
633/// * `x0` - Initial point
634/// * `config` - Optional solver configuration
635pub fn quadratic_coordinate_descent(
636    a: &Array2<f64>,
637    b: &Array1<f64>,
638    x0: &Array1<f64>,
639    config: Option<CoordinateDescentConfig>,
640) -> OptimizeResult<CoordinateDescentResult> {
641    let n = x0.len();
642    let config = config.unwrap_or_default();
643
644    if a.nrows() != n || a.ncols() != n {
645        return Err(OptimizeError::InvalidInput(format!(
646            "Matrix A has shape ({}, {}), expected ({}, {})",
647            a.nrows(),
648            a.ncols(),
649            n,
650            n
651        )));
652    }
653    if b.len() != n {
654        return Err(OptimizeError::InvalidInput(format!(
655            "Vector b has length {}, expected {}",
656            b.len(),
657            n
658        )));
659    }
660
661    let mut x = x0.clone();
662    let mut rng = StdRng::seed_from_u64(config.seed);
663    let mut objective_history = Vec::new();
664
665    // Objective: 0.5 * x^T A x - b^T x
666    let compute_obj = |x: &Array1<f64>| -> f64 {
667        let ax = a.dot(x);
668        0.5 * x.dot(&ax) - b.dot(x)
669    };
670
671    let mut prev_obj = compute_obj(&x)
672        + regularization_penalty(&x, config.regularization, config.lambda, config.alpha);
673
674    if config.track_objective {
675        objective_history.push(prev_obj);
676    }
677
678    let mut converged = false;
679    let mut iterations = 0;
680
681    for iter in 0..config.max_iter {
682        iterations = iter + 1;
683
684        for _coord_step in 0..n {
685            let coord = match config.strategy {
686                CoordinateSelectionStrategy::Cyclic => _coord_step,
687                CoordinateSelectionStrategy::Randomized => rng.random_range(0..n),
688                CoordinateSelectionStrategy::Greedy => {
689                    // Gradient = Ax - b
690                    let grad = a.dot(&x) - b;
691                    let mut best = 0;
692                    let mut best_val = f64::NEG_INFINITY;
693                    for i in 0..n {
694                        let abs_g = grad[i].abs();
695                        if abs_g > best_val {
696                            best_val = abs_g;
697                            best = i;
698                        }
699                    }
700                    best
701                }
702            };
703
704            let a_ii = a[[coord, coord]];
705            if a_ii.abs() < 1e-15 {
706                continue; // Skip degenerate coordinate
707            }
708
709            // Compute residual for this coordinate: (Ax - b)[coord]
710            let mut residual_coord = -b[coord];
711            for j in 0..n {
712                residual_coord += a[[coord, j]] * x[j];
713            }
714
715            match config.regularization {
716                RegularizationType::None => {
717                    // Exact minimizer: x_i = (b_i - sum_{j!=i} A_{ij} x_j) / A_{ii}
718                    x[coord] -= residual_coord / a_ii;
719                }
720                RegularizationType::L1 => {
721                    // Exact coordinate update with L1
722                    let rhs = b[coord]
723                        - (0..n)
724                            .filter(|&j| j != coord)
725                            .map(|j| a[[coord, j]] * x[j])
726                            .sum::<f64>();
727                    x[coord] = soft_threshold(rhs, config.lambda) / a_ii;
728                }
729                RegularizationType::L2 => {
730                    // With L2: A_{ii} + 2*lambda in denominator
731                    let rhs = b[coord]
732                        - (0..n)
733                            .filter(|&j| j != coord)
734                            .map(|j| a[[coord, j]] * x[j])
735                            .sum::<f64>();
736                    x[coord] = rhs / (a_ii + 2.0 * config.lambda);
737                }
738                RegularizationType::ElasticNet => {
739                    let rhs = b[coord]
740                        - (0..n)
741                            .filter(|&j| j != coord)
742                            .map(|j| a[[coord, j]] * x[j])
743                            .sum::<f64>();
744                    x[coord] = soft_threshold(rhs, config.alpha * config.lambda)
745                        / (a_ii + 2.0 * (1.0 - config.alpha) * config.lambda);
746                }
747            }
748        }
749
750        let smooth_obj = compute_obj(&x);
751        let total_obj = smooth_obj
752            + regularization_penalty(&x, config.regularization, config.lambda, config.alpha);
753
754        if config.track_objective {
755            objective_history.push(total_obj);
756        }
757
758        let change = (prev_obj - total_obj).abs();
759        prev_obj = total_obj;
760
761        if change < config.tol {
762            converged = true;
763            break;
764        }
765    }
766
767    let grad = a.dot(&x) - b;
768    let grad_norm = grad.dot(&grad).sqrt();
769    let smooth_obj = compute_obj(&x);
770    let reg_penalty =
771        regularization_penalty(&x, config.regularization, config.lambda, config.alpha);
772
773    Ok(CoordinateDescentResult {
774        x,
775        fun: smooth_obj,
776        fun_regularized: smooth_obj + reg_penalty,
777        iterations,
778        converged,
779        objective_history,
780        grad_norm,
781    })
782}
783
784#[cfg(test)]
785mod tests {
786    use super::*;
787    use scirs2_core::ndarray::{array, Array1, Array2};
788
789    /// Test 1: Minimize a simple quadratic f(x) = x_1^2 + x_2^2 to the known optimum (0, 0)
790    #[test]
791    fn test_cyclic_cd_quadratic_minimum() {
792        let objective = |x: &ArrayView1<f64>| -> f64 { x[0] * x[0] + x[1] * x[1] };
793        let gradient = |x: &ArrayView1<f64>| -> Array1<f64> { array![2.0 * x[0], 2.0 * x[1]] };
794
795        let x0 = array![5.0, 3.0];
796        let config = CoordinateDescentConfig {
797            max_iter: 5000,
798            tol: 1e-12,
799            strategy: CoordinateSelectionStrategy::Cyclic,
800            step_size: Some(0.4),
801            ..Default::default()
802        };
803
804        let result = coordinate_descent_minimize(objective, gradient, &x0, Some(config));
805        assert!(result.is_ok());
806        let result = result.expect("should succeed");
807        assert!(result.converged);
808        assert!(result.x[0].abs() < 1e-5);
809        assert!(result.x[1].abs() < 1e-5);
810        assert!(result.fun < 1e-10);
811    }
812
813    /// Test 2: Lasso regression produces a sparse solution
814    #[test]
815    fn test_lasso_sparse_solution() {
816        // Quadratic with Lasso: min 0.5 * x^T A x - b^T x + lambda * ||x||_1
817        // A = identity, b = [0.5, 0.05, 0.5] with lambda = 0.1
818        // L1 should zero out the second component (0.05 < lambda=0.1)
819        let a = Array2::eye(3);
820        let b = array![0.5, 0.05, 0.5];
821        let x0 = array![0.0, 0.0, 0.0];
822
823        let config = CoordinateDescentConfig {
824            max_iter: 1000,
825            tol: 1e-12,
826            regularization: RegularizationType::L1,
827            lambda: 0.1,
828            ..Default::default()
829        };
830
831        let result = quadratic_coordinate_descent(&a, &b, &x0, Some(config));
832        assert!(result.is_ok());
833        let result = result.expect("should succeed");
834        assert!(result.converged);
835        // x[1] should be zero (sparse)
836        assert!(
837            result.x[1].abs() < 1e-10,
838            "Expected sparse: x[1]={} should be ~0",
839            result.x[1]
840        );
841        // x[0] and x[2] should be nonzero (b_i > lambda)
842        assert!(result.x[0].abs() > 0.1);
843        assert!(result.x[2].abs() > 0.1);
844    }
845
846    /// Test 3: Convergence rate comparison - greedy vs cyclic
847    #[test]
848    fn test_greedy_vs_cyclic_convergence() {
849        // Diagonal quadratic: f(x) = sum_i (i+1) * x_i^2 / 2
850        let n = 10;
851        let a = Array2::from_diag(&Array1::from_vec((1..=n).map(|i| i as f64).collect()));
852        let b = Array1::ones(n);
853        let x0 = Array1::from_vec(vec![10.0; n]);
854
855        let config_cyclic = CoordinateDescentConfig {
856            max_iter: 50,
857            tol: 1e-20, // Don't stop early
858            strategy: CoordinateSelectionStrategy::Cyclic,
859            track_objective: true,
860            ..Default::default()
861        };
862
863        let config_greedy = CoordinateDescentConfig {
864            max_iter: 50,
865            tol: 1e-20,
866            strategy: CoordinateSelectionStrategy::Greedy,
867            track_objective: true,
868            ..Default::default()
869        };
870
871        let result_cyclic = quadratic_coordinate_descent(&a, &b, &x0, Some(config_cyclic));
872        let result_greedy = quadratic_coordinate_descent(&a, &b, &x0, Some(config_greedy));
873
874        assert!(result_cyclic.is_ok());
875        assert!(result_greedy.is_ok());
876        let r_cyclic = result_cyclic.expect("cyclic should succeed");
877        let r_greedy = result_greedy.expect("greedy should succeed");
878
879        // Both should converge; greedy should have at least comparable final objective
880        // (greedy is typically faster on ill-conditioned problems)
881        assert!(r_cyclic.fun.is_finite());
882        assert!(r_greedy.fun.is_finite());
883    }
884
885    /// Test 4: Randomized coordinate descent converges
886    #[test]
887    fn test_randomized_cd_converges() {
888        let objective = |x: &ArrayView1<f64>| -> f64 {
889            0.5 * (x[0] - 1.0).powi(2) + 0.5 * (x[1] - 2.0).powi(2)
890        };
891        let gradient = |x: &ArrayView1<f64>| -> Array1<f64> { array![x[0] - 1.0, x[1] - 2.0] };
892
893        let x0 = array![10.0, -5.0];
894        let config = CoordinateDescentConfig {
895            max_iter: 10000,
896            tol: 1e-10,
897            strategy: CoordinateSelectionStrategy::Randomized,
898            step_size: Some(0.9),
899            seed: 123,
900            ..Default::default()
901        };
902
903        let result = coordinate_descent_minimize(objective, gradient, &x0, Some(config));
904        assert!(result.is_ok());
905        let result = result.expect("should succeed");
906        assert!(result.converged);
907        assert!((result.x[0] - 1.0).abs() < 1e-4, "x[0]={}", result.x[0]);
908        assert!((result.x[1] - 2.0).abs() < 1e-4, "x[1]={}", result.x[1]);
909    }
910
911    /// Test 5: Block coordinate descent
912    #[test]
913    fn test_block_cd() {
914        let objective = |x: &ArrayView1<f64>| -> f64 {
915            (x[0] - 1.0).powi(2)
916                + (x[1] - 2.0).powi(2)
917                + (x[2] - 3.0).powi(2)
918                + (x[3] - 4.0).powi(2)
919        };
920        let gradient = |x: &ArrayView1<f64>| -> Array1<f64> {
921            array![
922                2.0 * (x[0] - 1.0),
923                2.0 * (x[1] - 2.0),
924                2.0 * (x[2] - 3.0),
925                2.0 * (x[3] - 4.0)
926            ]
927        };
928
929        let x0 = array![0.0, 0.0, 0.0, 0.0];
930        let config = CoordinateDescentConfig {
931            max_iter: 5000,
932            tol: 1e-12,
933            step_size: Some(0.4),
934            ..Default::default()
935        };
936
937        let solver = BlockCoordinateDescent::with_uniform_blocks(config, 4, 2);
938        let result = solver.minimize(objective, gradient, &x0);
939        assert!(result.is_ok());
940        let result = result.expect("should succeed");
941        assert!(result.converged);
942        assert!((result.x[0] - 1.0).abs() < 1e-4);
943        assert!((result.x[1] - 2.0).abs() < 1e-4);
944        assert!((result.x[2] - 3.0).abs() < 1e-4);
945        assert!((result.x[3] - 4.0).abs() < 1e-4);
946    }
947
948    /// Test 6: Quadratic coordinate descent with exact updates
949    #[test]
950    fn test_quadratic_cd_exact() {
951        // A = [[2, 1], [1, 3]], b = [1, 2]
952        // Solution: A^{-1} b = [1/5, 3/5]
953        let a = array![[2.0, 1.0], [1.0, 3.0]];
954        let b = array![1.0, 2.0];
955        let x0 = array![0.0, 0.0];
956
957        let config = CoordinateDescentConfig {
958            max_iter: 500,
959            tol: 1e-14,
960            ..Default::default()
961        };
962
963        let result = quadratic_coordinate_descent(&a, &b, &x0, Some(config));
964        assert!(result.is_ok());
965        let result = result.expect("should succeed");
966        assert!(result.converged);
967        // A^{-1} = [[3, -1], [-1, 2]] / 5
968        // x = [3/5 - 2/5, -1/5 + 4/5] = [1/5, 3/5]
969        assert!(
970            (result.x[0] - 0.2).abs() < 1e-8,
971            "x[0]={}, expected 0.2",
972            result.x[0]
973        );
974        assert!(
975            (result.x[1] - 0.6).abs() < 1e-8,
976            "x[1]={}, expected ~0.6",
977            result.x[1]
978        );
979    }
980
981    /// Test 7: Ridge regression
982    #[test]
983    fn test_ridge_cd() {
984        let a = Array2::eye(3);
985        let b = array![1.0, 2.0, 3.0];
986        let x0 = array![0.0, 0.0, 0.0];
987
988        let config = CoordinateDescentConfig {
989            max_iter: 1000,
990            tol: 1e-14,
991            regularization: RegularizationType::L2,
992            lambda: 0.5,
993            ..Default::default()
994        };
995
996        let result = quadratic_coordinate_descent(&a, &b, &x0, Some(config));
997        assert!(result.is_ok());
998        let result = result.expect("should succeed");
999        // Solution: x_i = b_i / (1 + 2*lambda) = b_i / 2
1000        assert!((result.x[0] - 0.5).abs() < 1e-8, "x[0]={}", result.x[0]);
1001        assert!((result.x[1] - 1.0).abs() < 1e-8, "x[1]={}", result.x[1]);
1002        assert!((result.x[2] - 1.5).abs() < 1e-8, "x[2]={}", result.x[2]);
1003    }
1004
1005    /// Test 8: Objective history tracking
1006    #[test]
1007    fn test_objective_history_tracking() {
1008        let a = Array2::eye(2);
1009        let b = array![1.0, 1.0];
1010        let x0 = array![5.0, 5.0];
1011
1012        let config = CoordinateDescentConfig {
1013            max_iter: 20,
1014            tol: 1e-20,
1015            track_objective: true,
1016            ..Default::default()
1017        };
1018
1019        let result = quadratic_coordinate_descent(&a, &b, &x0, Some(config));
1020        assert!(result.is_ok());
1021        let result = result.expect("should succeed");
1022        // History should be non-empty (initial + iterations)
1023        assert!(result.objective_history.len() > 1);
1024        // Objective should be monotonically non-increasing
1025        for i in 1..result.objective_history.len() {
1026            assert!(
1027                result.objective_history[i] <= result.objective_history[i - 1] + 1e-12,
1028                "Objective increased at iter {}: {} -> {}",
1029                i,
1030                result.objective_history[i - 1],
1031                result.objective_history[i]
1032            );
1033        }
1034    }
1035}