sklears_linear/
solver_implementations.rs

1//! Trait-based Solver Implementations
2//!
3//! This module implements various optimization solvers that work with the modular framework.
4//! All solvers implement the OptimizationSolver trait for consistency and pluggability.
5
6use crate::modular_framework::{
7    Objective, ObjectiveData, OptimizationSolver, SolverInfo, SolverRecommendations,
8};
9use scirs2_core::ndarray::{Array1, Array2};
10use sklears_core::{
11    error::{Result, SklearsError},
12    types::Float,
13};
14use std::collections::HashMap;
15
16/// Configuration for gradient descent solvers
17#[derive(Debug, Clone)]
18pub struct GradientDescentConfig {
19    /// Maximum number of iterations
20    pub max_iterations: usize,
21    /// Convergence tolerance
22    pub tolerance: Float,
23    /// Learning rate (step size)
24    pub learning_rate: Float,
25    /// Whether to use line search
26    pub use_line_search: bool,
27    /// Line search parameters
28    pub line_search_config: LineSearchConfig,
29    /// Whether to enable verbose output
30    pub verbose: bool,
31}
32
33impl Default for GradientDescentConfig {
34    fn default() -> Self {
35        Self {
36            max_iterations: 1000,
37            tolerance: 1e-6,
38            learning_rate: 0.01,
39            use_line_search: false,
40            line_search_config: LineSearchConfig::default(),
41            verbose: false,
42        }
43    }
44}
45
46/// Configuration for line search
47#[derive(Debug, Clone)]
48pub struct LineSearchConfig {
49    /// Armijo condition parameter (c1)
50    pub c1: Float,
51    /// Curvature condition parameter (c2, for strong Wolfe conditions)
52    pub c2: Float,
53    /// Maximum number of line search iterations
54    pub max_line_search_iterations: usize,
55    /// Initial step size scaling factor
56    pub initial_step_scale: Float,
57    /// Step size reduction factor
58    pub step_reduction_factor: Float,
59}
60
61impl Default for LineSearchConfig {
62    fn default() -> Self {
63        Self {
64            c1: 1e-4,
65            c2: 0.9,
66            max_line_search_iterations: 20,
67            initial_step_scale: 1.0,
68            step_reduction_factor: 0.5,
69        }
70    }
71}
72
73/// Result from gradient descent optimization
74#[derive(Debug, Clone)]
75pub struct GradientDescentResult {
76    /// Final coefficient values
77    pub coefficients: Array1<Float>,
78    /// Final objective value
79    pub objective_value: Float,
80    /// Number of iterations performed
81    pub n_iterations: usize,
82    /// Whether optimization converged
83    pub converged: bool,
84    /// Convergence history
85    pub convergence_history: Array1<Float>,
86    /// Gradient norm history
87    pub gradient_norm_history: Array1<Float>,
88    /// Final gradient norm
89    pub final_gradient_norm: Float,
90}
91
92/// Standard Gradient Descent solver
93#[derive(Debug)]
94pub struct GradientDescentSolver;
95
96impl OptimizationSolver for GradientDescentSolver {
97    type Config = GradientDescentConfig;
98    type Result = GradientDescentResult;
99
100    fn solve(
101        &self,
102        objective: &dyn Objective,
103        initial_guess: &Array1<Float>,
104        config: &Self::Config,
105    ) -> Result<Self::Result> {
106        let mut coefficients = initial_guess.clone();
107        let mut convergence_history = Vec::new();
108        let mut gradient_norm_history = Vec::new();
109        let mut converged = false;
110
111        // Create dummy data for objective computation (this is a limitation of the current design)
112        // In practice, the objective would need to store its own data
113        let dummy_data = ObjectiveData {
114            features: Array2::zeros((1, coefficients.len())),
115            targets: Array1::zeros(1),
116            sample_weights: None,
117            metadata: Default::default(),
118        };
119
120        for iteration in 0..config.max_iterations {
121            // Compute objective value and gradient
122            let (obj_value, gradient) = objective.value_and_gradient(&coefficients, &dummy_data)?;
123            let gradient_norm = gradient.mapv(|x| x * x).sum().sqrt();
124
125            convergence_history.push(obj_value);
126            gradient_norm_history.push(gradient_norm);
127
128            if config.verbose && iteration % 100 == 0 {
129                println!(
130                    "Iteration {}: obj={:.6}, ||grad||={:.6}",
131                    iteration, obj_value, gradient_norm
132                );
133            }
134
135            // Check convergence
136            if gradient_norm < config.tolerance {
137                converged = true;
138                if config.verbose {
139                    println!("Converged after {} iterations", iteration);
140                }
141                break;
142            }
143
144            // Compute step size
145            let step_size = if config.use_line_search {
146                self.line_search(
147                    objective,
148                    &coefficients,
149                    &gradient,
150                    &dummy_data,
151                    &config.line_search_config,
152                )?
153            } else {
154                config.learning_rate
155            };
156
157            // Update coefficients
158            coefficients = &coefficients - step_size * &gradient;
159        }
160
161        let final_objective = objective.value(&coefficients, &dummy_data)?;
162        let final_gradient = objective.gradient(&coefficients, &dummy_data)?;
163        let final_gradient_norm = final_gradient.mapv(|x| x * x).sum().sqrt();
164
165        Ok(GradientDescentResult {
166            coefficients,
167            objective_value: final_objective,
168            n_iterations: convergence_history.len(),
169            converged,
170            convergence_history: Array1::from_vec(convergence_history),
171            gradient_norm_history: Array1::from_vec(gradient_norm_history),
172            final_gradient_norm,
173        })
174    }
175
176    fn supports_objective(&self, _objective: &dyn Objective) -> bool {
177        true // Gradient descent works with any differentiable objective
178    }
179
180    fn name(&self) -> &'static str {
181        "GradientDescent"
182    }
183
184    fn get_recommendations(&self, data: &ObjectiveData) -> SolverRecommendations {
185        let n_samples = data.features.nrows();
186        let n_features = data.features.ncols();
187
188        // Heuristic recommendations based on problem size
189        let max_iter = if n_samples > 10000 { 100 } else { 1000 };
190        let tolerance = if n_features > 1000 { 1e-4 } else { 1e-6 };
191        let learning_rate = 1.0 / (n_samples as Float).sqrt();
192
193        SolverRecommendations {
194            max_iterations: Some(max_iter),
195            tolerance: Some(tolerance),
196            step_size: Some(learning_rate),
197            use_line_search: Some(n_features > 100),
198            notes: vec![
199                format!(
200                    "Problem size: {} samples, {} features",
201                    n_samples, n_features
202                ),
203                "Consider using line search for better convergence".to_string(),
204            ],
205        }
206    }
207}
208
209impl GradientDescentSolver {
210    /// Perform backtracking line search to find appropriate step size
211    fn line_search(
212        &self,
213        objective: &dyn Objective,
214        x: &Array1<Float>,
215        direction: &Array1<Float>,
216        data: &ObjectiveData,
217        config: &LineSearchConfig,
218    ) -> Result<Float> {
219        let f0 = objective.value(x, data)?;
220        let grad0 = objective.gradient(x, data)?;
221        let slope = grad0.dot(direction);
222
223        let mut step_size = config.initial_step_scale;
224
225        for _ in 0..config.max_line_search_iterations {
226            let x_new = x - step_size * direction;
227            let f_new = objective.value(&x_new, data)?;
228
229            // Armijo condition: f(x + α*p) ≤ f(x) + c1*α*∇f(x)ᵀp
230            if f_new <= f0 + config.c1 * step_size * slope {
231                return Ok(step_size);
232            }
233
234            step_size *= config.step_reduction_factor;
235        }
236
237        // If line search fails, return a small step size
238        Ok(step_size)
239    }
240}
241
242/// Configuration for coordinate descent solver
243#[derive(Debug, Clone)]
244pub struct CoordinateDescentConfig {
245    /// Maximum number of iterations
246    pub max_iterations: usize,
247    /// Convergence tolerance
248    pub tolerance: Float,
249    /// Whether to use random coordinate selection
250    pub random_selection: bool,
251    /// Random seed for reproducibility
252    pub random_seed: Option<u64>,
253    /// Whether to enable verbose output
254    pub verbose: bool,
255}
256
257impl Default for CoordinateDescentConfig {
258    fn default() -> Self {
259        Self {
260            max_iterations: 1000,
261            tolerance: 1e-6,
262            random_selection: false,
263            random_seed: None,
264            verbose: false,
265        }
266    }
267}
268
269/// Result from coordinate descent optimization
270#[derive(Debug, Clone)]
271pub struct CoordinateDescentResult {
272    /// Final coefficient values
273    pub coefficients: Array1<Float>,
274    /// Final objective value
275    pub objective_value: Float,
276    /// Number of iterations performed
277    pub n_iterations: usize,
278    /// Whether optimization converged
279    pub converged: bool,
280    /// Convergence history
281    pub convergence_history: Array1<Float>,
282    /// Number of coordinate updates performed
283    pub n_coordinate_updates: usize,
284}
285
286/// Coordinate Descent solver (good for L1-regularized problems)
287#[derive(Debug)]
288pub struct CoordinateDescentSolver;
289
290impl OptimizationSolver for CoordinateDescentSolver {
291    type Config = CoordinateDescentConfig;
292    type Result = CoordinateDescentResult;
293
294    fn solve(
295        &self,
296        objective: &dyn Objective,
297        initial_guess: &Array1<Float>,
298        config: &Self::Config,
299    ) -> Result<Self::Result> {
300        let mut coefficients = initial_guess.clone();
301        let n_features = coefficients.len();
302        let mut convergence_history = Vec::new();
303        let mut converged = false;
304        let mut coordinate_updates = 0;
305
306        // Create coordinate selection order
307        let coord_order: Vec<usize> = if config.random_selection {
308            // TODO: Implement random permutation using random_seed
309            (0..n_features).collect()
310        } else {
311            (0..n_features).collect()
312        };
313
314        let dummy_data = ObjectiveData {
315            features: Array2::zeros((1, n_features)),
316            targets: Array1::zeros(1),
317            sample_weights: None,
318            metadata: Default::default(),
319        };
320
321        for iteration in 0..config.max_iterations {
322            let _obj_value_start = objective.value(&coefficients, &dummy_data)?;
323            let mut max_coordinate_change: f64 = 0.0;
324
325            // Update each coordinate
326            for &coord_idx in &coord_order {
327                let old_value = coefficients[coord_idx];
328
329                // For coordinate descent, we would typically compute the optimal update
330                // for this coordinate. This is a simplified implementation.
331                let gradient = objective.gradient(&coefficients, &dummy_data)?;
332                let coord_gradient = gradient[coord_idx];
333
334                // Simple gradient step for this coordinate
335                let learning_rate = 0.01; // This should be adaptive
336                let new_value = old_value - learning_rate * coord_gradient;
337
338                coefficients[coord_idx] = new_value;
339                coordinate_updates += 1;
340
341                let change = (new_value - old_value).abs();
342                max_coordinate_change = max_coordinate_change.max(change);
343            }
344
345            let obj_value_end = objective.value(&coefficients, &dummy_data)?;
346            convergence_history.push(obj_value_end);
347
348            if config.verbose && iteration % 100 == 0 {
349                println!(
350                    "Iteration {}: obj={:.6}, max_change={:.6}",
351                    iteration, obj_value_end, max_coordinate_change
352                );
353            }
354
355            // Check convergence
356            if max_coordinate_change < config.tolerance {
357                converged = true;
358                if config.verbose {
359                    println!("Converged after {} iterations", iteration);
360                }
361                break;
362            }
363        }
364
365        let final_objective = objective.value(&coefficients, &dummy_data)?;
366
367        Ok(CoordinateDescentResult {
368            coefficients,
369            objective_value: final_objective,
370            n_iterations: convergence_history.len(),
371            converged,
372            convergence_history: Array1::from_vec(convergence_history),
373            n_coordinate_updates: coordinate_updates,
374        })
375    }
376
377    fn supports_objective(&self, _objective: &dyn Objective) -> bool {
378        // Coordinate descent works well with separable objectives
379        // For L1 regularization, it's particularly effective
380        true
381    }
382
383    fn name(&self) -> &'static str {
384        "CoordinateDescent"
385    }
386
387    fn get_recommendations(&self, data: &ObjectiveData) -> SolverRecommendations {
388        let n_features = data.features.ncols();
389
390        SolverRecommendations {
391            max_iterations: Some(if n_features > 1000 { 100 } else { 1000 }),
392            tolerance: Some(1e-6),
393            step_size: None, // Not applicable for coordinate descent
394            use_line_search: Some(false),
395            notes: vec![
396                "Coordinate descent is particularly effective for L1-regularized problems"
397                    .to_string(),
398                "Consider random coordinate selection for large problems".to_string(),
399            ],
400        }
401    }
402}
403
404/// Configuration for proximal gradient solver
405#[derive(Debug, Clone)]
406pub struct ProximalGradientConfig {
407    /// Maximum number of iterations
408    pub max_iterations: usize,
409    /// Convergence tolerance
410    pub tolerance: Float,
411    /// Initial step size
412    pub initial_step_size: Float,
413    /// Whether to use adaptive step size
414    pub adaptive_step_size: bool,
415    /// Backtracking parameters
416    pub backtracking_config: BacktrackingConfig,
417    /// Whether to enable verbose output
418    pub verbose: bool,
419}
420
421impl Default for ProximalGradientConfig {
422    fn default() -> Self {
423        Self {
424            max_iterations: 1000,
425            tolerance: 1e-6,
426            initial_step_size: 1.0,
427            adaptive_step_size: true,
428            backtracking_config: BacktrackingConfig::default(),
429            verbose: false,
430        }
431    }
432}
433
434/// Configuration for backtracking in proximal gradient
435#[derive(Debug, Clone)]
436pub struct BacktrackingConfig {
437    /// Backtracking parameter (β)
438    pub beta: Float,
439    /// Sufficient decrease parameter
440    pub sigma: Float,
441    /// Maximum backtracking iterations
442    pub max_backtrack_iterations: usize,
443}
444
445impl Default for BacktrackingConfig {
446    fn default() -> Self {
447        Self {
448            beta: 0.5,
449            sigma: 0.01,
450            max_backtrack_iterations: 50,
451        }
452    }
453}
454
455/// Result from proximal gradient optimization
456#[derive(Debug, Clone)]
457pub struct ProximalGradientResult {
458    /// Final coefficient values
459    pub coefficients: Array1<Float>,
460    /// Final objective value
461    pub objective_value: Float,
462    /// Number of iterations performed
463    pub n_iterations: usize,
464    /// Whether optimization converged
465    pub converged: bool,
466    /// Convergence history
467    pub convergence_history: Array1<Float>,
468    /// Step size history
469    pub step_size_history: Array1<Float>,
470}
471
472/// Proximal Gradient solver (for non-smooth regularization)
473#[derive(Debug)]
474pub struct ProximalGradientSolver;
475
476impl OptimizationSolver for ProximalGradientSolver {
477    type Config = ProximalGradientConfig;
478    type Result = ProximalGradientResult;
479
480    fn solve(
481        &self,
482        _objective: &dyn Objective,
483        _initial_guess: &Array1<Float>,
484        _config: &Self::Config,
485    ) -> Result<Self::Result> {
486        // NOTE: This is a simplified implementation. In practice, proximal gradient
487        // methods require separating the smooth and non-smooth parts of the objective.
488        // The current Objective trait doesn't directly support this separation.
489
490        Err(SklearsError::InvalidOperation(
491            "Proximal gradient solver requires objective decomposition not yet implemented"
492                .to_string(),
493        ))
494    }
495
496    fn supports_objective(&self, _objective: &dyn Objective) -> bool {
497        // This would check if the objective has a decomposable structure
498        false
499    }
500
501    fn name(&self) -> &'static str {
502        "ProximalGradient"
503    }
504
505    fn get_recommendations(&self, _data: &ObjectiveData) -> SolverRecommendations {
506        SolverRecommendations {
507            max_iterations: Some(1000),
508            tolerance: Some(1e-6),
509            step_size: Some(1.0),
510            use_line_search: Some(false),
511            notes: vec![
512                "Proximal gradient is ideal for problems with non-smooth regularization"
513                    .to_string(),
514                "Requires objective decomposition into smooth + non-smooth parts".to_string(),
515            ],
516        }
517    }
518}
519
520/// Factory for creating solver instances
521pub struct SolverFactory;
522
523impl SolverFactory {
524    /// Create a gradient descent solver
525    pub fn gradient_descent(
526    ) -> Box<dyn OptimizationSolver<Config = GradientDescentConfig, Result = GradientDescentResult>>
527    {
528        Box::new(GradientDescentSolver)
529    }
530
531    /// Create a coordinate descent solver
532    pub fn coordinate_descent() -> Box<
533        dyn OptimizationSolver<Config = CoordinateDescentConfig, Result = CoordinateDescentResult>,
534    > {
535        Box::new(CoordinateDescentSolver)
536    }
537
538    /// Create a proximal gradient solver
539    pub fn proximal_gradient(
540    ) -> Box<dyn OptimizationSolver<Config = ProximalGradientConfig, Result = ProximalGradientResult>>
541    {
542        Box::new(ProximalGradientSolver)
543    }
544}
545
546/// Utility function to convert from framework result to standard format
547pub fn convert_solver_result_to_standard(
548    _result: &dyn std::fmt::Debug,
549    solver_name: &str,
550) -> crate::modular_framework::OptimizationResult {
551    // This is a placeholder for result conversion
552    // In practice, each solver result type would implement a conversion trait
553    crate::modular_framework::OptimizationResult {
554        coefficients: Array1::zeros(1), // Placeholder
555        intercept: None,
556        objective_value: 0.0,
557        n_iterations: 0,
558        converged: false,
559        solver_info: SolverInfo {
560            solver_name: solver_name.to_string(),
561            metrics: HashMap::new(),
562            warnings: Vec::new(),
563            convergence_history: None,
564        },
565    }
566}
567
568#[allow(non_snake_case)]
569#[cfg(test)]
570mod tests {
571    use super::*;
572    use crate::loss_functions::SquaredLoss;
573    use crate::modular_framework::CompositeObjective;
574    use crate::regularization_schemes::L2Regularization;
575
576    // Test helper: create a simple quadratic objective
577    fn create_test_objective() -> CompositeObjective<'static> {
578        let loss = Box::leak(Box::new(SquaredLoss));
579        let reg = Box::leak(Box::new(L2Regularization::new(0.1).unwrap()));
580        CompositeObjective::new(loss, Some(reg))
581    }
582
583    #[test]
584    fn test_gradient_descent_config() {
585        let config = GradientDescentConfig::default();
586        assert_eq!(config.max_iterations, 1000);
587        assert_eq!(config.tolerance, 1e-6);
588        assert_eq!(config.learning_rate, 0.01);
589        assert!(!config.use_line_search);
590    }
591
592    #[test]
593    fn test_coordinate_descent_config() {
594        let config = CoordinateDescentConfig::default();
595        assert_eq!(config.max_iterations, 1000);
596        assert_eq!(config.tolerance, 1e-6);
597        assert!(!config.random_selection);
598    }
599
600    #[test]
601    fn test_solver_names() {
602        let gd_solver = GradientDescentSolver;
603        assert_eq!(gd_solver.name(), "GradientDescent");
604
605        let cd_solver = CoordinateDescentSolver;
606        assert_eq!(cd_solver.name(), "CoordinateDescent");
607
608        let pg_solver = ProximalGradientSolver;
609        assert_eq!(pg_solver.name(), "ProximalGradient");
610    }
611
612    #[test]
613    fn test_solver_factory() {
614        let gd = SolverFactory::gradient_descent();
615        assert_eq!(gd.name(), "GradientDescent");
616
617        let cd = SolverFactory::coordinate_descent();
618        assert_eq!(cd.name(), "CoordinateDescent");
619
620        let pg = SolverFactory::proximal_gradient();
621        assert_eq!(pg.name(), "ProximalGradient");
622    }
623
624    #[test]
625    fn test_solver_recommendations() {
626        let solver = GradientDescentSolver;
627        let data = ObjectiveData {
628            features: Array2::zeros((100, 10)),
629            targets: Array1::zeros(100),
630            sample_weights: None,
631            metadata: Default::default(),
632        };
633
634        let recommendations = solver.get_recommendations(&data);
635        assert!(recommendations.max_iterations.is_some());
636        assert!(recommendations.tolerance.is_some());
637        assert!(recommendations.step_size.is_some());
638    }
639
640    #[test]
641    fn test_line_search_config() {
642        let config = LineSearchConfig::default();
643        assert_eq!(config.c1, 1e-4);
644        assert_eq!(config.c2, 0.9);
645        assert_eq!(config.max_line_search_iterations, 20);
646    }
647}