1use crate::modular_framework::{
7 Objective, ObjectiveData, OptimizationSolver, SolverInfo, SolverRecommendations,
8};
9use scirs2_core::ndarray::{Array1, Array2};
10use sklears_core::{
11 error::{Result, SklearsError},
12 types::Float,
13};
14use std::collections::HashMap;
15
16#[derive(Debug, Clone)]
18pub struct GradientDescentConfig {
19 pub max_iterations: usize,
21 pub tolerance: Float,
23 pub learning_rate: Float,
25 pub use_line_search: bool,
27 pub line_search_config: LineSearchConfig,
29 pub verbose: bool,
31}
32
33impl Default for GradientDescentConfig {
34 fn default() -> Self {
35 Self {
36 max_iterations: 1000,
37 tolerance: 1e-6,
38 learning_rate: 0.01,
39 use_line_search: false,
40 line_search_config: LineSearchConfig::default(),
41 verbose: false,
42 }
43 }
44}
45
46#[derive(Debug, Clone)]
48pub struct LineSearchConfig {
49 pub c1: Float,
51 pub c2: Float,
53 pub max_line_search_iterations: usize,
55 pub initial_step_scale: Float,
57 pub step_reduction_factor: Float,
59}
60
61impl Default for LineSearchConfig {
62 fn default() -> Self {
63 Self {
64 c1: 1e-4,
65 c2: 0.9,
66 max_line_search_iterations: 20,
67 initial_step_scale: 1.0,
68 step_reduction_factor: 0.5,
69 }
70 }
71}
72
73#[derive(Debug, Clone)]
75pub struct GradientDescentResult {
76 pub coefficients: Array1<Float>,
78 pub objective_value: Float,
80 pub n_iterations: usize,
82 pub converged: bool,
84 pub convergence_history: Array1<Float>,
86 pub gradient_norm_history: Array1<Float>,
88 pub final_gradient_norm: Float,
90}
91
92#[derive(Debug)]
94pub struct GradientDescentSolver;
95
96impl OptimizationSolver for GradientDescentSolver {
97 type Config = GradientDescentConfig;
98 type Result = GradientDescentResult;
99
100 fn solve(
101 &self,
102 objective: &dyn Objective,
103 initial_guess: &Array1<Float>,
104 config: &Self::Config,
105 ) -> Result<Self::Result> {
106 let mut coefficients = initial_guess.clone();
107 let mut convergence_history = Vec::new();
108 let mut gradient_norm_history = Vec::new();
109 let mut converged = false;
110
111 let dummy_data = ObjectiveData {
114 features: Array2::zeros((1, coefficients.len())),
115 targets: Array1::zeros(1),
116 sample_weights: None,
117 metadata: Default::default(),
118 };
119
120 for iteration in 0..config.max_iterations {
121 let (obj_value, gradient) = objective.value_and_gradient(&coefficients, &dummy_data)?;
123 let gradient_norm = gradient.mapv(|x| x * x).sum().sqrt();
124
125 convergence_history.push(obj_value);
126 gradient_norm_history.push(gradient_norm);
127
128 if config.verbose && iteration % 100 == 0 {
129 println!(
130 "Iteration {}: obj={:.6}, ||grad||={:.6}",
131 iteration, obj_value, gradient_norm
132 );
133 }
134
135 if gradient_norm < config.tolerance {
137 converged = true;
138 if config.verbose {
139 println!("Converged after {} iterations", iteration);
140 }
141 break;
142 }
143
144 let step_size = if config.use_line_search {
146 self.line_search(
147 objective,
148 &coefficients,
149 &gradient,
150 &dummy_data,
151 &config.line_search_config,
152 )?
153 } else {
154 config.learning_rate
155 };
156
157 coefficients = &coefficients - step_size * &gradient;
159 }
160
161 let final_objective = objective.value(&coefficients, &dummy_data)?;
162 let final_gradient = objective.gradient(&coefficients, &dummy_data)?;
163 let final_gradient_norm = final_gradient.mapv(|x| x * x).sum().sqrt();
164
165 Ok(GradientDescentResult {
166 coefficients,
167 objective_value: final_objective,
168 n_iterations: convergence_history.len(),
169 converged,
170 convergence_history: Array1::from_vec(convergence_history),
171 gradient_norm_history: Array1::from_vec(gradient_norm_history),
172 final_gradient_norm,
173 })
174 }
175
176 fn supports_objective(&self, _objective: &dyn Objective) -> bool {
177 true }
179
180 fn name(&self) -> &'static str {
181 "GradientDescent"
182 }
183
184 fn get_recommendations(&self, data: &ObjectiveData) -> SolverRecommendations {
185 let n_samples = data.features.nrows();
186 let n_features = data.features.ncols();
187
188 let max_iter = if n_samples > 10000 { 100 } else { 1000 };
190 let tolerance = if n_features > 1000 { 1e-4 } else { 1e-6 };
191 let learning_rate = 1.0 / (n_samples as Float).sqrt();
192
193 SolverRecommendations {
194 max_iterations: Some(max_iter),
195 tolerance: Some(tolerance),
196 step_size: Some(learning_rate),
197 use_line_search: Some(n_features > 100),
198 notes: vec![
199 format!(
200 "Problem size: {} samples, {} features",
201 n_samples, n_features
202 ),
203 "Consider using line search for better convergence".to_string(),
204 ],
205 }
206 }
207}
208
209impl GradientDescentSolver {
210 fn line_search(
212 &self,
213 objective: &dyn Objective,
214 x: &Array1<Float>,
215 direction: &Array1<Float>,
216 data: &ObjectiveData,
217 config: &LineSearchConfig,
218 ) -> Result<Float> {
219 let f0 = objective.value(x, data)?;
220 let grad0 = objective.gradient(x, data)?;
221 let slope = grad0.dot(direction);
222
223 let mut step_size = config.initial_step_scale;
224
225 for _ in 0..config.max_line_search_iterations {
226 let x_new = x - step_size * direction;
227 let f_new = objective.value(&x_new, data)?;
228
229 if f_new <= f0 + config.c1 * step_size * slope {
231 return Ok(step_size);
232 }
233
234 step_size *= config.step_reduction_factor;
235 }
236
237 Ok(step_size)
239 }
240}
241
242#[derive(Debug, Clone)]
244pub struct CoordinateDescentConfig {
245 pub max_iterations: usize,
247 pub tolerance: Float,
249 pub random_selection: bool,
251 pub random_seed: Option<u64>,
253 pub verbose: bool,
255}
256
257impl Default for CoordinateDescentConfig {
258 fn default() -> Self {
259 Self {
260 max_iterations: 1000,
261 tolerance: 1e-6,
262 random_selection: false,
263 random_seed: None,
264 verbose: false,
265 }
266 }
267}
268
269#[derive(Debug, Clone)]
271pub struct CoordinateDescentResult {
272 pub coefficients: Array1<Float>,
274 pub objective_value: Float,
276 pub n_iterations: usize,
278 pub converged: bool,
280 pub convergence_history: Array1<Float>,
282 pub n_coordinate_updates: usize,
284}
285
286#[derive(Debug)]
288pub struct CoordinateDescentSolver;
289
290impl OptimizationSolver for CoordinateDescentSolver {
291 type Config = CoordinateDescentConfig;
292 type Result = CoordinateDescentResult;
293
294 fn solve(
295 &self,
296 objective: &dyn Objective,
297 initial_guess: &Array1<Float>,
298 config: &Self::Config,
299 ) -> Result<Self::Result> {
300 let mut coefficients = initial_guess.clone();
301 let n_features = coefficients.len();
302 let mut convergence_history = Vec::new();
303 let mut converged = false;
304 let mut coordinate_updates = 0;
305
306 let coord_order: Vec<usize> = if config.random_selection {
308 (0..n_features).collect()
310 } else {
311 (0..n_features).collect()
312 };
313
314 let dummy_data = ObjectiveData {
315 features: Array2::zeros((1, n_features)),
316 targets: Array1::zeros(1),
317 sample_weights: None,
318 metadata: Default::default(),
319 };
320
321 for iteration in 0..config.max_iterations {
322 let _obj_value_start = objective.value(&coefficients, &dummy_data)?;
323 let mut max_coordinate_change: f64 = 0.0;
324
325 for &coord_idx in &coord_order {
327 let old_value = coefficients[coord_idx];
328
329 let gradient = objective.gradient(&coefficients, &dummy_data)?;
332 let coord_gradient = gradient[coord_idx];
333
334 let learning_rate = 0.01; let new_value = old_value - learning_rate * coord_gradient;
337
338 coefficients[coord_idx] = new_value;
339 coordinate_updates += 1;
340
341 let change = (new_value - old_value).abs();
342 max_coordinate_change = max_coordinate_change.max(change);
343 }
344
345 let obj_value_end = objective.value(&coefficients, &dummy_data)?;
346 convergence_history.push(obj_value_end);
347
348 if config.verbose && iteration % 100 == 0 {
349 println!(
350 "Iteration {}: obj={:.6}, max_change={:.6}",
351 iteration, obj_value_end, max_coordinate_change
352 );
353 }
354
355 if max_coordinate_change < config.tolerance {
357 converged = true;
358 if config.verbose {
359 println!("Converged after {} iterations", iteration);
360 }
361 break;
362 }
363 }
364
365 let final_objective = objective.value(&coefficients, &dummy_data)?;
366
367 Ok(CoordinateDescentResult {
368 coefficients,
369 objective_value: final_objective,
370 n_iterations: convergence_history.len(),
371 converged,
372 convergence_history: Array1::from_vec(convergence_history),
373 n_coordinate_updates: coordinate_updates,
374 })
375 }
376
377 fn supports_objective(&self, _objective: &dyn Objective) -> bool {
378 true
381 }
382
383 fn name(&self) -> &'static str {
384 "CoordinateDescent"
385 }
386
387 fn get_recommendations(&self, data: &ObjectiveData) -> SolverRecommendations {
388 let n_features = data.features.ncols();
389
390 SolverRecommendations {
391 max_iterations: Some(if n_features > 1000 { 100 } else { 1000 }),
392 tolerance: Some(1e-6),
393 step_size: None, use_line_search: Some(false),
395 notes: vec![
396 "Coordinate descent is particularly effective for L1-regularized problems"
397 .to_string(),
398 "Consider random coordinate selection for large problems".to_string(),
399 ],
400 }
401 }
402}
403
404#[derive(Debug, Clone)]
406pub struct ProximalGradientConfig {
407 pub max_iterations: usize,
409 pub tolerance: Float,
411 pub initial_step_size: Float,
413 pub adaptive_step_size: bool,
415 pub backtracking_config: BacktrackingConfig,
417 pub verbose: bool,
419}
420
421impl Default for ProximalGradientConfig {
422 fn default() -> Self {
423 Self {
424 max_iterations: 1000,
425 tolerance: 1e-6,
426 initial_step_size: 1.0,
427 adaptive_step_size: true,
428 backtracking_config: BacktrackingConfig::default(),
429 verbose: false,
430 }
431 }
432}
433
434#[derive(Debug, Clone)]
436pub struct BacktrackingConfig {
437 pub beta: Float,
439 pub sigma: Float,
441 pub max_backtrack_iterations: usize,
443}
444
445impl Default for BacktrackingConfig {
446 fn default() -> Self {
447 Self {
448 beta: 0.5,
449 sigma: 0.01,
450 max_backtrack_iterations: 50,
451 }
452 }
453}
454
455#[derive(Debug, Clone)]
457pub struct ProximalGradientResult {
458 pub coefficients: Array1<Float>,
460 pub objective_value: Float,
462 pub n_iterations: usize,
464 pub converged: bool,
466 pub convergence_history: Array1<Float>,
468 pub step_size_history: Array1<Float>,
470}
471
472#[derive(Debug)]
474pub struct ProximalGradientSolver;
475
476impl OptimizationSolver for ProximalGradientSolver {
477 type Config = ProximalGradientConfig;
478 type Result = ProximalGradientResult;
479
480 fn solve(
481 &self,
482 _objective: &dyn Objective,
483 _initial_guess: &Array1<Float>,
484 _config: &Self::Config,
485 ) -> Result<Self::Result> {
486 Err(SklearsError::InvalidOperation(
491 "Proximal gradient solver requires objective decomposition not yet implemented"
492 .to_string(),
493 ))
494 }
495
496 fn supports_objective(&self, _objective: &dyn Objective) -> bool {
497 false
499 }
500
501 fn name(&self) -> &'static str {
502 "ProximalGradient"
503 }
504
505 fn get_recommendations(&self, _data: &ObjectiveData) -> SolverRecommendations {
506 SolverRecommendations {
507 max_iterations: Some(1000),
508 tolerance: Some(1e-6),
509 step_size: Some(1.0),
510 use_line_search: Some(false),
511 notes: vec![
512 "Proximal gradient is ideal for problems with non-smooth regularization"
513 .to_string(),
514 "Requires objective decomposition into smooth + non-smooth parts".to_string(),
515 ],
516 }
517 }
518}
519
520pub struct SolverFactory;
522
523impl SolverFactory {
524 pub fn gradient_descent(
526 ) -> Box<dyn OptimizationSolver<Config = GradientDescentConfig, Result = GradientDescentResult>>
527 {
528 Box::new(GradientDescentSolver)
529 }
530
531 pub fn coordinate_descent() -> Box<
533 dyn OptimizationSolver<Config = CoordinateDescentConfig, Result = CoordinateDescentResult>,
534 > {
535 Box::new(CoordinateDescentSolver)
536 }
537
538 pub fn proximal_gradient(
540 ) -> Box<dyn OptimizationSolver<Config = ProximalGradientConfig, Result = ProximalGradientResult>>
541 {
542 Box::new(ProximalGradientSolver)
543 }
544}
545
546pub fn convert_solver_result_to_standard(
548 _result: &dyn std::fmt::Debug,
549 solver_name: &str,
550) -> crate::modular_framework::OptimizationResult {
551 crate::modular_framework::OptimizationResult {
554 coefficients: Array1::zeros(1), intercept: None,
556 objective_value: 0.0,
557 n_iterations: 0,
558 converged: false,
559 solver_info: SolverInfo {
560 solver_name: solver_name.to_string(),
561 metrics: HashMap::new(),
562 warnings: Vec::new(),
563 convergence_history: None,
564 },
565 }
566}
567
568#[allow(non_snake_case)]
569#[cfg(test)]
570mod tests {
571 use super::*;
572 use crate::loss_functions::SquaredLoss;
573 use crate::modular_framework::CompositeObjective;
574 use crate::regularization_schemes::L2Regularization;
575
576 fn create_test_objective() -> CompositeObjective<'static> {
578 let loss = Box::leak(Box::new(SquaredLoss));
579 let reg = Box::leak(Box::new(L2Regularization::new(0.1).unwrap()));
580 CompositeObjective::new(loss, Some(reg))
581 }
582
583 #[test]
584 fn test_gradient_descent_config() {
585 let config = GradientDescentConfig::default();
586 assert_eq!(config.max_iterations, 1000);
587 assert_eq!(config.tolerance, 1e-6);
588 assert_eq!(config.learning_rate, 0.01);
589 assert!(!config.use_line_search);
590 }
591
592 #[test]
593 fn test_coordinate_descent_config() {
594 let config = CoordinateDescentConfig::default();
595 assert_eq!(config.max_iterations, 1000);
596 assert_eq!(config.tolerance, 1e-6);
597 assert!(!config.random_selection);
598 }
599
600 #[test]
601 fn test_solver_names() {
602 let gd_solver = GradientDescentSolver;
603 assert_eq!(gd_solver.name(), "GradientDescent");
604
605 let cd_solver = CoordinateDescentSolver;
606 assert_eq!(cd_solver.name(), "CoordinateDescent");
607
608 let pg_solver = ProximalGradientSolver;
609 assert_eq!(pg_solver.name(), "ProximalGradient");
610 }
611
612 #[test]
613 fn test_solver_factory() {
614 let gd = SolverFactory::gradient_descent();
615 assert_eq!(gd.name(), "GradientDescent");
616
617 let cd = SolverFactory::coordinate_descent();
618 assert_eq!(cd.name(), "CoordinateDescent");
619
620 let pg = SolverFactory::proximal_gradient();
621 assert_eq!(pg.name(), "ProximalGradient");
622 }
623
624 #[test]
625 fn test_solver_recommendations() {
626 let solver = GradientDescentSolver;
627 let data = ObjectiveData {
628 features: Array2::zeros((100, 10)),
629 targets: Array1::zeros(100),
630 sample_weights: None,
631 metadata: Default::default(),
632 };
633
634 let recommendations = solver.get_recommendations(&data);
635 assert!(recommendations.max_iterations.is_some());
636 assert!(recommendations.tolerance.is_some());
637 assert!(recommendations.step_size.is_some());
638 }
639
640 #[test]
641 fn test_line_search_config() {
642 let config = LineSearchConfig::default();
643 assert_eq!(config.c1, 1e-4);
644 assert_eq!(config.c2, 0.9);
645 assert_eq!(config.max_line_search_iterations, 20);
646 }
647}